编程语言
首页 > 编程语言> > ML-Agents案例之“排序算法超硬核版”

ML-Agents案例之“排序算法超硬核版”

作者:互联网

本案例源自ML-Agents官方的示例,Github地址:https://github.com/Unity-Technologies/ml-agents,本文是详细的配套讲解。

本文基于我前面发的两篇文章,需要对ML-Agents有一定的了解,详情请见:Unity强化学习之ML-Agents的使用ML-Agents命令及配置大全

我前面的相关文章有:

ML-Agents案例之Crawler

ML-Agents案例之推箱子游戏

ML-Agents案例之跳墙游戏

ML-Agents案例之食物收集者

ML-Agents案例之双人足球

Unity人工智能之不断自我进化的五人足球赛

ML-Agents案例之地牢逃脱

ML-Agents案例之金字塔

ML-Agents案例之蠕虫

ML-Agents案例之机器人学走路

ML-Agents案例之看图配对

在这里插入图片描述

环境说明

如图所示,智能体在一个圆形的房间中,墙壁上会随机出现带有数字的方块,智能体需要按照数字从小到大与方块进行碰撞,碰撞过的方块会变成绿色,分数+1,一旦碰撞顺序不对,游戏结束,分数-1。

这个案例的挑战是,我们不会告诉智能体怎么排序是对的,智能体需要在环境中试错,从而自己学习到这种从小到大排序,碰撞对应方块的行为模式,同时墙壁上出现的数字方块的个数是不定的,也就是说每个episode我们都需要接收不同个数的输入,这应该怎么处理呢?

状态输入:这里用到了一个新的传感器Buffer Sensor。

在这里插入图片描述

这个传感器的作用是可以接收个数变化的状态输入。我们需要每次传入一个向量,这个向量我们可以用数组listObservation表示。通过 m_BufferSensor.AppendObservation(listObservation)传入到BufferSensor中,而BufferSensor可以接收无数个这样的向量输入,但是每个向量的维度必须相同。也就是说即使我们输入的向量个数每次都不同,我们还是能训练网络还是产生我们所期望的输出,具体是怎么实现的项目代码中没有,集成在了ML-Agents包中,根据我的经验,应该用了Self-attention这种网络的结构,这样就能接收不同个数向量的输入了。

除了传给BufferSensor的输入之外,还传入了四维的向量,分别是智能体位置到场地中心的向量在x轴和z轴上的分量,智能体前进方向在x轴和z轴的分量。

动作输出:输出三个离散值,每个离散值包含0-2三个数,第一个离散值决定了前进后退,第二个离散值决定了左移右移,第三个离散值决定了左转右转。

在这里插入图片描述

代码讲解

智能体下挂载的脚本除去万年不变的Decesion Requester,Model Overrider,Behavior Parameters,以及刚刚说明的Buffer Sensor,就只剩下智能体的只有文件SorterAgent.cs了:

头文件:

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;

定义变量:

// 默认数字方块的最大数量,可在编辑器中滑动调节,调节范围为1 - 20
[Range(1, 20)]
public int DefaultMaxNumTiles;
// 方块数字的最大值
private const int k_HighestTileValue = 20;
// 生成方块的数量
int m_NumberOfTilesToSpawn;
// 方块的最大数量
int m_MaxNumberOfTiles;
// 刚体
Rigidbody m_AgentRb;

// BufferSensorComponent 是一个传感器,允许观察不同数量的输入
BufferSensorComponent m_BufferSensor;
// 数字方块的列表
public List<NumberTile> NumberTilesList = new List<NumberTile>();
// 出现在场景中的方块列表
private List<NumberTile> CurrentlyVisibleTilesList = new List<NumberTile>();
// 已经被接触过的方块列表
private List<Transform> AlreadyTouchedList = new List<Transform>();

private List<int> m_UsedPositionsList = new List<int>();
// 初始位置
private Vector3 m_StartingPos;
// 整个场景
GameObject m_Area;
// 环境参数,可以从配置文件中获取
EnvironmentParameters m_ResetParams;
// 下一个想要碰撞的数字方块的索引
private int m_NextExpectedTileIndex;

