Version : Unity 2021.3.5f1
Unity ML-Agents 2.0 을 이용한 Flappy Bird 훈련
ML-Agents 를 이용하여 Flappy Bird 를 교육시켰다.
Flappy Bird는 중력의 영향을 받아 아래 계속 떨어지고 점프를 통하여 위쪽으로 가속도를 줄 수 있다.
파이프는 배경과 같은 속도로 랜덤하게 나오도록 설계 ( 파이프[위/아래] 사이의 거리는 레벨에 따라 다르다. )
점수는 파이프와 파이프 사이에 오프젝트를 만들어 부딛히면 점수를 주고 바닥에 닿거나 파이프에 부딛히면 점수를 깍고 에피소드를 다시 시작한다.
관측하는 값은 자신의 Y 포지션과 가속도 (음 로컬이 문제인가….)
파이프와 파이프 사이의 거리는 커리큘럼 파라미터에 따라 달라진다. (60은 default)
저음 레벨이 낮을 경우에는 학습을 하고 있다고 느낄 정도로 정말 잘하지만
점점 가면 갈 수록 학습의 능률이 떨어지고 포기하는 경우의 수가 많아진다.
아마 인공지능의 컨트롤로 극복할 수 없는 환경이 되어서 포기하는 듯 하다. (내가 플레이 해도 못 깬다 ㅋㅋ )
Heuristic(휴리스틱-직접 컨트롤, 환경을 테스트 해볼 수 있다.)은 딜레이가 좀 있어서 미세한 컨트롤은 무리…
ML-Agents 를 사용하는 방법이 어렵다기 보다는 환경을 구성하는 부분에서 가장 많은 시간을 잡아먹었다. (교육에는 환경이 중요하다@!)
이번에 ML-Agent의 기본적인 사용법을 익힌 것이 가장 큰 수확
커리큘럼 파라미터 (Curriculum Parameter)
behaviors: FlyBird: trainer_type: ppo hyperparameters: batch_size: 256 buffer_size: 4096 learning_rate: 0.0003 beta: 0.005 epsilon: 0.2 lambd: 0.95 num_epoch: 3 learning_rate_schedule: constant network_settings: normalize: false hidden_units: 256 num_layers: 2 vis_encode_type: simple reward_signals: extrinsic: gamma: 0.95 strength: 1.0 keep_checkpoints: 5 max_steps: 9000000 time_horizon: 128 summary_freq: 20000 environment_parameters: LevelUp: curriculum: - name: Lesson0 completion_criteria: measure: progress behavior: FlyBird signal_smoothing: true min_lesson_length: 100 threshold: 0.10 value: 100 - name: Lesson1 completion_criteria: measure: progress behavior: FlyBird signal_smoothing: true min_lesson_length: 100 threshold: 0.15 value: 95 - name: Lesson2 completion_criteria: measure: progress behavior: FlyBird signal_smoothing: true min_lesson_length: 100 threshold: 0.20 value: 90 - name: Lesson3 completion_criteria: measure: progress behavior: FlyBird signal_smoothing: true min_lesson_length: 100 threshold: 0.25 value: 85 - name: Lesson4 completion_criteria: measure: progress behavior: FlyBird signal_smoothing: true min_lesson_length: 100 threshold: 0.30 value: 80 - name: Lesson5 completion_criteria: measure: progress behavior: FlyBird signal_smoothing: true min_lesson_length: 100 threshold: 0.35 value: 75 - name: Lesson6 completion_criteria: measure: progress behavior: FlyBird signal_smoothing: true min_lesson_length: 100 threshold: 0.4 value: 70 - name: Lesson7 completion_criteria: measure: progress behavior: FlyBird signal_smoothing: true min_lesson_length: 100 threshold: 0.45 value: 65 - name: Lesson9 value: 60
BirdAgent.cs
using System.Collections; using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; using UnityEngine.Events; public class BirdAgent : Agent { public GameMain GameMain; EnvironmentParameters m_ResetParams; public Rigidbody2D rBody; public float jumpForce = 25f; private bool Diecheck = false; public UnityAction onDIe; public override void Initialize() { m_ResetParams = Academy.Instance.EnvironmentParameters; this.rBody = this.gameObject.GetComponent<Rigidbody2D>(); ConfigureAgent(); } public override void OnEpisodeBegin() { this.Diecheck = false; this.transform.localPosition = new Vector3(-40f, 15f, 0f); } public override void CollectObservations(VectorSensor sensor) { sensor.AddObservation(this.gameObject.transform.localPosition.y); // 1 //sensor.AddObservation(StepCount/MaxStep); // 1 sensor.AddObservation(this.rBody.velocity.y); // 1 } public override void OnActionReceived(ActionBuffers actions) { this.AddReward(1 / (float)this.MaxStep); var discreteActions = actions.DiscreteActions; if (discreteActions[0] == 1) { this.rBody.velocity = Vector2.up * this.jumpForce; } } public override void Heuristic(in ActionBuffers actionsOut) { var discreteActions = actionsOut.DiscreteActions; if (Input.GetKey(KeyCode.Space)) { discreteActions[0] = 1; } else { discreteActions[0] = 0; } } private void OnCollisionEnter2D(Collision2D collision) { if (collision.gameObject.tag == "Target") { Debug.Log("Target"); this.AddReward(0.02f); Destroy(collision.gameObject); } else { if (!Diecheck) { this.Diecheck = true; AddReward(-0.3f); this.onDIe(); EndEpisode(); } } } void ConfigureAgent() { this.GameMain.Level =(int)m_ResetParams.GetWithDefault("LevelUp", 60); } private void FixedUpdate() { this.ConfigureAgent(); } }
GameMain.cs
using System.Collections; using System.Collections.Generic; using UnityEngine; public class GameMain : MonoBehaviour { public GameObject env; public GameObject pipePrefabs; private Pipe pipe; public BirdAgent birdAgent; public GameObject GroundPrefabs; private Ground ground; private List<GameObject> pipesList = new List<GameObject>(); private List<GameObject> GroundList = new List<GameObject>(); public GameObject backGround; private SpriteRenderer backGroundRender; public int DieCount = 0; public float pipe_groundSpeed = 40f; public int Level = 80; public float pipeInitTIme = 1.5f; private void Awake() { Application.runInBackground = true; StartGameInit(); this.backGroundRender = this.backGround.GetComponent<SpriteRenderer>(); } void Start() { StartCoroutine(CreatePipeRoutine()); this.birdAgent.onDIe = (() => { this.DieCount++; Debug.Log(this.DieCount); this.birdAgent.OnEpisodeBegin(); // Destroy(birdGO); this.StartGameInit(); StartCoroutine(BackGroundCoroutine()); }); } public void StartGameInit() { foreach (var item in this.pipesList) { if (item != null) { Destroy(item.gameObject); } } foreach (var item in this.GroundList) { if (item != null) { Destroy(item.gameObject); } } this.pipesList = new List<GameObject>(); this.GroundList = new List<GameObject>(); var groundGO = Instantiate<GameObject>(this.GroundPrefabs); groundGO.transform.SetParent(this.env.transform); groundGO.transform.localPosition = new Vector3(100f, 10f, -1f); this.ground = groundGO.GetComponent<Ground>(); this.ground.ChangeGroundSpeed(this.pipe_groundSpeed); // 땅 속도 변경 this.GroundList.Add(groundGO); groundGO = Instantiate<GameObject>(this.GroundPrefabs); groundGO.transform.SetParent(this.env.transform); groundGO.transform.localPosition = new Vector3(312f, 10f, -1f); this.ground = groundGO.GetComponent<Ground>(); this.ground.ChangeGroundSpeed(this.pipe_groundSpeed); // 땅 속도 변경 this.GroundList.Add(groundGO); groundGO = Instantiate<GameObject>(this.GroundPrefabs); groundGO.transform.SetParent(this.env.transform); groundGO.transform.localPosition = new Vector3(150f, 10f, -1f); this.ground = groundGO.GetComponent<Ground>(); this.ground.ChangeGroundSpeed(this.pipe_groundSpeed); // 땅 속도 변경 this.GroundList.Add(groundGO); } private IEnumerator CreatePipeRoutine() { while (true) { int num = 150 - this.Level; int randfirst = Random.Range(3, num); int randsecond = num - randfirst; var topGO = Instantiate<GameObject>(this.pipePrefabs); topGO.transform.SetParent(this.env.transform); this.pipe = topGO.GetComponent<Pipe>(); this.pipe.SettingPipe(Pipe.eDirType.UP, randfirst); this.pipe.removePipe = ((pipe) => { pipesList.Remove(pipe); } ); this.pipe.ChangePipeSpeed(this.pipe_groundSpeed); this.pipesList.Add(topGO); topGO.transform.localPosition = new Vector3(71f, -22f, -0.5997626f); topGO = Instantiate<GameObject>(this.pipePrefabs); topGO.transform.SetParent(this.env.transform); this.pipe = topGO.GetComponent<Pipe>(); this.pipe.SettingPipe(Pipe.eDirType.Down, randsecond); this.pipe.removePipe = ((pipe) => { pipesList.Remove(pipe); }); topGO.transform.localPosition = new Vector3(71f, 31f, -0.5997626f); this.pipe.ChangePipeSpeed(this.pipe_groundSpeed); this.pipesList.Add(topGO); yield return new WaitForSeconds(this.pipeInitTIme); } } private IEnumerator BackGroundCoroutine() { this.backGroundRender.color = new Color(255f, 0, 0f); yield return new WaitForSeconds(0.3f); this.backGroundRender.color = new Color(255f, 255f, 255f); } }
Pipe.cs
using UnityEngine; using UnityEngine.Events; public class Pipe : MonoBehaviour { public enum eDirType { UP = 1, Down = -1 }; public eDirType dirType = eDirType.UP; public int[] arr = { 10, 20, 30, 40, 50, 60, 70, 80, 90 }; public GameObject headGO; public GameObject bodyGO; public float moveSpeed = 20; public GameObject target; public UnityAction<GameObject> removePipe; // Update is called once per frame void Update() { this.transform.Translate(Vector3.left * moveSpeed * Time.deltaTime); if (this.transform.localPosition.x < -70f) { removePipe(this.gameObject); Destroy(this.gameObject); } } public void SettingPipe(eDirType Type, int rand) { //rand += level; this.dirType = Type; var localScale = this.bodyGO.transform.localScale; // localScale.y = (int)this.dirType * this.arr[rand]; localScale.y = (float)this.dirType * rand; this.bodyGO.transform.localScale = localScale; //this.headGO.transform.localPosition = new Vector3(this.headGO.transform.localPosition.x, //(int)this.dirType * ((float)arr[rand]/10) * 3, this.headGO.transform.localPosition.z); this.headGO.transform.localPosition = new Vector3(this.headGO.transform.localPosition.x, ((((int)this.dirType) * ((float)rand / 10) * 3)) + ((int)this.dirType * 1.5f), this.headGO.transform.localPosition.z); if (Type == eDirType.Down) { Destroy(this.target.gameObject); } else { this.target.transform.localPosition = new Vector3(this.headGO.transform.localPosition.x, this.headGO.transform.localPosition.y + 12f, this.headGO.transform.localPosition.z); } } public void ChangePipeSpeed(float speed) { this.moveSpeed = speed; } }
Test.cs
using System.Collections; using System.Collections.Generic; using UnityEngine; public class Test : MonoBehaviour { public GameObject envPrefabs; void Start() { int x = 0; int y = 0; for (int i = 0; i < 10; i++) { Instantiate<GameObject>(this.envPrefabs).transform.position = new Vector3(x, y, 0); x += 300; if (x >= 1500) { x = 0; y += 300; } } } }