编程语言
首页 > 编程语言> > 强化学习(一)--Sarsa与Q-learning算法

强化学习(一)--Sarsa与Q-learning算法

作者:互联网

强化学习(一)--Sarsa与Q-learning算法


最近实验室有一个项目要用到强化学习,在这开个新坑来记录下强化学习的学习过程。
第一节就先来最简单的基于表格型的RL算法,包括经典的Sarsa和Q-learning算法。

由于时间原因,关于算法的理论知识不再详细介绍,重点是研究怎么编程实现,代码是参考的飞浆PaddlePaddle公开课的代码,下来又自己手撸了一遍。飞浆PaddlePaddle公开课是我认为最适合入门强化学习的公开课,科老师讲解的真的非常清晰,公开课地址

1. SARSA算法

sarsa算法是最基础的on-policy算法,它采用的是TD单步更新的方式,每一个step都会更新Q表格,Q表格的更新公式为:这也是代码最核心的部分,它就是将Q值不断逼近目标值,也就是未来总收益。
在这里插入图片描述
Sarsa的名字就来源于它更新Q表格时所用到的五个参数:S,A,R,S’,A’,它的算法伪代码为:
在这里插入图片描述
第一次看伪代码可能会有些懵,公开课里很贴心的给出了流程图:
![在这里插入图片描述](https://www.icode9.com/i/ll/?i=20210315185651139.png?,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM3MzMzMDQ4,size_16,color_FFFFFF,t_70在这里插入图片描述
根据流程图很容易就能编程实现Sarsa算法。

2. Q-learning算法

Q-learning算法则是off-policy算法,与sarsa算法一样都是采用查表的方式,不同的地方在于它的A’默认为最优策略选择的动作,而sarsa的A则是下一个状态要实际执行的动作。

因此,Q-learning算法的Q表更新公式有些不同,可以看到Target_Q使用的是下个状态下最大的Q值来更新Q表格:
在这里插入图片描述
Q-learning算法的伪代码,可以看到下一时刻的动作并不一定去执行:
在这里插入图片描述
它的流程图,与sarsa的对比就能看出不同:
在这里插入图片描述

3. 代码实现

以sarsa算法为例,来讲解一下怎么进行代码实现,这里使用的环境为gym中的CliffWalking,它有四个动作 :0 up, 1 right, 2 down, 3 left。小乌龟每走一步reward = -1,掉入黑色方框内reward=-100,小乌龟被拖到起点重新开始。
在这里插入图片描述

3.1主函数

主函数主要承担导入环境,定义智能体,训练及测试。

# 主函数
def main():

    # 导入环境
    env = gym.make("CliffWalking-v0")
    env = CliffWalkingWapper(env)

    # env = gym.make("FrozenLake-v0",is_slippery = False)
    # env = FrozenLakeWapper(env)

    agent = SarsaAgent(
        obs_n = env.observation_space.n,
        act_n = env.action_space.n,
        learning_rate = 0.1,
        gamma = 0.9,
        e_greed = 0.1)

    is_render = False

    # 进行500个轮次的训练
    for episode in range(1000):
        ep_reward,ep_steps = run_episode(env,agent,is_render)  # 一个episode
        print("Episode %s: steps = %s ,reward = %1.f" %(episode,ep_steps,ep_reward))

        # 每20个episode渲染一下看看效果
        if episode % 20 == 0:
            is_render = True
        else:
            is_render = False

    # 训练结束,测试效果
    test_episode(env,agent)

3.2训练及测试函数

训练函数的思路很简单,就按照上边的流程图实现就ok了。

def run_episode(env,agent,is_render=False):
    total_steps = 0    # 记录每个episode走了多少step
    total_reward = 0   # 记录每个episode获得的总reward

    obs = env.reset() # 获得s
    action = agent.sample(obs) # 选择一个动作a

    while True:
        next_obs,reward,done,_ = env.step(action)  # 与环境交互获得s',r
        next_action = agent.sample(next_obs)       # 获得a'

        # 更新Q表格
        agent.learn(obs,action,reward,next_obs,next_action,done)

        # s<-s'  a<-a'
        obs = next_obs
        action = next_action

        total_steps+=1
        total_reward+=reward

        if is_render:  # 是否渲染图像
            env.render()
        if done:       # 是否结束episode训练
            break
    return total_reward,total_steps

测试函数的思路和训练函数是一样的,只不过不再需要更新Q表格,而动作的选取完全是基于Q表格,实现函数为SarsaAgent 类中的predict() 函数。

# 测试函数
def test_episode(env,agent):
    total_reward = 0
    obs = env.reset()

    while True:
        action = agent.predict(obs)
        next_obs,reward,done,_ = env.step(action)

        total_reward += reward
        obs = next_obs
        time.sleep(0.5)
        env.render()
        if done:
            print('test reward = %.1f'%(total_reward))
            break

3.3 SarsaAgent类的实现

SarsaAgent类有三个主要函数组成:sample()、predict()、learn()

3.3.1 sample函数

sample函数主要实现e_greedy方法来选择动作,满足强化学习中的探索和利用。

    # 根据状态,选择动作 采用e-greedy算法
    def sample(self,obs):
        if np.random.uniform(0,1)<(1.0 - self.epsilon):  # 根据Q表格选择动作
            action = self.predict(obs)
        else:
            action = np.random.choice(self.act_n)        # 随机选择动作,探索
        return action

3.3.2 predict函数

predict函数则是实现查Q表,选择该状态下Q值最大的动作。

    # 查表格选择动作
    def predict(self,obs):
        Q_list = self.Q[obs,:]   # 取这一状态下所有a的Q值
        max_q = np.max(Q_list)   # 选择Q值最大的
        action_list = np.where(Q_list==max_q)[0]   # max_q有可能对应多个动作,取出所有的动作
        action = np.random.choice(action_list)     # 随机选择这些动作
        return action

3.3.3 learn函数

learn函数则是利用S,A,R,S’,A’的值完成Q表的更新。

    # 更新Q表格的方法
    def learn(self,obs,action,reward,next_obs,next_action,done):

        predict_Q = self.Q[obs,action]
        if done:
            target_Q = reward
        else:
            target_Q = reward + self.gamma * self.Q[next_obs,next_action]

        self.Q[obs,action] += self.lr * (target_Q - predict_Q)  # 修正Q值

3.4 Q-learning算法的改变

Q-learning算法的实现也比较类似,其中只需要改变learn函数和run_episode函数中的内容,其他保持不变。

    def learn(self,obs,action,reward,next_obs,done):
        predict_Q = self.Q[obs,action]

        if done:
            target_Q = reward
        else:
            target_Q = reward+self.gamma*np.max(self.Q[next_obs,:])

        self.Q[obs,action] += self.lr*(target_Q-predict_Q)
def run_episode(env,agent,is_render):
    total_reward = 0
    total_steps = 0

    obs = env.reset()

    while True:
        action = agent.sample(obs)
        next_obs,reward,done,_ = env.step(action)

        agent.learn(obs,action,reward,next_obs,done)

        obs = next_obs

        total_steps += 1
        total_reward+=reward

        if is_render:
            env.render()
        if done:
            break
    return total_reward,total_steps

标签:episode,env,--,算法,Sarsa,learning,action,reward,obs
来源: https://blog.csdn.net/qq_37333048/article/details/114848199