Unity ML-Agents 2.0 (Flappy Bird 훈련)

Version : Unity 2021.3.5f1

Unity ML-Agents 2.0 을 이용한 Flappy Bird 훈련

ML-Agents 를 이용하여 Flappy Bird 를 교육시켰다.

Flappy Bird는 중력의 영향을 받아 아래 계속 떨어지고 점프를 통하여 위쪽으로 가속도를 줄 수 있다.

파이프는 배경과 같은 속도로 랜덤하게 나오도록 설계 ( 파이프[위/아래] 사이의 거리는 레벨에 따라 다르다. )

점수는 파이프와 파이프 사이에 오프젝트를 만들어 부딛히면 점수를 주고 바닥에 닿거나 파이프에 부딛히면 점수를 깍고 에피소드를 다시 시작한다.

관측하는 값은 자신의 Y 포지션과 가속도 (음 로컬이 문제인가….)

파이프와 파이프 사이의 거리는 커리큘럼 파라미터에 따라 달라진다. (60은 default)

저음 레벨이 낮을 경우에는 학습을 하고 있다고 느낄 정도로 정말 잘하지만

점점 가면 갈 수록 학습의 능률이 떨어지고 포기하는 경우의 수가 많아진다.

아마 인공지능의 컨트롤로 극복할 수 없는 환경이 되어서 포기하는 듯 하다. (내가 플레이 해도 못 깬다 ㅋㅋ )

Heuristic(휴리스틱-직접 컨트롤, 환경을 테스트 해볼 수 있다.)은 딜레이가 좀 있어서 미세한 컨트롤은 무리…

ML-Agents 를 사용하는 방법이 어렵다기 보다는 환경을 구성하는 부분에서 가장 많은 시간을 잡아먹었다. (교육에는 환경이 중요하다@!)

이번에 ML-Agent의 기본적인 사용법을 익힌 것이 가장 큰 수확

커리큘럼 파라미터 (Curriculum Parameter)

behaviors:
  FlyBird:
    trainer_type: ppo
    hyperparameters:
      batch_size: 256
      buffer_size: 4096
      learning_rate: 0.0003
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: constant
    network_settings:
      normalize: false
      hidden_units: 256
      num_layers: 2
      vis_encode_type: simple
    reward_signals:
      extrinsic:
        gamma: 0.95
        strength: 1.0
    keep_checkpoints: 5
    max_steps: 9000000
    time_horizon: 128
    summary_freq: 20000
environment_parameters:
  LevelUp:
    curriculum:
      - name: Lesson0
        completion_criteria:
          measure: progress
          behavior: FlyBird
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.10
        value: 100
      - name: Lesson1
        completion_criteria:
          measure: progress
          behavior: FlyBird
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.15
        value: 95
      - name: Lesson2
        completion_criteria:
          measure: progress
          behavior: FlyBird
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.20
        value: 90
      - name: Lesson3
        completion_criteria:
          measure: progress
          behavior: FlyBird
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.25
        value: 85
      - name: Lesson4
        completion_criteria:
          measure: progress
          behavior: FlyBird
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.30
        value: 80
      - name: Lesson5
        completion_criteria:
          measure: progress
          behavior: FlyBird
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.35
        value: 75
      - name: Lesson6
        completion_criteria:
          measure: progress
          behavior: FlyBird
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.4
        value: 70
      - name: Lesson7
        completion_criteria:
          measure: progress
          behavior: FlyBird
          signal_smoothing: true
          min_lesson_length: 100
          threshold: 0.45
        value: 65
      - name: Lesson9
        value: 60

BirdAgent.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine.Events;

public class BirdAgent : Agent
{
    public GameMain GameMain;
    EnvironmentParameters m_ResetParams;


    public Rigidbody2D rBody; 
    public float jumpForce = 25f;
    private bool Diecheck = false;

    public UnityAction onDIe;

    public override void Initialize()
    {
        m_ResetParams = Academy.Instance.EnvironmentParameters;
        this.rBody = this.gameObject.GetComponent<Rigidbody2D>();
        ConfigureAgent();
    }