初始化方法Initialize():

public override void Initialize()
{
    // 获取父物体
    m_Area = transform.parent.gameObject;
    // 获取方块的最大数量
    m_MaxNumberOfTiles = k_HighestTileValue;
    // 从配置文件中获取环境参数
    m_ResetParams = Academy.Instance.EnvironmentParameters;
    // 获取传感器脚本
    m_BufferSensor = GetComponent<BufferSensorComponent>();
    // 获取刚体
    m_AgentRb = GetComponent<Rigidbody>();
    // 起始位置
    m_StartingPos = transform.position;
}

状态输入方法:

public override void CollectObservations(VectorSensor sensor)
{
    // 获取智能体到场地中心的x轴和z轴上的距离
    sensor.AddObservation((transform.position.x - m_Area.transform.position.x) / 20f);
    sensor.AddObservation((transform.position.z - m_Area.transform.position.z) / 20f);
	// 获取智能体前进方向的x轴和z轴的值
    sensor.AddObservation(transform.forward.x);
    sensor.AddObservation(transform.forward.z);

    foreach (var item in CurrentlyVisibleTilesList)
    {
        // 定义一个数组,存放一系列观察值,数组长度为数字方块最大数量 + 3,默认初始化全部为0
        float[] listObservation = new float[k_HighestTileValue + 3];
        // 获取方块的数字,设置对应的one-hot向量
        listObservation[item.NumberValue] = 1.0f;
        // 获取方块的坐标(子物体坐标才是真实坐标的,transform本身的位置保持在场景中央,方便旋转)
        var tileTransform = item.transform.GetChild(1);
        // 输入数字方块和智能体的x分量和z分量
        listObservation[k_HighestTileValue] = (tileTransform.position.x - transform.position.x) / 20f;
        listObservation[k_HighestTileValue + 1] = (tileTransform.position.z - transform.position.z) / 20f;
        // 该方块是否已经被碰撞过
        listObservation[k_HighestTileValue + 2] = item.IsVisited ? 1.0f : 0.0f;
        // 把数组添加到Buffer Sensor中(不直接输入到网络的原因是需要添加的数组个数个数是变化的)
        m_BufferSensor.AppendObservation(listObservation);
    }
}

动作输出方法OnActionReceived:

public override void OnActionReceived(ActionBuffers actionBuffers)
{
    // 移动智能体
    MoveAgent(actionBuffers.DiscreteActions);
    // 时间惩罚,激励智能体越快完成越好
    AddReward(-1f / MaxStep);
} 
public void MoveAgent(ActionSegment<int> act)
{
    var dirToGo = Vector3.zero;
    var rotateDir = Vector3.zero;
	// 获取神经网络三个离散输出
    var forwardAxis = act[0];
    var rightAxis = act[1];
    var rotateAxis = act[2];
	// 第一个离散输出决定了前进后退
    switch (forwardAxis)
    {
        case 1:
            dirToGo = transform.forward * 1f;
            break;
        case 2:
            dirToGo = transform.forward * -1f;
            break;
    }
	// 第二个离散输出决定了左移右移
    switch (rightAxis)
    {
        case 1:
            dirToGo = transform.right * 1f;
            break;
        case 2:
            dirToGo = transform.right * -1f;
            break;
    }
	// 第三个离散输出决定了左转右转
    switch (rotateAxis)
    {
        case 1:
            rotateDir = transform.up * -1f;
            break;
        case 2:
            rotateDir = transform.up * 1f;
            break;
    }
	// 执行动作
    transform.Rotate(rotateDir, Time.deltaTime * 200f);
    m_AgentRb.AddForce(dirToGo * 2, ForceMode.VelocityChange);

}

每一个episode(回合)开始时执行的方法OnEpisodeBegin:

