其他分享
首页 > 其他分享> > 2021-2-08 tensorflow2.0 Muzero

2021-2-08 tensorflow2.0 Muzero

作者:互联网

参考资料:
[1]ColinFred. 蒙特卡洛树搜索(MCTS)代码详解【python】. 2019-03-23 23:37:09.
[2]饼干Japson 深度强化学习实验室.【论文深度研读报告】MuZero算法过程详解.2021-01-19.
[3]Tangarf. Muzero算法研读报告. 2020-08-31 11:40:20 .
[4]带带弟弟好吗. AlphaGo版本三——MuZero. 2020-08-30.
[5]Google原论文:Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model.
[6]参考GitHub代码1.
[7]参考GitHub代码2.

import tensorflow as tf
import numpy as np

class MuZeroModels(object):
    def __init__(
        self,
        representation_layer_list: "list",
        dynamics_layer_list: "list",
        prediction_layer_list: "list",
    ):

        self.representation = tf.keras.Sequential(
            representation_layer_list,
            name="representation function: obs 1 2 3 ... -> hidden state"
        )

        self.dynamics = tf.keras.Sequential(
            dynamics_layer_list,
            name="dynamics function: hidden state(k) AND action -> hidden state(k+1) AND reward"
        )

        self.prediction = tf.keras.Sequential(
            prediction_layer_list,
            name="prediction function: hidden state -> poliby AND value function"
        )

    @staticmethod
    def loss(
        reward_target,
        value_target,
        polict_target,
        reward_pred,
        valude_pred,
        polict_pred
    ):
        return tf.losses.mean_squared_error(
            y_pred=reward_pred,
            y_true=reward_target
        ) + tf.losses.categorical_crossentropy(
            y_pred=valude_pred,
            y_true=value_target
        ) + tf.losses.categorical_crossentropy(
            y_pred=polict_pred,
            y_true=polict_target
        )

class minmax(object):
    def __init__(self):
        self.maximum = -float("inf")
        self.minimum = float("inf")

    def update(self, value):
        self.maximum = max(self.maximum, value)
        self.minimum = min(self.minimum, value)

    def normalize(self, value):
        if self.maximum > self.minimum:
            return (value - self.minimum) / (self.maximum - self.minimum)
        return value

