( 참고 : Fastcampus 강의 )

[ 9. Value Iteration 실습 2 ]

Policy Iteration과의 코드 비교

1. Policy Iteration

Policy가 수렴할 때 까지..

  • (1) Policy Evaluation과
  • (2) Policy Improvement가

iterative하게 반복한다.

def policy_iteration(self, policy=None):
    if policy is None:
        pi_old = self.policy
    else:
        pi_old = policy
    info = dict()
    info['v'] = list()
    info['pi'] = list()
    info['converge'] = None

    steps = 0
    converged = False
    while True:
        v_old = self.policy_evaluation(pi_old)
        pi_improved = self.policy_improvement(pi_old, v_old)
        steps += 1
        
        info['v'].append(v_old)
        info['pi'].append(pi_old)

        # check convergence
        diff = np.linalg.norm(pi_improved - pi_old)
        if diff <= self.error_tol:
            if not converged:  
                info['converge'] = steps
            break
        else:
            pi_old = pi_improved
    return info


%%time
dp_agent.reset_policy()
info_pi = dp_agent.policy_iteration()
Wall time: 4.99 ms


2. Value Iteration

\(V_{k+1}=\max _{a \in \mathcal{A}}\left(R^{a}+\gamma P^{a} V_{k+1}\right)\).

def value_iteration(self, v_init=None, compute_pi=False):
    if v_init is not None:
        v_old = v_init
    else:
        v_old = np.zeros(self.ns)

    info = dict()
    info['v'] = list()
    info['pi'] = list()
    info['converge'] = None

    steps = 0
    converged = False

    while True:
        v_improved = (self.R.T + self.P.dot(v_old)).max(axis=0) # [num. actions x num states]
        info['v'].append(v_improved)

		######### Optional (정책 계산이 반드시 필요하진 X ) ##########
		# compute policy from v
        if compute_pi:
            ## 1) Compute v -> q
            q_pi = (self.R.T + self.P.dot(v_improved)) # [num. actions x num states]

            ## 2) Construct greedy policy
            pi = np.zeros_like(self.policy)
            pi[np.arange(q_pi.shape[1]), q_pi.argmax(axis=0)] = 1
            info['pi'].append(pi)
            
        steps += 1

        # check convergence
        diff = np.linalg.norm(v_improved - v_old)

        if diff <= self.error_tol:
            if not converged:  
                info['converge'] = steps
            break
        else:
            v_old = v_improved
    return info


Policy Iteration보다 빠르게 진행됨을 알 수 있다.

%%time
dp_agent.reset_policy()
info_vi = dp_agent.value_iteration(compute_pi=True)
Wall time: 1.01 ms