Search

DQN으로 TSP 문제 풀어보기

Q-러닝은 각 상태(State)에서 특정 행동(Action)을 했을 때, 앞으로 얻게 될 총 보상이 얼마일지를 나타내는 Q 값을 배우는 것을 목표로 합니다.
TSP 예를 들면, 지금 도시에서 (State) 다른 도시를 방문할 때 이동거리가 작아 지도록 이동(Action) 하면 Q 값은 높아지고 반면에 이동거리가 길어지면 Q 값은 낮아지게 되는 것으로 생각할 수 있습니다.
Q-러닝은 이 Q 값을 계속 업데이트하면서 학습을 합니다. 현재 상태에서 어떤 행동을 했을 때 얻은 즉각적인 보상과, 그 행동을 한 후 도달한 다음 상태에서 얻을 수 있는 최대 Q 값을 이용해서 현재 Q 값을 조정합니다.
전통적인 Q-러닝은 가능한 상태와 행동의 조합이 많지 않을 때, 이 Q 값을 표(Q-table)로 만들어서 관리할 수 있었습니다.
하지만 복잡한 문제의 경우에는 가능한 상태의 수가 급격하게 많아져서 Q-table을 만드는 것은 불가능해집니다.
DQN은 바로 이 문제를 해결하기 위해 인공신경망을 사용하는 아이디어입니다. Q-table을 직접 만드는 대신, 딥러닝 모델에게 Q 값을 '예측'하게 하는 것입니다.
신경망의 입력: 현재 환경의 상태 (예: 현재까지 방문한 도시의 Array) 신경망의 출력: 현재 상태에서 방문할 수 있는 다른 도시들의 거리에 대한 Q 값들
단순히 Q-네트워크만 사용하면 학습이 불안정해지는 문제가 발생하기 때문에 DQN은 두 가지 중요한 기술을 추가합니다.
경험 리플레이 (Experience Replay) 에이전트가 환경과 상호작용하면서 얻은 경험들 (현재 상태, 취한 행동, 받은 보상, 다음 상태)을 경험 메모리 저장 공간에 차곡차곡 쌓아둡니다. 신경망을 학습할 때는 이 경험 메모리에서 무작위로 일부 경험들을 뽑아서 사용합니다. 에이전트가 경험한 데이터를 순서대로 학습하면 데이터 간의 상관관계가 높아져 학습이 불안정해지기 때문에 경험 리플레이는 데이터를 섞어서 사용함으로써 이러한 문제를 완화하고 데이터 효율성을 높입니다.
타겟 네트워크 (Target Network): DQN 학습은 현재 Q-네트워크가 예측한 Q 값과 목표 Q 값(벨만 방정식 기반) 사이의 오차를 줄이는 방향으로 이루어 집니다. 그런데 이 목표 Q 값을 계산할 때, 만약 현재 학습하고 있는 Q-네트워크의 최신 가중치를 바로 사용하면 목표 값이 계속 변동하게 됩니다. 마치 움직이는 과녁을 맞추는 것처럼 학습이 매우 불안정하게 됩니다.
그래서 DQN은 목표 Q 값을 계산할 때는 현재 학습 중인 Q-네트워크와 똑같이 생겼지만, 가중치 업데이트를 훨씬 느리게(또는 일정 주기마다 복사해오는) 별도의 네트워크를 사용하고, 이 네트워크를 타겟 네트워크라고 합니다.
TSP를 위한 State 정의
처음에 TSP 문제를 DQN으로 도전 했을때 품질이 좋지 않았는데 아래와 같이 State를 수정했더니 품질이 좋아졌습니다.
def _get_state(self): current_city_one_hot = np.zeros(self.num_cities) current_city_one_hot[self.current_city] = 1.0 visited_state = np.array(self.visited, dtype=float) state = np.concatenate((current_city_one_hot, visited_state)) return state
Python
복사
소스코드
import torch import torch.nn as nn import torch.optim as optim import numpy as np import random from collections import deque import matplotlib.pyplot as plt # 하이퍼파라미터 설정 NUM_CITIES = 20 BATCH_SIZE = 64 GAMMA = 0.99 EPS_START = 1.0 EPS_END = 0.01 EPS_DECAY = 0.9999 MEMORY_SIZE = 10000 LEARNING_RATE = 0.0008 NUM_EPISODES = 30000 MAX_STEPS = NUM_CITIES*2 UPDATE_TARGET_EVERY = 100 INVALID_ACTION_PENALTY = -2000 # TSP 환경 정의 class TSPEnv: def __init__(self, num_cities): self.num_cities = num_cities self.cities = None self.current_city = None self.visited = None self.distances = None self.tour = [] self.total_distance = 0 self.reset() def reset(self): self.cities = np.array([[0.0, 0.0], [0.0, 900], [100, 500], [200, 200], [400, 100], [400, 800], [700, 200], [800, 500], [900, 0.0], [900, 900], [0.0, 100], [0.0, 700], [400, 0.0], [400, 100], [400, 800], [400, 900], [700, 0.0], [700, 900], [900, 200], [900, 700]]) self.distances = np.zeros((self.num_cities, self.num_cities)) for i in range(self.num_cities): for j in range(i+1, self.num_cities): dist = int(np.linalg.norm(self.cities[i] - self.cities[j])) self.distances[i,j] = self.distances[j,i] = dist self.current_city = 0 self.visited = [False] * self.num_cities self.visited[self.current_city] = True self.tour = [self.current_city] self.total_distance = 0 return self._get_state() def step(self, action): done = False if self.visited[action]: return self._get_state(), INVALID_ACTION_PENALTY, True, {} distance = self.distances[self.current_city, action] reward = -distance self.total_distance += distance self.current_city = action self.visited[action] = True self.tour.append(action) if all(self.visited): return_to_start_distance = self.distances[self.current_city, self.tour[0]] self.total_distance += return_to_start_distance reward -= return_to_start_distance done = True return self._get_state(), reward, done, {} def _get_state(self): current_city_one_hot = np.zeros(self.num_cities) current_city_one_hot[self.current_city] = 1.0 visited_state = np.array(self.visited, dtype=float) state = np.concatenate((current_city_one_hot, visited_state)) return state def get_valid_actions(self): valid_actions = [i for i in range(self.num_cities) if not self.visited[i]] return valid_actions class DQN(nn.Module): def __init__(self, input_size, output_size): super(DQN, self).__init__() self.fc1 = nn.Linear(input_size, 256) self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, output_size) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc3(x) class ReplayMemory: def __init__(self, capacity): self.memory = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) class DQNAgent: def __init__(self, input_size, output_size): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.policy_net = DQN(input_size, output_size).to(self.device) self.target_net = DQN(input_size, output_size).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE) self.memory = ReplayMemory(MEMORY_SIZE) self.epsilon = EPS_START self.output_size = output_size self.input_size = input_size def select_action(self, state, env): valid_actions = env.get_valid_actions() if np.random.rand() <= self.epsilon: if not valid_actions: return None return random.choice(valid_actions) else: with torch.no_grad(): state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device) q_values = self.policy_net(state_tensor) for i in range(self.output_size): if i not in valid_actions: q_values[0, i] = -float('inf') if torch.all(q_values == -float('inf')): return None action = q_values.argmax().item() return action def update_model(self): if len(self.memory) < BATCH_SIZE: return transitions = self.memory.sample(BATCH_SIZE) batch = tuple(zip(*transitions)) state_batch = torch.tensor(np.array(batch[0]), dtype=torch.float32).to(self.device) action_batch = torch.tensor(batch[1], dtype=torch.int64).unsqueeze(1).to(self.device) reward_batch = torch.tensor(batch[2], dtype=torch.float32).unsqueeze(1).to(self.device) next_state_batch = torch.tensor(np.array(batch[3]), dtype=torch.float32).to(self.device) done_batch = torch.tensor(batch[4], dtype=torch.float32).unsqueeze(1).to(self.device) q_values = self.policy_net(state_batch).gather(1, action_batch) with torch.no_grad(): next_q_values_policy = self.policy_net(next_state_batch) next_visited_batch = next_state_batch[:, self.output_size:] next_q_values_policy[next_visited_batch.bool()] = -float('inf') # 마스킹된 Q 값에서 최대 Q 값을 가지는 액션의 인덱스 선택 next_actions = next_q_values_policy.argmax(dim=1, keepdim=True) # 타겟 네트워크에서 선택된 다음 액션에 대한 Q 값 계산 max_next_q_values = self.target_net(next_state_batch).gather(1, next_actions) expected_q_values = reward_batch + (1 - done_batch) * GAMMA * max_next_q_values loss = nn.MSELoss()(q_values, expected_q_values) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_epsilon(self): self.epsilon = max(EPS_END, self.epsilon * EPS_DECAY) def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) env = TSPEnv(NUM_CITIES) state_size = NUM_CITIES + NUM_CITIES action_size = env.num_cities agent = DQNAgent(state_size, action_size) total_rewards = [] last_episode_done = False for episode in range(NUM_EPISODES): state = env.reset() total_reward = 0 done = False # 에피소드 시작 시 초기화 for step in range(MAX_STEPS): action = agent.select_action(state, env) if action is None: done = True break next_state, reward, done, _ = env.step(action) agent.memory.push(state, action, reward, next_state, done) state = next_state total_reward += reward if len(agent.memory) > BATCH_SIZE: agent.update_model() if done: break agent.update_epsilon() if episode % UPDATE_TARGET_EVERY == 0: agent.update_target_network() # 진행 상황 출력 if episode % 500 == 0 or episode == NUM_EPISODES - 1: final_distance = env.total_distance if done else float('inf') print(f"Episode: {episode}, Total Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.4f}, Final Distance: {final_distance:.2f}") if episode == NUM_EPISODES - 1: last_episode_done = done # 최종 경로 시각화 (마지막 에피소드가 완료된 경우에만) if last_episode_done: final_route = env.tour final_distance = env.total_distance # 시각화 함수 def plot_route(env, route, distance): coords = env.cities route_coords = np.append(coords[route], [coords[route[0]]], axis=0) plt.figure(figsize=(10, 10)) plt.plot(route_coords[:, 0], route_coords[:, 1], marker='o', linestyle='-', color='blue') plt.scatter(coords[:, 0], coords[:, 1], c='red', s=50, zorder=5) plt.title('TSP DQN :'+str(distance)) plt.xlabel('X Coordinate') plt.ylabel('Y Coordinate') plt.grid(True) plt.show() plot_route(env, final_route, final_distance)
Python
복사
결과는 아래와 같습니다.
Reference: