发布时间:2024-04-12 13:01
代码cr:MorvanZhou (Morvan) · GitHub
一根杆子由一个非驱动的关节连接到小车上,小车沿着无摩擦的轨道移动。这个系统是通过对小车施加+1或-1的力来控制的。钟摆开始直立,目的是防止它倒下。柱子保持直立的每一步将获得+1奖励。当电线杆与垂直的距离超过15度,或者车与中心的距离超过2.4个单位时,就结束了。
环境脚本:gym/gym/envs/classic_control at master · openai/gym · GitHub
class PolicyGradient:
#初始化
def _init_(self,n_actions,
n_features,
learning_rate=0.01,
reward_decay=0.95,
output_graph=False,):
#建立policy gradient神经网络
def _bulid_net(self):
#选行为
def choose_action(self,observation):
#存储回合transition
def store_transition(self,s,a,r):
#学习更新参数
def learn(self,s,a,r,s_):
#衰减回合的reward
def _discount_and_norm_rewards(self):
import gym
from RL_brain import PolicyGradient
import matplotlib.pyplot as plt
DISPLAY_REWARD_THRESHOLD = 400 #当回合总reward大于400时显示模拟窗口
RENDER = False #在屏幕上显示模拟窗口会拖慢运行速度,我们等计算机学的差不多了再显示模拟
env = gym.make('CartPole-v0')
env.seed(1) # reproducible, general Policy gradient has high variance
env = env.unwrapped #取消限制
print(env.action_space) #显示可用action
print(env.observation_space) #显示可用state的observation
print(env.observation_space.high) #显示observation最高值
print(env.observation_space.low) #显示observation最低值
RL = PolicyGradient( #定义RL部分的policygradient以及相关参数
n_actions=env.action_space.n,
n_features=env.observation_space.shape[0],
learning_rate=0.02,
reward_decay=0.99,
# output_graph=True,
)
for i_episode in range(3000):#基础版的policy gradient是回合更新
observation = env.reset() #获得初始环境observation
while True:
if RENDER: env.render() #重绘环境的一幅图像
action = RL.choose_action(observation) #选取动作
observation_, reward, done, info = env.step(action) #将这一状态的动作传入、获取下一状态的观测值、奖励以及是否已经终结
RL.store_transition(observation, action, reward) #存储这一回合的transition
if done: #回合结束
ep_rs_sum = sum(RL.ep_rs) #对reward进行处理,使其更有导向性的引导policygradient的gradient方向
if 'running_reward' not in globals():
running_reward = ep_rs_sum
else:
running_reward = running_reward * 0.99 + ep_rs_sum * 0.01
if running_reward > DISPLAY_REWARD_THRESHOLD: RENDER = True # rendering
print("episode:", i_episode, " reward:", int(running_reward))
vt = RL.learn()
if i_episode == 0:
plt.plot(vt) # plot the episode vt
plt.xlabel('episode steps')
plt.ylabel('normalized state-action value')
plt.show()
break
observation = observation_