Search

REINFORCE으로 TSP 문제 풀어보기

앞에서 설명한 DQN은 가치 기반(Value-based) 학습 방법의 대표적인 예입니다. 즉, 각 상태(State)에서 행동(Action)을 했을 때의 미래 보상 합계인 Q 값을 배우는 데 집중합니다. 그리고 학습된 Q 값을 바탕으로 가장 Q 값이 높은 행동을 선택하는 정책을 따릅니다.
반면, REINFORCE는 정책 기반(Policy-based) 학습 방법의 가장 기본적인 형태입니다. REINFORCE는 Q 값 대신, 어떤 상태에서 어떤 행동을 할 확률이 높은지를 직접적으로 나타내는 정책(Policy) 자체를 배웁니다.
강화학습에서 정책 π는 에이전트가 주어진 상태 s에서 어떤 행동 a를 선택할지를 결정하는 전략입니다. Policy-based 방법에서는 이 정책을 수학적인 함수 형태로 표현하고, 그 함수의 파라미터 θ를 학습을 통해 조정합니다.
정책 기반 방법은 총 보상의 기댓값을 최대화하기 위해, 정책 파라미터 θ를 총 보상 기댓값 함수의 경사(Gradient) 방향으로 업데이트합니다.
마치 산을 오를 때 가장 가파른 경사 방향으로 발걸음을 옮기는 것처럼 작동하는데 이를 정책 경사(Policy Gradient)라고 합니다.
특징
DQN
REINFORCE
접근방식
가치기반
정책기반
학습대상
행동가치함수
정책
목적함수
TD오차의 최소화
총보상값의 기댓값 최대화
최적화
경사하강법
경사상승법
결과
최적Q 함수 학습을 통한 최적정책 도출
최적 정책 직접 학습
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import numpy as np import matplotlib.pyplot as plt learning_rate = 0.0009 gamma = 0.98 num_cities = 20 cities = np.array([[0.0, 0.0], [0.0, 900], [100, 500], [200, 200], [400, 100], [400, 800], [700, 200], [800, 500], [900, 0.0], [900, 900], [0.0, 100], [0.0, 700], [400, 0.0], [400, 100], [400, 800], [400, 900], [700, 0.0], [700, 900], [900, 200], [900, 700]]) device = torch.device("cpu") #if torch.backends.mps.is_available(): # device = torch.device("mps") print(device) class Policy(nn.Module): def __init__(self): super(Policy, self).__init__() self.data = [] self.fc1 = nn.Linear(num_cities, num_cities).to(device) self.fc2 = nn.Linear(num_cities, num_cities).to(device) self.optimizer = optim.Adam(self.parameters(), lr=learning_rate) def forward(self, x): x = F.relu(self.fc1(x)) x = F.softmax(self.fc2(x), dim=0) return x def put_data(self, item): self.data.append(item) def train_net(self): R = 0 self.optimizer.zero_grad() for r, prob in self.data[::-1]: R = r + gamma * R loss = -torch.log(prob) * R loss.backward() self.optimizer.step() self.data = [] def train_net(self, gt): R = gt self.optimizer.zero_grad() for r, prob in self.data[::-1]: #R = r + gamma * R loss = -torch.log(prob) * R loss.backward() self.optimizer.step() self.data = [] def calculate_reward(self, tour): distance = 0 for i in range(len(tour) - 1): city1 = cities[tour[i]] city2 = cities[tour[i + 1]] distance += np.linalg.norm(city1 - city2) distance += np.linalg.norm(cities[tour[-1]] - cities[tour[0]]) return 1 / int(distance) def main(): pi = Policy() score = 0.0 print_interval = 1000 for n_epi in range(150000): state = np.zeros(num_cities) tour = [0] state[0] = 1 pi.data = [] for _ in range(num_cities - 1): state_tensor = torch.from_numpy(state).float().to(device) prob = pi(state_tensor) available_cities = np.where(state == 0)[0] action_probs = prob.cpu().detach().numpy()[available_cities] action_probs = np.nan_to_num(action_probs, nan=0.001) action_probs /= action_probs.sum() action = np.random.choice(available_cities, p=action_probs) tour.append(action) reward = pi.calculate_reward(tour) pi.put_data((reward, prob[action])) state[action] = 1 score += reward pi.train_net(pi.calculate_reward(tour)) if n_epi % print_interval == 0 and n_epi != 0: print("# of episode :{}, avg score : {}".format(n_epi, score / print_interval)) score = 0.0 print("Optimal Tour:", tour) # 경로 시각화 plt.figure(figsize=(10, 10)) for i in range(num_cities - 1): plt.plot([cities[tour[i]][0], cities[tour[i + 1]][0]], [cities[tour[i]][1], cities[tour[i + 1]][1]], 'b-') plt.plot([cities[tour[-1]][0], cities[tour[0]][0]], [cities[tour[-1]][1], cities[tour[0]][1]], 'b-') plt.scatter(cities[:, 0], cities[:, 1], c='r', marker='o') route_distance = calculate_total_distance(cities, tour) plt.xlabel('X') plt.ylabel('Y') plt.title('TSP REINFORCE: '+str(route_distance)) plt.grid(True) plt.show() def calculate_total_distance(coordinates, tour): total_distance = 0 num_cities = len(coordinates) for i in range(num_cities-1): current_city_index = tour[i] next_city_index = tour[i+1] total_distance += calculate_distance(coordinates[current_city_index], coordinates[next_city_index]) #total_distance += calculate_distance(coordinates[num_cities-1], coordinates[0]) return total_distance def calculate_distance(point1, point2): return int(np.sqrt((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)) if __name__ == '__main__': main()
Python
복사
결과는 아래와 같습니다.
Reference: