其他分享
首页 > 其他分享> > baselines库中cmd_util.py模块对atari游戏的包装为什么要分成两部分并在中间加入flatten操作呢?

baselines库中cmd_util.py模块对atari游戏的包装为什么要分成两部分并在中间加入flatten操作呢?

作者:互联网

如题:

cmd_util.py模块中对应的代码:

 

 

可以看到不论是atari游戏还是retro游戏,在进行游戏环境包装的时候都是分成两部分的,如atari游戏,第一部分是make_atari,第二部分是wrap_deepmind,在两者之间有一个FlattenObservation操作。

 

通过FlattenObservation的代码可以知道,该操作是将observation的space从dict变为np.array,也就是gym.spaces.Dict变为gym.spaces.Box类型:

import numpy as np
import gym.spaces as spaces
from gym import ObservationWrapper


class FlattenObservation(ObservationWrapper):
    r"""Observation wrapper that flattens the observation."""
    def __init__(self, env):
        super(FlattenObservation, self).__init__(env)

        flatdim = spaces.flatdim(env.observation_space)
        self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32)

    def observation(self, observation):
        return spaces.flatten(self.env.observation_space, observation)

 

 

对atari游戏的两个包装方法来看:

def make_atari(env_id, max_episode_steps=None):
    env = gym.make(env_id)
    assert 'NoFrameskip' in env.spec.id
    env = NoopResetEnv(env, noop_max=30)
    env = MaxAndSkipEnv(env, skip=4)
    if max_episode_steps is not None:
        env = TimeLimit(env, max_episode_steps=max_episode_steps)
    return env

def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
    """Configure environment for DeepMind-style Atari.
    """
    if episode_life:
        env = EpisodicLifeEnv(env)
    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = WarpFrame(env)
    if scale:
        env = ScaledFloatFrame(env)
    if clip_rewards:
        env = ClipRewardEnv(env)
    if frame_stack:
        env = FrameStack(env, 4)
    return env

make_atari部分并不对observation部分进行处理,而wrap_deepmind部分才对observation部分进行处理,因此在baselines库中对这两部分拆开并在中间进行FlattenObservation操作 ,这样以好保证在wrap_deepmind部分的操作可以直接对np.array类型的observation进行操作。

 

个人评价:

其实感觉这个FlattenObservation操作还是有一定欠缺的,就是对MultiDiscrete的observation,没有对observation进行one-hot操作。

而这个代码中对Discrete的observation是进行了one-hot编码,而对MultiDiscrete的observation并没有进行one-hot编码,而这个对应MultiDiscrete是否应该进行one-hot编码也是要看具体情况的,如果observation的spaces虽然属于MultiDiscrete但是它的spaces.shape的很大,也就是observation的空间维度很大,这样的话也没有必要进行one-hot编码,但是如果shape比较小,如为2,这样的,那么就有必要one-hot。

如:

 

import gym

obs_space=gym.spaces.MultiDiscrete((3,5))

print(obs_space.shape)
print(obs_space.nvec)

可以知道如果observation的space属于上面的情况,那么不one-hot编码observation的空间编码长度为2, 如果one-hot编码后长度为8。

也就是不one-hot编码的一个observation,如:(2,3) ,one-hot编码后为(010 00100),

从这个形式上来看,好像对于MultiDsicrete的observation是否进行one-hot编码好像也没有太大的影响,或许baselines中的设置还是说的过去的。

 

 

但是这个代码中还有一个地方需要注意:

        self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(flatdim,), dtype=np.float32)

从这个代码中可以看到不论observation的原始数据类型是什么,只要进行了flatten操作都会把数据类型转为np.float32,这样的操作可能导致精度损失,有可能造成空间存储变大,所以这个FlattenObservation操作是非必要不使用的,不然很可能出问题的。

或许这也是在run.py中对使用FlattenObservation操作的限制了:

 

 

可以看到在baselines中只有对observation_space属于gym.spaces.Dict的才进行FlattenObservation操作。

 

 

 给出一个自己FlattenObservation操作单独写在一个文件中的代码:

import numpy as np
import gym.spaces as spaces
from gym import ObservationWrapper

from gym.spaces import Box
from gym.spaces import Discrete
from gym.spaces import MultiDiscrete
from gym.spaces import MultiBinary
from gym.spaces import Tuple
from gym.spaces import Dict


def flatdim(space):
    if isinstance(space, Box):
        return int(np.prod(space.shape))
    elif isinstance(space, Discrete):
        return int(space.n)
    elif isinstance(space, Tuple):
        return int(sum([flatdim(s) for s in space.spaces]))
    elif isinstance(space, Dict):
        return int(sum([flatdim(s) for s in space.spaces.values()]))
    elif isinstance(space, MultiBinary):
        return int(space.n)
    elif isinstance(space, MultiDiscrete):
        return int(np.prod(space.shape))
    else:
        raise NotImplementedError


def flatten(space, x):
    if isinstance(space, Box):
        return np.asarray(x, dtype=np.float32).flatten()
    elif isinstance(space, Discrete):
        onehot = np.zeros(space.n, dtype=np.float32)
        onehot[x] = 1.0
        return onehot
    elif isinstance(space, Tuple):
        return np.concatenate([flatten(s, x_part) for x_part, s in zip(x, space.spaces)])
    elif isinstance(space, Dict):
        return np.concatenate([flatten(s, x[key]) for key, s in space.spaces.items()])
    elif isinstance(space, MultiBinary):
        return np.asarray(x).flatten()
    elif isinstance(space, MultiDiscrete):
        return np.asarray(x).flatten()
    else:
        raise NotImplementedError


class FlattenObs(ObservationWrapper):
    r"""Observation wrapper that flattens the observation."""

    def __init__(self, env):
        super(FlattenObs, self).__init__(env)

        _flatdim = flatdim(env.observation_space)
        self.observation_space = spaces.Box(low=-float('inf'), high=float('inf'), shape=(_flatdim,), dtype=np.float32)

    def observation(self, observation):
        return flatten(self.env.observation_space, observation)


if __name__ == '__main__':
    import gym
    FlattenObs(gym.make('Pong-v0'))
    print(gym.make('Pong-v0').observation_space)
    print(gym.make('Pong-v0').observation_space.dtype)

 

 

 

 

==========================================

 

标签:cmd,observation,space,gym,py,库中,spaces,env,np
来源: https://www.cnblogs.com/devilmaycry812839668/p/16125142.html