daisukeの技術ブログ

AI、機械学習、最適化、Pythonなどについて、技術調査、技術書の理解した内容、ソフトウェア/ツール作成について書いていきます

ゼロから作るDeep Learning 4(強化学習編)の4章のサンプルコードに画像ファイルの出力機能を追加する

前回(ゼロから作るDeep Learning 4(強化学習編)の4章のサンプルコードにアニメーション化の機能を追加する - daisukeの技術ブログ)に続き、今回は、同じく「ゼロから作るDeep Learning 4(強化学習編)」の4章のサンプルコードの「policy_iter.py」に対して、ステップごとの更新された価値関数が書かれたマップの画像ファイルを、ファイル出力する機能を追加していきます。

また、ついでに、同じく4章のサンプルコードの「policy_eval.py」のステートごとの更新された価値関数が書かれたマップの画像ファイルを、ファイル出力する機能を追加します。

参考文献

はじめに

前回、マップの変化をアニメーション化したことで、どのようにマップが変化していくのかを直感的に理解することができました。一方で、細かく、価値関数や、方策が更新されていくのかを確認するには、少し不便でした。

そこで、今回は、前回アニメーション化していたものを、単純に画像ファイルとして出力する機能を追加しました。

なお、今回も、機能を追加したソースコードは、以下のGitHubに格納しています。

github.com

使い方