    public override void OnEpisodeBegin()
    {
        this.Diecheck = false;
        this.transform.localPosition = new Vector3(-40f, 15f, 0f);

    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(this.gameObject.transform.localPosition.y); // 1
        //sensor.AddObservation(StepCount/MaxStep); // 1
        sensor.AddObservation(this.rBody.velocity.y); // 1
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        this.AddReward(1 / (float)this.MaxStep);

        var discreteActions = actions.DiscreteActions;

        if (discreteActions[0] == 1)
        {
            this.rBody.velocity = Vector2.up * this.jumpForce;
        }     
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var discreteActions = actionsOut.DiscreteActions;

        if (Input.GetKey(KeyCode.Space))
        {
            discreteActions[0] = 1;

        }
        else
        {
            discreteActions[0] = 0;
        }
    }



    private void OnCollisionEnter2D(Collision2D collision)
    {
        if (collision.gameObject.tag == "Target")
        {
            Debug.Log("Target");
            this.AddReward(0.02f);
            Destroy(collision.gameObject);
        }
        else
        {
            if (!Diecheck)
            {
                this.Diecheck = true;
                AddReward(-0.3f);
                this.onDIe();
                EndEpisode();
            }
        }
    }


    void ConfigureAgent()
    {
       this.GameMain.Level =(int)m_ResetParams.GetWithDefault("LevelUp", 60);
    }

    private void FixedUpdate()
    {
        this.ConfigureAgent();
    }


}

GameMain.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;


public class GameMain : MonoBehaviour
{
    public GameObject env; 

    public GameObject pipePrefabs;
    private Pipe pipe;

    public BirdAgent birdAgent;


    public GameObject GroundPrefabs;
    private Ground ground;

    private List<GameObject> pipesList = new List<GameObject>();
    private List<GameObject> GroundList = new List<GameObject>();

    public GameObject backGround;
    private SpriteRenderer backGroundRender;
    public int DieCount = 0;


    public float pipe_groundSpeed = 40f;
    public int Level = 80;
    public float pipeInitTIme = 1.5f;



    private void Awake()
    {
        Application.runInBackground = true;
        StartGameInit();
        this.backGroundRender = this.backGround.GetComponent<SpriteRenderer>();
    }

    void Start()
    {
        StartCoroutine(CreatePipeRoutine());

        this.birdAgent.onDIe = (() => {
            this.DieCount++;
            Debug.Log(this.DieCount);
            this.birdAgent.OnEpisodeBegin();
            // Destroy(birdGO);
            this.StartGameInit();

            StartCoroutine(BackGroundCoroutine());

        });
    }

    public void StartGameInit()
    {
        foreach (var item in this.pipesList)
        {
            if (item != null)
            {
                Destroy(item.gameObject);
            }
        }

        foreach (var item in this.GroundList)
        {
            if (item != null)
            {
                Destroy(item.gameObject);
            }
        }

        this.pipesList = new List<GameObject>();
        this.GroundList = new List<GameObject>();
        
        var groundGO = Instantiate<GameObject>(this.GroundPrefabs);
        groundGO.transform.SetParent(this.env.transform);
        groundGO.transform.localPosition = new Vector3(100f, 10f, -1f);
        this.ground = groundGO.GetComponent<Ground>();
        this.ground.ChangeGroundSpeed(this.pipe_groundSpeed); // 땅 속도 변경
        this.GroundList.Add(groundGO);

        groundGO = Instantiate<GameObject>(this.GroundPrefabs);
        groundGO.transform.SetParent(this.env.transform);

        groundGO.transform.localPosition = new Vector3(312f, 10f, -1f);
        this.ground = groundGO.GetComponent<Ground>();
        this.ground.ChangeGroundSpeed(this.pipe_groundSpeed); // 땅 속도 변경
        this.GroundList.Add(groundGO);
        groundGO = Instantiate<GameObject>(this.GroundPrefabs);
        groundGO.transform.SetParent(this.env.transform);

        groundGO.transform.localPosition = new Vector3(150f, 10f, -1f);
        this.ground = groundGO.GetComponent<Ground>();
        this.ground.ChangeGroundSpeed(this.pipe_groundSpeed); // 땅 속도 변경
        this.GroundList.Add(groundGO);



    }