public override void OnEpisodeBegin()
{
    // 从配置文件中获取方块的数量,没有的话设为DefaultMaxNumTiles
    m_MaxNumberOfTiles = (int)m_ResetParams.GetWithDefault("num_tiles", DefaultMaxNumTiles);
	// 随机生成方块的数量
    m_NumberOfTilesToSpawn = Random.Range(1, m_MaxNumberOfTiles + 1);
    // 选择将要生成的对应的方块并加入列表中
    SelectTilesToShow();
    // 生成方块及调整位置
    SetTilePositions();

    transform.position = m_StartingPos;
    m_AgentRb.velocity = Vector3.zero;
    m_AgentRb.angularVelocity = Vector3.zero;
}

void SelectTilesToShow()
{
	// 清除两个列表
    CurrentlyVisibleTilesList.Clear();
    AlreadyTouchedList.Clear();

    // 共生成nunLeft个方块
    int numLeft = m_NumberOfTilesToSpawn;
    while (numLeft > 0)
    {
        // 在范围内取随机数生成对应方块
        int rndInt = Random.Range(0, k_HighestTileValue);
        var tmp = NumberTilesList[rndInt];
        // 如果对应的方块列表中没有才进行添加
        if (!CurrentlyVisibleTilesList.Contains(tmp))
        {
            CurrentlyVisibleTilesList.Add(tmp);
            numLeft--;
        }
    }

    // 给方块列表列表按照数字升序进行排序
    CurrentlyVisibleTilesList.Sort((x, y) => x.NumberValue.CompareTo(y.NumberValue));
    m_NextExpectedTileIndex = 0;
}

void SetTilePositions()
{
	// 清空列表
    m_UsedPositionsList.Clear();
    // 重置所有方块的状态,ResetTile方法可以在数字方块的脚本中看到
    foreach (var item in NumberTilesList)
    {
        item.ResetTile();
        item.gameObject.SetActive(false);
    }

    foreach (var item in CurrentlyVisibleTilesList)
    {
        bool posChosen = false;
        // rndPosIndx决定了我们方块的旋转角度(即在圆形场地的哪里)
        int rndPosIndx = 0;
        while (!posChosen)
        {
            rndPosIndx = Random.Range(0, k_HighestTileValue);
            // 这个旋转角度是否被选了,没被选就加入列表中
            if (!m_UsedPositionsList.Contains(rndPosIndx))
            {
                m_UsedPositionsList.Add(rndPosIndx);
                posChosen = true;
            }
        }
        // 执行方块角度的旋转并激活物体
        item.transform.localRotation = Quaternion.Euler(0, rndPosIndx * (360f / k_HighestTileValue), 0);
        item.gameObject.SetActive(true);
    }
}

当与别的物体开始发生碰撞执行方法OnCollisionEnter:

private void OnCollisionEnter(Collision col)
{
    // 只检测和数字方块的碰撞
    if (!col.gameObject.CompareTag("tile"))
    {
        return;
    }
    // 如果方块已经碰撞过,也排除在碰撞对象之外
    if (AlreadyTouchedList.Contains(col.transform))
    {
        return;
    }
    // 如果碰撞的顺序错误,奖励-1,结束游戏
    if (col.transform.parent != CurrentlyVisibleTilesList[m_NextExpectedTileIndex].transform)
    {
        AddReward(-1);
        EndEpisode();
    }
    // 碰撞到正确的方块的情况
    else
    {
        // 奖励+1
        AddReward(1);
        // 改变方块的材质
        var tile = col.gameObject.GetComponentInParent<NumberTile>();
        tile.VisitTile();
        // 索引+1
        m_NextExpectedTileIndex++;
		// 把方块加入到已接触列表中
        AlreadyTouchedList.Add(col.transform);

        // 如果完成了所有的任务,游戏结束
        if (m_NextExpectedTileIndex == m_NumberOfTilesToSpawn)
        {
            EndEpisode();
        }
    }
}

当智能体没有模型,人想手动录制示例时可以采用Heuristic方法:

