当前位置:网站首页>Reinforcement learning series (IV) -policygradient example
Reinforcement learning series (IV) -policygradient example
2022-06-23 02:28:00 【languageX】
Above we introduced the simple Random Guessing Algorithm & Hill Climbing Algorithm to solve CartPole problem , It is mainly modified in the step of decision-making action , However, the methods described above are all random weight changes , For simple problems with few parameters, better results may be obtained , But if the problem is complicated , If a large number of parameters are required , This method is not ideal . This article mainly introduces based on PolicyGradient How to solve the problem CartPole problem .
PolicyGradient example
The policy based scheme has been introduced in the chapter of algorithm introduction , It is to model the policy directly , The strategy is represented by a neural network , Output an output probability to the action .
We are still based on the above learning framework , Just in the most important choose_action In the step , Adjusted for PolicyGradient The model predicted action.
First, let's look at the learning process , The main logic is added to the code comments .
The exploration process
# The learning process , Explore 1000 Time
for i_episode in range(1000):
# Reset the environment every time you explore
observation = env.reset()
while True:
if RENDER: env.render()
# Make decisions based on the policy model
action = RL.choose_action(observation)
# Executive action , Returns the observation status after the action is executed ,reward Etc
observation_, reward, done, info = env.step(action)
# Will observe , Actions and rewards are stored . You need to use these sequence values for model learning
RL.store_transition(observation, action, reward)
# This exploration is over
if done:
ep_rs_sum = sum(RL.ep_rs)
if 'running_reward' not in globals():
running_reward = ep_rs_sum
else:
# Accumulate the return value of each exploration
running_reward = running_reward * 0.99 + ep_rs_sum * 0.01
# reward Start rendering above threshold , Otherwise, learn again
if running_reward > DISPLAY_REWARD_THRESHOLD: RENDER = True
print("episode:", i_episode, "rewards:", int(running_reward), "RENDER", RENDER)
# Learn once per exploration
vt = RL.learn()
break
# Agent exploration step
observation = observation_Model update process
The most important logic code is action = RL.choose_action(observation)
Already in every exploration vt = RL.learn()
among RL Namely PolicyGradient, Now let's focus on PolicyGradient The model code , And code analysis :
class PolicyGradient:
def __init__(
self,
n_actions,
n_features,
learning_rate=0.01,
reward_decay=0.95,
output_graph=False,
):
# Dimension of action space --2
self.n_actions = n_actions
# Dimension of state characteristics --4
self.n_features = n_features
# Learning rate
self.lr = learning_rate
# Rate of return decay
self.gamma = reward_decay
# Observations from an exploration , Action value , And return value
self.ep_obs, self.ep_as, self.ep_rs = [], [], []
# Create a policy network
self._build_net()
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
if output_graph:
tf.summary.FileWriter("logs/", self.sess.graph)
def _build_net(self):
""" Create an implementation of the policy network
"""
# 2.x Version and 1.x Version compatibility issues
tf.disable_eager_execution()
with tf.name_scope('input'):
# Observation state --[B, 4]
self.tf_obs = tf.placeholder(tf.float32, [None, self.n_features], name="observations")
# Executive action --[B, ]
self.tf_acts = tf.placeholder(tf.int32, [None, ], name="actions_num")
# Cumulative return value --[B,]
self.tf_vt = tf.placeholder(tf.float32, [None, ], name="actions_value")
# Network structure , Two fully connected layers
layer = tf.layers.dense(
inputs=self.tf_obs,
units=10,
activation=tf.nn.tanh,
kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.3),
bias_initializer=tf.constant_initializer(0.1),
name='fc1',
)
all_act = tf.layers.dense(
inputs=layer,
units=self.n_actions,
activation=None,
kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.3),
bias_initializer=tf.constant_initializer(0.1),
name='fc2'
)
# utilize softmax The function predicts the probability of each action
self.all_act_prob = tf.nn.softmax(all_act, name='act_prob')
# Define the loss function
with tf.name_scope('loss'):
# to maximize total reward (log_p * R) is to minimize -(log_p * R), and the tf only have minimize(loss)
# The goal is to maximize (log_p * R) Equivalent to optimizer minimization -(log_p * R)
neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=all_act, labels=self.tf_acts)
# or in this way:
# neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions), axis=1)
loss = tf.reduce_mean(neg_log_prob * self.tf_vt)
# Define training , Update parameters
with tf.name_scope('train'):
self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)
def choose_action(self, observation, type="random"):
""" Define how to choose behavior , I.e. status s Behavior sampling at . Sample according to the current behavior probability distribution
:param observation: Current observations
:return: Actions selected according to the policy
"""
prob_weights = self.sess.run(self.all_act_prob, feed_dict={self.tf_obs: observation[np.newaxis, :]})
# Sample according to the given probability , Or take the maximum .(random Way to add more random and exploratory )
if type == "random":
action = np.random.choice(range(prob_weights.shape[1]), p=prob_weights.ravel())
else:
action = np.argmax(prob_weights.ravel())
return action
def store_transition(self, s, a, r):
""" Define storage , Turn the status of a round , Actions and rewards are preserved
:param s: Observations per step
:param a: Action value of each step
:param r: Every step reward
"""
self.ep_obs.append(s)
self.ep_as.append(a)
self.ep_rs.append(r)
def learn(self):
""" After each exploration to obtain data , Carry out learning and update policy network parameters
"""
# Calculate the cumulative discount return for an exploration
discounted_ep_rs_norm = self._discount_and_norm_rewards()
# Call the training function to update the parameters
self.sess.run(self.train_op, feed_dict={
self.tf_obs: np.vstack(self.ep_obs),
self.tf_acts: np.array(self.ep_as),
self.tf_vt: discounted_ep_rs_norm,
})
# Empty episode data , Waiting for the next exploratory learning
self.ep_obs, self.ep_as, self.ep_rs = [], [], []
return discounted_ep_rs_norm
def _discount_and_norm_rewards(self):
""" Decay round reward
"""
discounted_ep_rs = np.zeros_like(self.ep_rs)
running_add = 0
# Due to the consideration of long-term accumulation reward, Here is the reverse order .t moment reward: At present t moment reward * gamma + (t+1) The moment reward.
for t in reversed(range(0, len(self.ep_rs))):
running_add = running_add * self.gamma + self.ep_rs[t]
discounted_ep_rs[t] = running_add
# Normalize
discounted_ep_rs -= np.mean(discounted_ep_rs)
discounted_ep_rs /= np.std(discounted_ep_rs)
return discounted_ep_rssummary
Through the above logic , The framework of the whole exploration and learning process is summarized as follows :
# The learning process , Explore N Time
for i_episode in range(N):
observation = env.reset()
while True:
# Decision making action ( Replaceable modules )
action = choose_action(observation)
observation_, reward, done, _ = env.step(action)
# Store discovery sequence information
store_transition(observation, action, reward)
if done:
# Model learning ( Replaceable modules )
vt = learn()
break
# Agent exploration step , Update Observations
observation = observation_One of the most important is a decision model , The best guiding action can be obtained through the current observation state , To maximize long-term benefits .
The most important part of the decision model is the design of the network ( The code in this article uses a simple two-tier full link , More complex networks can be designed ), as well as loss Part of the design ( The goal is to maximize long-term benefits ).
In the next article , We will introduce a combination of policy based and value based Actor-Critic programme .
Code reference :
Chinese women's football team is really awesome !~
边栏推荐
- 【CodeWars】What is between?
- [CodeWars]Matrix Determinant
- Garbled code of SecureCRT, double lines, double characters, unable to input (personal detection)
- This monitoring tool is enough for the operation and maintenance of small and medium-sized enterprises - wgcloud
- Mobile communication Overview - Architecture
- Learning notes of recommendation system (1) - Collaborative Filtering - Theory
- CSDN browser assistant for online translation, calculation, learning and removal of all advertisements
- Unity official case nightmare shooter development summary < I > realization of the role's attack function
- Detailed explanation of various networking modes of video monitoring platform
- Performance test -- Jenkins environment construction for 15jmeter performance test
猜你喜欢

