2021-2-08 tensorflow2.0 Muzero
作者:互联网
参考资料:
[1]ColinFred. 蒙特卡洛树搜索(MCTS)代码详解【python】. 2019-03-23 23:37:09.
[2]饼干Japson 深度强化学习实验室.【论文深度研读报告】MuZero算法过程详解.2021-01-19.
[3]Tangarf. Muzero算法研读报告. 2020-08-31 11:40:20 .
[4]带带弟弟好吗. AlphaGo版本三——MuZero. 2020-08-30.
[5]Google原论文:Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model.
[6]参考GitHub代码1.
[7]参考GitHub代码2.
import tensorflow as tf
import numpy as np
class MuZeroModels(object):
def __init__(
self,
representation_layer_list: "list",
dynamics_layer_list: "list",
prediction_layer_list: "list",
):
self.representation = tf.keras.Sequential(
representation_layer_list,
name="representation function: obs 1 2 3 ... -> hidden state"
)
self.dynamics = tf.keras.Sequential(
dynamics_layer_list,
name="dynamics function: hidden state(k) AND action -> hidden state(k+1) AND reward"
)
self.prediction = tf.keras.Sequential(
prediction_layer_list,
name="prediction function: hidden state -> poliby AND value function"
)
@staticmethod
def loss(
reward_target,
value_target,
polict_target,
reward_pred,
valude_pred,
polict_pred
):
return tf.losses.mean_squared_error(
y_pred=reward_pred,
y_true=reward_target
) + tf.losses.categorical_crossentropy(
y_pred=valude_pred,
y_true=value_target
) + tf.losses.categorical_crossentropy(
y_pred=polict_pred,
y_true=polict_target
)
class minmax(object):
def __init__(self):
self.maximum = -float("inf")
self.minimum = float("inf")
def update(self, value):
self.maximum = max(self.maximum, value)
self.minimum = min(self.minimum, value)
def normalize(self, value):
if self.maximum > self.minimum:
return (value - self.minimum) / (self.maximum - self.minimum)
return value
class TreeNode(object):
def __init__(
self,
parent,
prior_p,
hidden_state,
reward,
is_PVP: 'bool'=False,
gamma=0.997
):
self._parent = parent
self._children = {}
self._num_visits = 0
self._Q = 0
self._U = 0
self._P = prior_p
self._hidden_state = hidden_state
self.reward = reward
self._is_PVP = is_PVP
self._gamma = gamma
def expand(self, action_priorP_hiddenStates_reward):
'''
:param action_priors: 元组类型,第一项为执行的动作, 第二项为预测的这个动作的概率, 第三项为 hidden state
生成新节点扩展树
'''
for action, prob, hidden_state, reward in action_priorP_hiddenStates_reward:
if action not in self._children.keys():
self._children[action] = TreeNode(
parent=self,
prior_p=prob,
hidden_state=hidden_state,
reward=reward,
is_PVP=self._is_PVP,
gamma=self._gamma
)
def select(self, c_puct_1=1.25, c_puct_2=19652):
'''
:param c_puct_1: 这里根据论文的值设为1.25
:param c_puct_2: 这里根据论文的值设为19652
:return: 选择UCB值最大的节点
'''
return max(
self._children.items(),
key=lambda node_tuple: node_tuple[1].get_value(c_puct_1, c_puct_2)
)
def _update(self, value, reward, minmax):
'''
:param reward: 从最后叶子节点 n_l 到当前节点 n_k 回溯的奖励累计(乘上衰变因子)
:param value: 模型估计的最后的叶子节点 n_l 的值乘上 gamma ^ (l-k)
注意:此函数无需在类外调用
'''
_G = reward + value
minmax.update(_G)
_G = minmax.normalize(_G)
self._Q = (self._num_visits * self._Q + _G) / (self._num_visits + 1)
self._num_visits += 1
def backward_update(self, minmax, value, backward_reward=0):
'''
:param backward_reward: 从叶子节点回溯的所有奖励乘上衰变因子 gamma 后之和
:param value: 最后叶子节点估计的值函数
注意:此函数只用在叶子节点调用, 非叶子节点不调用,值函数之评估最终状态
'''
self._update(value, backward_reward, minmax)
if self._is_PVP:
all_rewards = self.reward - self._gamma * backward_reward
else:
all_rewards = self.reward + self._gamma * backward_reward
if self._parent:
self._parent.backward_update(minmax, self._gamma * value, all_rewards)
def get_value(self, c_puct_1=1.25, c_puct_2=19652):
'''
:param c_puct_1: 这里根据论文的值设为1.25
:param c_puct_2: 这里根据论文的值设为19652
:return: 计算的值
注意这里UCB地值计算和 alphazero 不一样
'''
self._U = self._P *\
(np.sqrt(self._parent._num_visits)/(1 + self._num_visits)) *\
(
c_puct_1 + np.log(
(self._parent._num_visits + c_puct_2 + 1)/c_puct_2)
)
return self._Q + self._U
def is_leaf(self):
return self._children == {}
def is_root(self):
return self._parent is None
class MCTS(object):
def __init__(
self,
model: 'MuZeroModels',
observations,
reward,
is_PVP: 'bool'=False,
gamma=0.997,
num_playout=50,
c_puct_1=1.25,
c_puct_2=19652,
):
self._muzero_model = model
self._minmax = minmax()
self._root = TreeNode(
parent=None,
prior_p=1.0,
hidden_state=self._muzero_model.representation.predict(observations),
reward=reward,
is_PVP=is_PVP,
gamma=gamma
)
self._c_pict_1 = c_puct_1
self._c_pict_2 = c_puct_2
self._num_playout = num_playout
def _playout(self):
node = self._root
while True:
if node.is_leaf():
break
_, node = node.select(self._c_pict_1, self._c_pict_2)
action_probs, value = self._muzero_model.prediction.predict(node._hidden_state)[0]
action_probs = list(action_probs)
action_priorP_hiddenStates_reward = []
for action_prob in action_probs:
action_num = action_probs.index(action_prob)
action = 'action:'+str(action_num)
prob = action_probs[action_num]
action_num_one_hot = [1 if i == action_num else 0 for i in range(len(action_prob))]
next_hidden_state, reward = self._muzero_model.dynamics.predict([node._hidden_state, action_num_one_hot])
action_priorP_hiddenStates_reward.append((action, prob, next_hidden_state, reward))
node.expand(action_priorP_hiddenStates_reward)
node.backward_update(minmax=self._minmax, value=value)
def choice_action(self):
for _ in range(self._num_playout):
self._playout()
actions = []
visits = []
for action, node in self._root._children.items():
actions.append(action)
visits.append(node._num_visits)
exp_visits = np.exp(visits)
return actions, exp_visits / np.sum(exp_visits)
def __str__(self):
return "MuZero_MCTS"
class MuZero:
def __init__(self):
pass
ps : 代码未完全完成,如有错误欢迎更正。
标签:08,value,tensorflow2.0,num,._,action,Muzero,reward,self 来源: https://blog.csdn.net/weixin_41369892/article/details/113754384