$ git clone https://github.com/dk0893/deep-learning-from-scratch-4.git -b v1.1-dk0893
Cloning into 'deep-learning-from-scratch-4'...
remote: Enumerating objects: 425, done.
remote: Counting objects: 100% (148/148), done.
remote: Compressing objects: 100% (33/33), done.
ta 115)R, pack-erceueiving objectss:e  d81 %2774[K5/425) 3eu40s/ed 1425)16 (d e  l
Receiving objects: 100% (425/425), 922.49 KiB | 0 bytes/s, done.
Resolving deltas: 100% (246/246), done.
Checking connectivity... done.
Note: checking out '4eca9bf48e1afbf56628107a33bffe2440df6000'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by performing another checkout.

If you want to create a new branch to retain commits you create, you may
do so (now or later) by using -b with the checkout command again. Example:

  git checkout -b <new-branch-name>


$ cd deep-learning-from-scratch-4/

$ python ch04/policy_iter.py --ope ani_step --fpath policy_iter_step.gif
save animation: policy_iter_step.gif

$ python ch04/policy_iter.py --ope im_step
save image: images\policy_iter_step_0.png
save image: images\policy_iter_step_1.png
save image: images\policy_iter_step_2.png
save image: images\policy_iter_step_3.png
save image: images\policy_iter_step_4.png

$ python ch04/policy_iter.py --ope ani_state --fpath policy_iter_state.gif
save animation: policy_iter_state.gif
ImageStore.cnt=384

$ python ch04/policy_iter.py --ope im_state
save image: images\policy_iter_step0_phase00_state_(0, 0).png
save image: images\policy_iter_step0_phase00_state_(0, 1).png
save image: images\policy_iter_step0_phase00_state_(0, 2).png
save image: images\policy_iter_step0_phase00_state_(0, 3).png
save image: images\policy_iter_step4_phase00_state_(2, 0).png
save image: images\policy_iter_step4_phase00_state_(2, 1).png
save image: images\policy_iter_step4_phase00_state_(2, 2).png
save image: images\policy_iter_step4_phase00_state_(2, 3).png
ImageStore.cnt=384

今回は、参考として、Google Colaboratoryで実行できるファイル(ch04-exec.ipynb)も一緒にコミットしておきました。

今回の機能追加の設計方針

前回の機能追加は、オリジナルのソースコードに対して、特に何も考えずに変更してしまいましたが、今回は、オリジナルのソースコードになるべく影響を与えないような変更にしました。

具体的には、オリジナルのソースコードに対して、最小限の追加で実現するようにしました。こうすることで、オリジナルのソースコードが更新された場合に、今回の機能追加分をマージすることが簡単になりますし、他の機能を追加したくなった場合に、複雑な構成にならないようにできます。

オリジナルからの変更内容

policy_iter.py

以下のように、policy_iter()は、1行の変更で実現できています。

--- deep-learning-from-scratch-4-org/ch04/policy_iter.py        2024-03-20 18:07:05.107000000 +0900
+++ deep-learning-from-scratch-4/ch04/policy_iter.py    2024-03-23 20:52:32.819000000 +0900
@@ -3,6 +3,7 @@
     sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
 from collections import defaultdict
 from common.gridworld import GridWorld
+from common.image_store import ImageStore
 from ch04.policy_eval import policy_eval
@@ -44,7 +45,7 @@
         new_pi = greedy_policy(V, env, gamma)

         if is_render:
-            env.render_v(V, pi)
+            ImageStore.st_step( env, V, pi )

         if new_pi == pi:
             break
@@ -53,7 +54,25 @@
     return pi

+def parse_args():
+    import argparse
+    parser = argparse.ArgumentParser( description='policy_iter.py' )
+    parser.add_argument( '--ope',   default=None,              help='select output operation, [None or im_step or im_state or ani_step or ani_state]' )
+    parser.add_argument( '--dpath', default="images",          help='input save image directory path' )
+    parser.add_argument( '--fpath', default='policy_iter.gif', help='input save animation path' )
+    return parser.parse_args()
+
 if __name__ == '__main__':
+    args = parse_args()
+    ImageStore.init( args.ope, args.dpath, args.fpath )
     env = GridWorld()
     gamma = 0.9
     pi = policy_iter(env, gamma)
+    ImageStore.output( env.renderer.fig )

policy_eval.py

こちらも、2行を追加しただけで実現できています。

--- deep-learning-from-scratch-4-org/ch04/policy_eval.py        2024-03-20 18:07:05.100000000 +0900
+++ deep-learning-from-scratch-4/ch04/policy_eval.py    2024-03-21 23:07:58.401000000 +0900
@@ -3,12 +3,14 @@
     sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
 from collections import defaultdict
 from common.gridworld import GridWorld
+from common.image_store import ImageStore

 def eval_onestep(pi, V, env, gamma=0.9):
     for state in env.states():
         if state == env.goal_state:
             V[state] = 0
+            ImageStore.st_state( env, V, pi, state )
             continue

         action_probs = pi[state]
@@ -18,6 +20,7 @@
             r = env.reward(state, action, next_state)
             new_V += action_prob * (r + gamma * V[next_state])
         V[state] = new_V
+        ImageStore.st_state( env, V, pi, state )
     return V

image_store.py

こちらは、新規追加したファイルです。クラス変数を使うことで、import文を書くだけで、一行の追加で変更ができています。

通常のクラスの使い方のようなインスタンスを生成する方法の場合、既存の関数の引数への追加が多く必要になってしまいます。

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation

class ImageStore:
    
    ope   = None
    dpath = None
    fpath = None
    
    artists_step  = []
    artists_state = []
    animation     = False
    
    step  = 0
    phase = 0
    cnt   = 0
    
    debug = False
    
    def init( ope=None, dpath=None, fpath=None, debug=False ):
        
        ImageStore.ope   = ope
        ImageStore.dpath = dpath
        ImageStore.fpath = fpath
        ImageStore.debug = debug
        
        if ImageStore.ope == "ani_step" or ImageStore.ope == "ani_state":
            ImageStore.animation = True
        elif ImageStore.ope == "im_step" or ImageStore.ope == "im_state":
            os.makedirs( ImageStore.dpath, exist_ok=True )
    
    def st_step( env, V, pi ):
        
        if ImageStore.ope == "im_step" or ImageStore.ope == "ani_step" or ImageStore.ope is None:
            frame = env.render_v(V, pi, title=f"step={ImageStore.step}")
            ImageStore.artists_step.append( frame )
        
        if ImageStore.ope == "im_step":
            fpath = os.path.join( ImageStore.dpath, f"policy_iter_step_{ImageStore.step}.png" )
            plt.savefig( fpath )
            plt.close()
            
            print( f"save image: {fpath}" )
        
        ImageStore.step += 1
        ImageStore.phase = 0
    
    def st_state( env, V, pi, state ):
        
        if ImageStore.ope == "im_state" or ImageStore.ope == "ani_state":
            frame = env.render_v( V, pi, title=f"step={ImageStore.step} phase={ImageStore.phase} state={state}" )
            ImageStore.artists_state.append( frame )
            
            if ImageStore.ope == "im_state":
                fpath = os.path.join( ImageStore.dpath, f"policy_iter_step{ImageStore.step}_phase{ImageStore.phase:02d}_state_{state}.png" )
                plt.savefig( fpath )
                plt.close()
                print( f"save image: {fpath}" )
            
            ImageStore.cnt += 1
            if ImageStore.cnt % np.prod(env.shape) == 0:
                ImageStore.phase += 1
            
            if ImageStore.debug:
                if ImageStore.ope == "ani_state" and ImageStore.phase == 1:
                    ImageStore.ope = "ani_end"
    
    def output( fig ):
        
        if ImageStore.ope == "ani_step" or ImageStore.ope == "ani_state" or ImageStore.ope == "ani_end":
            artists = ImageStore.artists_step if ImageStore.ope == "ani_step" else ImageStore.artists_state
            interval = 2000 if ImageStore.ope == "ani_step" else 500
            anim = ArtistAnimation( fig, artists, interval=interval )
            anim.save( ImageStore.fpath )
            
            print( f"save animation: {ImageStore.fpath}" )
        
        if ImageStore.ope == "im_state" or ImageStore.ope == "ani_state":
            print( f"ImageStore.cnt={ImageStore.cnt}" )

おわりに

価値関数の更新は思っていたより、たくさんの更新が必要だったことが、今回の画像ファイルの出力で分かりました。

何度も更新過程の画像ファイルを見たことで、理解が深まりました。

今回は以上です!

最後までお読みいただき、ありがとうございました。