强化学习(四)—— Actor-Critic
生活随笔
收集整理的這篇文章主要介紹了
强化学习(四)—— Actor-Critic
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
強化學習(四)—— Actor-Critic
- 1. 網絡結構
- 2. 網絡函數
- 3. 策略網絡的更新-策略梯度
- 4. 價值網絡的更新-時序差分(TD)
- 5. 網絡訓練流程
- 6. 案例
1. 網絡結構
-
狀態價值函數:
Vπ(st)=∑aQπ(st,a)?π(a∣st)V_\pi(s_t)=\sum_aQ_\pi(s_t,a)\cdot\pi(a|s_t) Vπ?(st?)=a∑?Qπ?(st?,a)?π(a∣st?) -
通過策略網絡近似策略函數:
π(a∣s)≈π(a∣s;θ)\pi(a|s)\approx\pi(a|s;\theta) π(a∣s)≈π(a∣s;θ)
-
通過價值網絡近似動作價值函數:
q(s,a;W)≈Q(s,a)q(s,a;W)\approx Q(s,a) q(s,a;W)≈Q(s,a)
-
神經網絡近似后的狀態價值函數:
V(s;θ,W)=∑aq(s,a;W)?π(a∣s;θ)V(s;\theta ,W)=\sum_aq(s,a;W)*\pi(a|s;\theta) V(s;θ,W)=a∑?q(s,a;W)?π(a∣s;θ) -
通過對策略網絡不斷更新以增加狀態價值函數值。
-
通過對價值網絡不斷更新來更好的預測所獲得的回報。
2. 網絡函數
Policy Network
- 通過策略網絡近似策略函數
π(a∣st)≈π(a∣st;θ)π(a|s_t)≈π(a|s_t;\theta) π(a∣st?)≈π(a∣st?;θ) - 狀態價值函數及其近似
Vπ(st)=∑aπ(a∣st)Qπ(st,a)V_π(s_t)=\sum_aπ(a|s_t)Q_π(s_t,a) Vπ?(st?)=a∑?π(a∣st?)Qπ?(st?,a)
V(st;θ)=∑aπ(a∣st;θ)?Qπ(st,a)V(s_t;\theta)=\sum_aπ(a|s_t;\theta)·Q_π(s_t,a) V(st?;θ)=a∑?π(a∣st?;θ)?Qπ?(st?,a) - 策略學習最大化的目標函數
J(θ)=ES[V(S;θ)]J(\theta)=E_S[V(S;\theta)] J(θ)=ES?[V(S;θ)] - 依據策略梯度上升進行
θ←θ+β??V(s;θ)?θ\theta\gets\theta+\beta·\frac{\partial V(s;\theta)}{\partial \theta} θ←θ+β??θ?V(s;θ)?
3. 策略網絡的更新-策略梯度
Policy Network
- 策略梯度為:
g(a,θ)=?lnπ(a∣s;θ)?θ?q(s,a;W)?V(s;θ,W)?θ=E[g(A,θ)]g(a,\theta)=\frac{\partial ln\pi(a|s;\theta)}{\partial \theta}\cdot q(s,a;W)\\\frac{\partial V(s;\theta,W)}{\partial \theta}=E[g(A,\theta)] g(a,θ)=?θ?lnπ(a∣s;θ)??q(s,a;W)?θ?V(s;θ,W)?=E[g(A,θ)] - 可采用隨機策略梯度,(無偏估計)
a~π(?∣st;θ)θt+1=θt+β?g(a,θt)a\sim \pi(\cdot|s_t;\theta)\\\theta_{t+1}=\theta_t+\beta·g(a,\theta_t) a~π(?∣st?;θ)θt+1?=θt?+β?g(a,θt?)
4. 價值網絡的更新-時序差分(TD)
- TD的目標:
yt=rt+γq(st+1,at+1;Wt)y_t= r_t+\gamma q(s_{t+1},a_{t+1};W_t) yt?=rt?+γq(st+1?,at+1?;Wt?) - 損失函數為:
loss=12[q(st,at;Wt)?yt]2loss = \frac{1}{2}[q(s_t,a_t;W_t)-y_t]^2 loss=21?[q(st?,at?;Wt?)?yt?]2 - 采用梯度下降進行更新:
Wt+1=Wt?α??loss?W∣W=WtW_{t+1}=W_t-\alpha\cdot\frac{\partial loss}{\partial W}|_{W=W_t} Wt+1?=Wt??α??W?loss?∣W=Wt??
5. 網絡訓練流程
一次更新中,Agent執行一次動作,獲得一次獎勵。
6. 案例
該網絡的收斂對于模型大小、激活函數等參數較敏感。
# -*- coding: utf-8 -*- # @Time : 2022/3/29 21:51 # @Author : CyrusMay WJ # @FileName: AC.py # @Software: PyCharm # @Blog :https://blog.csdn.net/Cyrus_Mayimport tensorflow as tf import numpy as np import logging import sys import gymclass Critic():def __init__(self,logger=None,input_dim=6,gamma=0.9):self.logger = loggerself.__build_model(input_dim)self.gamma = gammaself.optimizer = tf.optimizers.Adam(learning_rate=0.001)def __build_model(self,input_dim):self.model = tf.keras.Sequential([tf.keras.layers.Dense(32, activation="relu"),tf.keras.layers.Dense(1)])self.model.build(input_shape=[None,input_dim])def predict(self,action,state):action = tf.one_hot([action],depth=2)state = tf.convert_to_tensor([state])x = tf.concat([action,state],axis=1)return self.model(x)[0][0]def train(self,state,state_,action,action_,reward,done):action = tf.one_hot([action], depth=2)state = tf.convert_to_tensor([state])action_ = tf.one_hot([action_], depth=2)state_ = tf.convert_to_tensor([state_])x = tf.concat([action, state], axis=1)x_ = tf.concat([action_, state_], axis=1)done = 0 if done else 1with tf.GradientTape() as tape:q = self.model(x)q_ = self.model(x_)Td_error = (reward + done * self.gamma * q_ - q)loss = tf.square(Td_error)dt = tape.gradient(loss,self.model.trainable_variables)self.optimizer.apply_gradients(zip(dt,self.model.trainable_variables))return Td_errorclass Actor():def __init__(self,logger=None,input_dim=4,gamma=0.9,output_dim=2):self.logger = loggerself.__build_model(input_dim,output_dim)self.gamma = gammaself.optimizer = tf.optimizers.Adam(learning_rate=0.001)self.output_dim = output_dimdef __build_model(self,input_dim,output_dim=2):self.model = tf.keras.Sequential([tf.keras.layers.Dense(32, activation="relu"),tf.keras.layers.Dense(output_dim)])self.model.build(input_shape=[None,input_dim])def predict(self,state):state = tf.convert_to_tensor([state])logits = self.model(state)prob = tf.nn.softmax(logits).numpy()action = np.random.choice([i for i in range(self.output_dim)],p=prob.ravel())return actiondef train(self,state,action,TD_error,done):state = tf.convert_to_tensor([state])with tf.GradientTape() as tape:logits = self.model(state)loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = [action], logits=logits)loss = tf.reduce_sum(tf.multiply(TD_error,loss))dt = tape.gradient(loss,self.model.trainable_variables)self.optimizer.apply_gradients(zip(dt,self.model.trainable_variables))class Agent():def __init__(self,gamma=0.9,logger=None):self.gamma = gammaself.logger = loggerself.env = gym.make("CartPole-v0")self.actor = Actor(logger=logger,input_dim=4,gamma=self.gamma,output_dim=2)self.critic = Critic(logger = logger,input_dim=6,gamma=self.gamma)def train(self,tran_epochs=1000,max_act=100):history_returns = []for epoch in range(tran_epochs):single_returns = 0state = self.env.reset()for iter in range(max_act):self.env.render()action = self.actor.predict(state)state_,reward,done,info = self.env.step(action)action_ = self.actor.predict(state_)TD_error = self.critic.train(state,state_,action,action_,reward,done)self.actor.train(state,action,TD_error,done)single_returns+=(reward)state = state_if done:breakif history_returns:history_returns.append(history_returns[-1]*0.9+0.1*single_returns)else:history_returns.append( single_returns)self.logger.info("epoch:{}\{} || epoch return:{:,.4f} || history return:{:,.4f}".format(tran_epochs,epoch+1,single_returns,history_returns[-1]))self.env.close()def test(self,max_act=1000):state = self.env.reset()single_returns = 0for iter in range(max_act):self.env.render()action = self.actor.predict(state)state_, reward, done, info = self.env.step(action)single_returns += (reward)if done:self.logger.info("End in {} iterations".format(iter+1))breakif not done:self.logger.info("success and return is {}".format(single_returns))if __name__ == '__main__':logger = logging.getLogger()logger.setLevel(logging.INFO)screen_handler = logging.StreamHandler(sys.stdout)screen_handler.setLevel(logging.INFO)formatter = logging.Formatter('%(asctime)s - %(module)s.%(funcName)s:%(lineno)d - %(levelname)s - %(message)s')screen_handler.setFormatter(formatter)logger.addHandler(screen_handler)agent = Agent(logger=logger)agent.train(tran_epochs=2000,max_act=500)agent.test()本文部分內容為參考B站學習視頻書寫的筆記!
by CyrusMay 2022 03 29
摸不到的顏色 是否叫彩虹
看不到的擁抱 是否叫做微風
————五月天(星空)————
總結
以上是生活随笔為你收集整理的强化学习(四)—— Actor-Critic的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 强化学习(三)—— 策略学习(Polic
- 下一篇: 强化学习(五)—— AlphaGo与Al