    private IEnumerator CreatePipeRoutine()
    {

        while (true)
        {
            int num = 150 - this.Level;

            int randfirst = Random.Range(3, num);


            int randsecond = num - randfirst;

            var topGO = Instantiate<GameObject>(this.pipePrefabs);
            topGO.transform.SetParent(this.env.transform);
            this.pipe = topGO.GetComponent<Pipe>();
            this.pipe.SettingPipe(Pipe.eDirType.UP, randfirst);
            this.pipe.removePipe = ((pipe) => {
                pipesList.Remove(pipe); } );
            this.pipe.ChangePipeSpeed(this.pipe_groundSpeed);
            this.pipesList.Add(topGO);
            topGO.transform.localPosition = new Vector3(71f, -22f, -0.5997626f);
            

            topGO = Instantiate<GameObject>(this.pipePrefabs);
            topGO.transform.SetParent(this.env.transform);
            this.pipe = topGO.GetComponent<Pipe>();
            this.pipe.SettingPipe(Pipe.eDirType.Down, randsecond);
            this.pipe.removePipe = ((pipe) => {
                pipesList.Remove(pipe);
            });
            topGO.transform.localPosition = new Vector3(71f, 31f, -0.5997626f);
            this.pipe.ChangePipeSpeed(this.pipe_groundSpeed);
            this.pipesList.Add(topGO);
            yield return new WaitForSeconds(this.pipeInitTIme);

        }
    }

    private IEnumerator BackGroundCoroutine()
    {
        this.backGroundRender.color = new Color(255f, 0, 0f);

        yield return new WaitForSeconds(0.3f);

        this.backGroundRender.color = new Color(255f, 255f, 255f);
    }

}

Pipe.cs

using UnityEngine;
using UnityEngine.Events;

public class Pipe : MonoBehaviour
{
    public enum eDirType { UP = 1, Down = -1 };
    public eDirType dirType = eDirType.UP;
    public int[] arr = { 10, 20, 30, 40, 50, 60, 70, 80, 90 };
    public GameObject headGO;
    public GameObject bodyGO;
    public float moveSpeed = 20;
    public GameObject target;

    public UnityAction<GameObject> removePipe;


    // Update is called once per frame
    void Update()
    {
        this.transform.Translate(Vector3.left * moveSpeed * Time.deltaTime);

        if (this.transform.localPosition.x < -70f)
        {
            removePipe(this.gameObject);
            Destroy(this.gameObject);
        }
    }

    public void SettingPipe(eDirType Type, int rand)
    {
        //rand += level;
        this.dirType = Type;
        var localScale = this.bodyGO.transform.localScale;
        // localScale.y = (int)this.dirType * this.arr[rand];
        localScale.y = (float)this.dirType * rand;
        this.bodyGO.transform.localScale = localScale;
        //this.headGO.transform.localPosition = new Vector3(this.headGO.transform.localPosition.x,
        //(int)this.dirType * ((float)arr[rand]/10) * 3, this.headGO.transform.localPosition.z);
        this.headGO.transform.localPosition = new Vector3(this.headGO.transform.localPosition.x,
        ((((int)this.dirType) * ((float)rand / 10) * 3)) + ((int)this.dirType * 1.5f), this.headGO.transform.localPosition.z);

        if (Type == eDirType.Down)
        {
            Destroy(this.target.gameObject);
        }
        else
        {
            this.target.transform.localPosition = new Vector3(this.headGO.transform.localPosition.x, this.headGO.transform.localPosition.y + 12f, this.headGO.transform.localPosition.z);
        }
    }

    public void ChangePipeSpeed(float speed)
    {
        this.moveSpeed = speed;
    }

}

Test.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class Test : MonoBehaviour
{
    public GameObject envPrefabs;
    void Start()
    {
        int x = 0;
        int y = 0;

        for (int i = 0; i < 10; i++)
        {

            Instantiate<GameObject>(this.envPrefabs).transform.position = new Vector3(x, y, 0);

            x += 300;

            if (x >= 1500)
            {
                x = 0;
                y += 300;
            }
        }
    }
}

댓글 달기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다

위로 스크롤