( 참고 : Fastcampus 강의 )

[ 23. Q-learning 실습 ]


1. 복습

figure2


2. Import Packages

import numpy as np
import matplotlib.pyplot as plt

from src.part2.temporal_difference import QLearner
from src.common.gridworld import GridworldEnv
from src.common.grid_visualization import visualize_value_function, visualize_policy

np.random.seed(0)


3. Agent 생성

QLearner는 앞서서 봤던 TDAgent를 상속받는다.

qlearning_agent = QLearner(gamma=1.0,
                           lr=1e-1,
                           num_states=env.nS,
                           num_actions=env.nA,
                           epsilon=1.0)


4. SARSA vs Q- Learning

(1) SARSA

\(Q(s, a) \leftarrow Q(s, a)+\alpha\left(r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime}\right)-Q(s, a)\right)\).

def update_sample(self, s, a, r, s_, a_, done):
    td_target = r + self.gamma * self.q[s_, a_] * (1 - done)
    self.q[s, a] += self.lr * (td_target - self.q[s, a])


(2) Q-learning

\(Q(s, a) \leftarrow Q(s, a)+\alpha\left(r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime}\right)-Q(s, a)\right)\).

def update_sample(self, s, a, r, s_, done):
    td_target = r + self.gamma * self.q[s_, :].max() * (1 - done)
    self.q[s, a] += self.lr * (td_target - self.q[s, a])

차이점 :

  • SARSA : self.q[s_, a_]
  • Q-Learning : self.q[s_, :].max()


5. Run Iteration

episode 횟수 : 10,000 ( 출력 로그 간격 : 1,000)

num_episode = 10000
print_log = 1000

qlearning_qs = []
iter_idx = []
qlearning_rewards = []


for i in range(num_episode):
    reward_sum = 0
    env.reset()    
    while True:
		# (1) state 관찰 -> (2) action -> (3) reward,다음 state 받기
        s = env.s
        a = qlearning_agent.get_action(s)
        s_, r, done, info = env.step(a)
        
        ##### [SARSA와 달리, 이 과정이 없음] (4) 다음 state에 맞는 action #####
        
        # (5) 앞에서 얻게 된 s,a,r,s_로 update하기
        qlearning_agent.update_sample(state=s,
                                      action=a,
                                      reward=r,
                                      next_state=s_,
                                      done=done)
        reward_sum += r
        if done:
            break
    
    qlearning_rewards.append(reward_sum)
    
    if i % print_log == 0:
        print("Running {} th episode".format(i))
        print("Reward sum : {}".format(reward_sum))
        qlearning_qs.append(qlearning_agent.q.copy())
        iter_idx.append(i)