1. Mx6u bare metal program (5) - external interrupt

Mongodb aggregate query implements multi table associated query, type conversion, and returns specified parameters.

5g core network and core network evolution

Interviewer: with the for loop, why do you need foreach??

Understand GB, gbdt and xgboost step by step

Analog Electronic Technology

Three methods for solving Fibonacci sequence feibonacci (seeking rabbit) - program design

How to download online printing on Web pages to local PDF format (manual personal test)

Cmake configuration error, error configuration process, Preject files may be invalid

Data analysis method - user group analysis
随机推荐
How to prohibit copying and copying files to the local server remote desktop
Anaconda creates a new environment encounter pit
Markdown - mark above / below symbol (typora, latex)
Learning notes of recommendation system (1) - Collaborative Filtering - Theory
Detailed explanation of GCC usage
Data analysis method - user group analysis
Information theory and coding
C language game minesweeping [simple implementation]
Xgboost Guide
What is a smart farm?
1.3-1.4 web page data capture
Source code analysis | activity setcontentview I don't flash
1. Mx6u bare metal program (4) - GPIO module
Easygbs adds websocket message push, which can quickly locate video playback faults
Use of apicloud AVM framework list component list view and flex layout tutorial
//1.15 putchar function
PHP Base64 image processing Encyclopedia
For Xiaobai who just learned to crawl, you can understand it after reading it
Aikuai multi dialing + load balancing overlay bandwidth
Analysis of ThreadLocal