본문 바로가기

유니티/수업 내용

ML Agent - DinoRun

 

더보기
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class TrainingArea : MonoBehaviour
{
    [SerializeField]
    private GameObject cactusPrefab;
    [SerializeField]
    private float spawnTime = 3f;
    [SerializeField]
    private Transform targetParent;

    void Start()
    {
        StartCoroutine(this.GenerateCactus());
    }

    private IEnumerator GenerateCactus()
    {
        while (true)
        {
            yield return new WaitForSeconds(this.spawnTime);
            var go = Instantiate(this.cactusPrefab, targetParent);
            go.transform.localPosition = new Vector3(10f, 0, 0);
        }
    }
}
더보기
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;

public class DinoAgent : Agent
{
    //public GameObject cactus;

    private Rigidbody2D rBody;

    [SerializeField]
    private float jumpForce = 300f;

    void Start()
    {
        this.rBody = this.GetComponent<Rigidbody2D>();
    }

    public override void OnEpisodeBegin()
    {
        this.delta = 0;
        this.transform.localPosition = new Vector3(0, 1f, 0);

        //타겟의 위치 초기화
        //this.cactus.transform.localPosition = new Vector3(3.23f, 0, 0);
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        //Dino의 y 좌표를 관찰
        sensor.AddObservation(this.transform.localPosition.y);

        //Dino의 y 속력을 관찰
        sensor.AddObservation(this.rBody.velocity.y);
    }

    private float delta = 0;

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;

        if (action[0] == 1)
        {
            if (this.rBody.velocity.y == 0 && this.transform.localPosition.y <= 0.2f)
            {
                //바닥에 붙어있다면
                this.JumpA();
            }
        }
        else
        {
            if (this.transform.localPosition.y <= 0.2f && this.rBody.velocity.y == 0)
            {
                this.delta += Time.deltaTime;
                if (this.delta > 1.0f)
                {
                    Debug.Log("Get point");
                    this.delta = 0;
                    //바닥에서 오래동안 달렸으므로 보상
                    this.SetReward(0.01f);
                }
            }
        }
    }

    private void Jump()
    {
        this.rBody.AddForce(Vector2.up * this.jumpForce);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        
    }

    private void OnCollisionEnter2D(Collision2D collision)
    {
        if (collision.collider.CompareTag("Cactus"))
        {
            //감점
            this.AddReward(-1f);
            this.EndEpisode();
        }
    }
}
더보기
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class GroundMove : MonoBehaviour
{
    private float moveSpeed = 3f;

    void Update()
    {
        this.transform.Translate(Vector2.left * this.moveSpeed * Time.deltaTime);
        if (this.transform.localPosition.x <= -18f)
        {
            var pos = this.transform.localPosition;
            pos.x = 37;
            pos.y = 0f;
            this.transform.Translate(pos);
        }
    }
}
더보기
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class CactusMove : MonoBehaviour
{
    [SerializeField]
    private float moveSpeed = 3f;

    void Update()
    {
        this.transform.Translate(Vector2.left * this.moveSpeed * Time.deltaTime);
        if (this.transform.localPosition.x <= -12f)
        {
            Destroy(this.gameObject);
        }
    }
}

 

D:

cd D:\workspace\unity\ml-agents-release_18

mlagents-learn config/ppo/DinoRun.yaml --run-id=DinoRun

tensorboard --logdir results --port 6006