强化学习记录——PolicyGradient跑CartPole-v0

发布时间:2024-04-12 13:01

代码cr:MorvanZhou (Morvan) · GitHub

一、CartPole-v0环境介绍:

强化学习记录——PolicyGradient跑CartPole-v0_第1张图片

一根杆子由一个非驱动的关节连接到小车上,小车沿着无摩擦的轨道移动。这个系统是通过对小车施加+1或-1的力来控制的。钟摆开始直立,目的是防止它倒下。柱子保持直立的每一步将获得+1奖励。当电线杆与垂直的距离超过15度,或者车与中心的距离超过2.4个单位时,就结束了。

环境脚本:gym/gym/envs/classic_control at master · openai/gym · GitHub

二、RL_brain.py

class PolicyGradient:

#初始化
def _init_(self,n_actions,
           n_features,
           learning_rate=0.01,
           reward_decay=0.95,
           output_graph=False,):

#建立policy gradient神经网络
def _bulid_net(self):

#选行为
def choose_action(self,observation):

#存储回合transition
def store_transition(self,s,a,r):

#学习更新参数
def learn(self,s,a,r,s_):

#衰减回合的reward
def _discount_and_norm_rewards(self):

三、run_CartPole.py

import gym
from RL_brain import PolicyGradient
import matplotlib.pyplot as plt

DISPLAY_REWARD_THRESHOLD = 400  #当回合总reward大于400时显示模拟窗口
RENDER = False  #在屏幕上显示模拟窗口会拖慢运行速度,我们等计算机学的差不多了再显示模拟

env = gym.make('CartPole-v0')
env.seed(1)  # reproducible, general Policy gradient has high variance
env = env.unwrapped #取消限制

print(env.action_space) #显示可用action
print(env.observation_space) #显示可用state的observation
print(env.observation_space.high) #显示observation最高值
print(env.observation_space.low) #显示observation最低值

RL = PolicyGradient( #定义RL部分的policygradient以及相关参数
    n_actions=env.action_space.n,
    n_features=env.observation_space.shape[0],
    learning_rate=0.02,
    reward_decay=0.99,
    # output_graph=True,
)

for i_episode in range(3000):#基础版的policy gradient是回合更新

    observation = env.reset() #获得初始环境observation

    while True:
        if RENDER: env.render() #重绘环境的一幅图像

        action = RL.choose_action(observation) #选取动作

        observation_, reward, done, info = env.step(action) #将这一状态的动作传入、获取下一状态的观测值、奖励以及是否已经终结
 

        RL.store_transition(observation, action, reward) #存储这一回合的transition

        if done: #回合结束
            ep_rs_sum = sum(RL.ep_rs) #对reward进行处理,使其更有导向性的引导policygradient的gradient方向

            if 'running_reward' not in globals():
                running_reward = ep_rs_sum
            else:
                running_reward = running_reward * 0.99 + ep_rs_sum * 0.01
            if running_reward > DISPLAY_REWARD_THRESHOLD: RENDER = True     # rendering
            print("episode:", i_episode, "  reward:", int(running_reward))

            vt = RL.learn()

            if i_episode == 0:
                plt.plot(vt)    # plot the episode vt
                plt.xlabel('episode steps')
                plt.ylabel('normalized state-action value')
                plt.show()
            break

        observation = observation_

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号