强化学习—— 蒙特卡洛树(Monte Carlo Tree Search, MCTS)
生活随笔
收集整理的這篇文章主要介紹了
强化学习—— 蒙特卡洛树(Monte Carlo Tree Search, MCTS)
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
強(qiáng)化學(xué)習(xí)—— 蒙特卡洛樹(Monte Carlo Tree Search, MCTS)
- 1. 單一狀態(tài)蒙特卡洛規(guī)劃
- 1.1 特點(diǎn)
- 1.2 數(shù)學(xué)模型
- 2. 上限置信區(qū)間策略
- 3. 蒙特卡洛樹搜索
- 3.1 選擇
- 3.2 擴(kuò)展
- 3.3 模擬
- 3.4 反向傳播
- 3.5 流程圖
- 4. 代碼實(shí)現(xiàn)
1. 單一狀態(tài)蒙特卡洛規(guī)劃
以 多臂賭博機(jī)(multi-armed bandits) 為例
1.1 特點(diǎn)
為序列決策問題,在利用(exploitation)和探索(exploration)之間保持平衡,利用為過去決策中的最佳匯報,探索為未來獲得更大回報。
1.2 數(shù)學(xué)模型
- 設(shè)有k個賭博機(jī),選擇第I個賭博機(jī)后,獲得的回報為:VItV_{I_t}VIt??
- 經(jīng)過n次操作后的悔值函數(shù)為(第一項(xiàng)為最大的獎賞):Qn=maxi=1,...,k∑t=1nVi,t?∑t=1nVIt,tQ_n=\displaystyle{max_{i=1,...,k}}\sum_{t=1}^n V_{i,t} - \sum_{t=1}^n V_{I_t,t}Qn?=maxi=1,...,k?t=1∑n?Vi,t??t=1∑n?VIt?,t?
2. 上限置信區(qū)間策略
upper confidence bound, UCB
s為賭博機(jī)在過去被選中的次數(shù)。
3. 蒙特卡洛樹搜索
3.1 選擇
3.2 擴(kuò)展
如果葉子節(jié)點(diǎn)L不是終止節(jié)點(diǎn),則隨機(jī)創(chuàng)建一個未被訪問節(jié)點(diǎn),選擇該節(jié)點(diǎn)作為后續(xù)節(jié)點(diǎn)C。
3.3 模擬
從節(jié)點(diǎn)C出發(fā),對游戲進(jìn)行模擬,直到博弈游戲結(jié)束。
3.4 反向傳播
用模擬結(jié)果來回溯更新導(dǎo)致這個結(jié)果的每個節(jié)點(diǎn)中的獲勝次數(shù)和訪問次數(shù)。
3.5 流程圖
此圖來源
4. 代碼實(shí)現(xiàn)
MCTS實(shí)際使用時可以根據(jù)任務(wù)進(jìn)行細(xì)節(jié)調(diào)整,以下為五子棋的MCTS代碼:
# -*- coding: utf-8 -*- # @Time : 2022/4/4 14:55 # @Author : CyrusMay WJ # @FileName: mcts.py # @Software: PyCharm # @Blog :https://blog.csdn.net/Cyrus_May import numpy as np import copy import datetimeclass Agent:"""turn: 0 means black player, 1 means white player."""def __init__(self, width=15, height=15, logger=None):self.width = widthself.height = heightself.logger = Noneself.turn = 0self.__init_board()def __init_board(self):self.black_board = np.zeros([self.width, self.height])self.white_board = np.zeros([self.width, self.height])self.all_board = self.black_board + self.white_boarddef judge_terminal(self):if self.turn:return self.__judge(self.white_board)else:return self.__judge(self.black_board)def __judge(self, board):for i in range(self.width):for j in range(self.height):if self.width - i >= 5 and board[i, j:i + 5].sum() == 5:return 1if self.height - j >= 5 and board[i:i + 5, j].sum() == 5:return 1if self.width - i >= 5 and self.height - j >= 5 and sum(board[i, j], board[i + 1, j + 1], \board[i + 2, j + 2], board[i + 3, j + 3],board[i + 4, j + 4]) == 5:return 1if self.i >= 4 and self.height - j >= 5 and sum(board[i, j], board[i - 1, j + 1], \board[i - 2, j + 2], board[i - 3, j + 3],board[i - 4, j + 4]) == 5:return 1return 0def update_board(self, x, y):if self.turn:self.black_board[x, y] = 1else:self.white_board[x, y] = 1self.all_board[x, y] = 1def next_state(self):x, y = np.where(1 - self.all_board)if not x.shape[0]:return None, Noneidx = np.random.choice(np.arange(x.shape[0]))x = x[idx]y = y[idx]return x, ydef childs_state(self):x, y = np.where(1 - self.all_board)return x, yclass Node():def __init__(self, agent, childs=[], parent=None):self.agent = agentself.childs = childsself.parent = parentself.reward = 0self.n = 0def add_child(self, node):self.childs.append(node)class MCTS():def __init__(self, max_epochs=10000, max_time=5, logger=None):self.logger = loggerself.max_epochs = max_epochsself.c = 1/np.sqrt(2) # 平衡因子self.max_time = max_timedef search(self, board):board = np.array(board)black_state = (board == 1).astype(np.int32)white_state = (board == 2).astype(np.int32)turn = 0 if black_state.sum() <= white_state.sum() else 1self.agent = Agent(logger=self.logger)self.agent.white_board = white_stateself.agent.black_board = black_stateself.agent.all_board = white_state + black_stateself.agent.turn = turnself.turn = turnreturn self.run()def run(self):root = Node(copy.deepcopy(self.agent))start = datetime.datetime.now()for i in range(self.max_epochs):path = self.selection(root,self.max_epochs)path = self.expand(path)if not path:continuereward = self.simulation(path)self.backward(path,reward)if datetime.datetime.now() - start > self.max_time:breakscores = np.array([self.ucb(node, self.max_epochs) for node in root.childs])x,y = np.where(self.agent.all_board - root.childs[np.argmax(scores)].agent.all_board)return x[0],y[0]def ucb(self, node, epoch):if node.turn == self.turn:return (node.n - node.reward) / (node.n + 1e-8) + 2 * np.sqrt(2 * np.log(epoch) / ((node.n-node.reward) + 1e-8))return node.reward / (node.n + 1e-8) + 2 * np.sqrt(2 * np.log(epoch) / (node.n + 1e-8))def selection(self, root, epoch):path = [root]while 1:if not root.childs:return pathscores = np.array([self.ucb(node, epoch) for node in root.childs])path.append(root.childs[np.argmax(scores)])return pathdef expand(self, path):if path[-1].n > 0 or len(path) == 1:x, y = path[-1].agent.childs_state()if not x.shape[0]:return Nonefor row, col in zip(x, y):node = copy.deepcopy(path[-1])node.turn = 1 - path[-1].agent.turnnode.agent.update_board(row, col)path[-1].add_child(node)path.append(path[-1].childs[0])return pathdef simulation(self, path):root = copy.deepcopy(path[-1])while 1:if root.judge_terminal():return 1 if root.agent.turn != self.turn else 0x, y = root.agent.next_state()if not x.shape[0]:return 0else:root.agent.update_board(x,y)root.agent.turn = 1 - root.agent.turndef backward(self,path,reward):for node in path:node.n += 1node.reward += rewardby CyrusMay 2022 04 04
生命是華麗錯覺
時間是賊偷走一切
————五月天(如煙)————
總結(jié)
以上是生活随笔為你收集整理的强化学习—— 蒙特卡洛树(Monte Carlo Tree Search, MCTS)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 强化学习(五)—— AlphaGo与Al
- 下一篇: python绘图—— matplotli