public override void Heuristic(in ActionBuffers actionsOut)
{
    var discreteActionsOut = actionsOut.DiscreteActions;
    //forward
    if (Input.GetKey(KeyCode.W))
    {
        discreteActionsOut[0] = 1;
    }
    if (Input.GetKey(KeyCode.S))
    {
        discreteActionsOut[0] = 2;
    }
    //rotate
    if (Input.GetKey(KeyCode.A))
    {
        discreteActionsOut[2] = 1;
    }
    if (Input.GetKey(KeyCode.D))
    {
        discreteActionsOut[2] = 2;
    }
    //right
    if (Input.GetKey(KeyCode.E))
    {
        discreteActionsOut[1] = 1;
    }
    if (Input.GetKey(KeyCode.Q))
    {
        discreteActionsOut[1] = 2;
    }
}

挂载在数字方块上的脚本NumberTile.cs:

using UnityEngine;

public class NumberTile : MonoBehaviour
{
    // 方块上的数字
    public int NumberValue;
    // 默认材质和成功时转换用的材质
    public Material DefaultMaterial;
    public Material SuccessMaterial;
    // 是否已经碰撞过
    private bool m_Visited;
    // 渲染,用于转换材质
    private MeshRenderer m_Renderer;

    public bool IsVisited
    {
        get { return m_Visited; }
    }
	// 用于转换材质的方法
    public void VisitTile()
    {
        m_Renderer.sharedMaterial = SuccessMaterial;
        m_Visited = true;
    }
	// 重置方块的方法,材质还原,m_Visited状态还原
    public void ResetTile()
    {
        if (m_Renderer is null)
        {
            m_Renderer = GetComponentInChildren<MeshRenderer>();
        }
        m_Renderer.sharedMaterial = DefaultMaterial;
        m_Visited = false;
    }
}

配置文件

behaviors:
  Sorter:
    trainer_type: ppo
    hyperparameters:
      batch_size: 512
      buffer_size: 40960
      learning_rate: 0.0003
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: constant
    network_settings:
      normalize: False
      hidden_units: 128
      num_layers: 2
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 5000000
    time_horizon: 256
    summary_freq: 10000
environment_parameters:
  num_tiles:
    curriculum:
      - name: Lesson0 # The '-' is important as this is a list
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.3
        value: 2.0
      - name: Lesson1
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.4
        value: 4.0
      - name: Lesson2
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.45
        value: 6.0
      - name: Lesson3
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.5
        value: 8.0
      - name: Lesson4
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.55
        value: 10.0
      - name: Lesson5
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.6
        value: 12.0
      - name: Lesson6
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.65
        value: 14.0
      - name: Lesson7
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.7
        value: 16.0
      - name: Lesson8
        completion_criteria:
          measure: progress
          behavior: Sorter
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.75
        value: 18.0
      - name: Lesson9
        value: 20.0

可以看到配置文件采用了最为常用的PPO算法,而且是没有带其他“配件”例如LSTM,内在奖励机制等模块的普通PPO,唯一的不同是这里加入了Curriculum Learning(课程学习),也就是说,这种能够数十个方块的排序的智能体是很难一下子训练出来的,因此我们需要从易到难给它安排任务,从一开始能排序两个方块逐渐两个两个递增,最后达到20个。关于Curriculum Learning有关参数的详细解释,请查看我前面的文章ML-Agents案例之跳墙游戏

效果演示

在这里插入图片描述

后记

本案例相比于之前的案例的创新点在于引入了Buffer Sensor,这个传感器是用于接收不同个数向量的输入的,而并非像以往的传感器一样挂在智能体下就能用,这是为了处理类似该案例情况下接收信息数量随环境改变的情况的,这种情况有很多,例如智能体在面对敌人时,敌人的个数是不确定的,敌人发射子弹的数量也是不确定的,这时候,我们就需要用到Buffer Sensor,用来接受不同个数的输入,当然这样的训练往往也需要更多的样本,各种数量的输入都需要覆盖到,否则就会过拟合。为了达到这个目的,这里用到了之前的Curriculum Learning(课程学习)来使训练样本多样化,同时使得训练从易到难,使得智能体的策略具有鲁棒性。

标签:ML,transform,Agents,硬核,var,public,方块
来源: https://blog.csdn.net/tianjuewudi/article/details/121944658