class TreeNode(object):
    def __init__(
        self,
        parent,
        prior_p,
        hidden_state,
        reward,
        is_PVP: 'bool'=False,
        gamma=0.997
    ):
        self._parent = parent
        self._children = {}
        self._num_visits = 0
        self._Q = 0
        self._U = 0
        self._P = prior_p

        self._hidden_state = hidden_state
        self.reward = reward

        self._is_PVP = is_PVP
        self._gamma = gamma

    def expand(self, action_priorP_hiddenStates_reward):
        '''
        :param action_priors: 元组类型,第一项为执行的动作, 第二项为预测的这个动作的概率, 第三项为 hidden state
        生成新节点扩展树
        '''
        for action, prob, hidden_state, reward in action_priorP_hiddenStates_reward:
            if action not in self._children.keys():
                self._children[action] = TreeNode(
                    parent=self,
                    prior_p=prob,
                    hidden_state=hidden_state,
                    reward=reward,
                    is_PVP=self._is_PVP,
                    gamma=self._gamma
                )

    def select(self, c_puct_1=1.25, c_puct_2=19652):
        '''
        :param c_puct_1: 这里根据论文的值设为1.25
        :param c_puct_2: 这里根据论文的值设为19652
        :return: 选择UCB值最大的节点
        '''
        return max(
            self._children.items(),
            key=lambda node_tuple: node_tuple[1].get_value(c_puct_1, c_puct_2)
        )

    def _update(self, value, reward, minmax):
        '''
        :param reward: 从最后叶子节点 n_l 到当前节点 n_k 回溯的奖励累计(乘上衰变因子)
        :param value: 模型估计的最后的叶子节点 n_l 的值乘上 gamma ^ (l-k)
        注意:此函数无需在类外调用
        '''
        _G = reward + value
        minmax.update(_G)
        _G = minmax.normalize(_G)
        self._Q = (self._num_visits * self._Q + _G) / (self._num_visits + 1)
        self._num_visits += 1

    def backward_update(self, minmax, value, backward_reward=0):
        '''
        :param backward_reward: 从叶子节点回溯的所有奖励乘上衰变因子 gamma 后之和
        :param value: 最后叶子节点估计的值函数
        注意:此函数只用在叶子节点调用, 非叶子节点不调用,值函数之评估最终状态
        '''
        self._update(value, backward_reward, minmax)
        if self._is_PVP:
            all_rewards = self.reward - self._gamma * backward_reward
        else:
            all_rewards = self.reward + self._gamma * backward_reward

        if self._parent:
            self._parent.backward_update(minmax, self._gamma * value, all_rewards)

    def get_value(self, c_puct_1=1.25, c_puct_2=19652):
        '''
        :param c_puct_1: 这里根据论文的值设为1.25
        :param c_puct_2: 这里根据论文的值设为19652
        :return: 计算的值
        注意这里UCB地值计算和 alphazero 不一样
        '''
        self._U = self._P *\
                  (np.sqrt(self._parent._num_visits)/(1 + self._num_visits)) *\
                  (
                    c_puct_1 + np.log(
                      (self._parent._num_visits + c_puct_2 + 1)/c_puct_2)
                  )
        return self._Q + self._U

    def is_leaf(self):
        return self._children == {}

    def is_root(self):
        return self._parent is None

class MCTS(object):
    def __init__(
        self,
        model: 'MuZeroModels',
        observations,
        reward,
        is_PVP: 'bool'=False,
        gamma=0.997,
        num_playout=50,
        c_puct_1=1.25,
        c_puct_2=19652,
    ):

        self._muzero_model = model
        self._minmax = minmax()
        self._root = TreeNode(
            parent=None,
            prior_p=1.0,
            hidden_state=self._muzero_model.representation.predict(observations),
            reward=reward,
            is_PVP=is_PVP,
            gamma=gamma
        )
        self._c_pict_1 = c_puct_1
        self._c_pict_2 = c_puct_2
        self._num_playout = num_playout

    def _playout(self):
        node = self._root
        while True:
            if node.is_leaf():
                break
            _, node = node.select(self._c_pict_1, self._c_pict_2)
        action_probs, value = self._muzero_model.prediction.predict(node._hidden_state)[0]
        action_probs = list(action_probs)

        action_priorP_hiddenStates_reward = []

        for action_prob in action_probs:
            action_num = action_probs.index(action_prob)
            action = 'action:'+str(action_num)

            prob = action_probs[action_num]

            action_num_one_hot = [1 if i == action_num else 0 for i in range(len(action_prob))]
            next_hidden_state, reward = self._muzero_model.dynamics.predict([node._hidden_state, action_num_one_hot])

            action_priorP_hiddenStates_reward.append((action, prob, next_hidden_state, reward))

        node.expand(action_priorP_hiddenStates_reward)

        node.backward_update(minmax=self._minmax, value=value)

    def choice_action(self):
        for _ in range(self._num_playout):
            self._playout()
        actions = []
        visits = []
        for action, node in self._root._children.items():
            actions.append(action)
            visits.append(node._num_visits)

        exp_visits = np.exp(visits)

        return actions, exp_visits / np.sum(exp_visits)

    def __str__(self):
        return "MuZero_MCTS"

class MuZero:
    def __init__(self):
        pass

ps : 代码未完全完成,如有错误欢迎更正。

标签:08,value,tensorflow2.0,num,._,action,Muzero,reward,self
来源: https://blog.csdn.net/weixin_41369892/article/details/113754384