본문 바로가기

유니티/수업 내용

게임인공지능 MLAgent RollerBall 복습

더보기
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;

public class RollerAgent : Agent
{
    private Rigidbody rBody;

    [SerializeField]
    private float moveForce = 10f;

    [SerializeField]
    private Transform target;

    void Start()
    {
        this.rBody = this.GetComponent<Rigidbody>();
    }

    public override void OnEpisodeBegin()
    {
        //Agent의 위치 초기화
        this.transform.localPosition = new Vector3(0, 0.5f, 0);
        //타겟의 위치 초기화
        this.target.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4);
        //속도 초기화
        this.rBody.velocity = Vector3.zero;
    }

    //관찰정보를 Python Trainer에게 보낸다
    public override void CollectObservations(VectorSensor sensor)
    {
        //RollerBall의 x, y, z 좌표를 관찰
        sensor.AddObservation(this.transform.localPosition);

        //타겟의 위치 x, y, z 좌표 관찰
        sensor.AddObservation(this.target.localPosition);

        //타겟과의 거리
        sensor.AddObservation(Vector3.Distance(this.target.localPosition, this.transform.localPosition));
    }

    public int a;

    //액션버퍼를 받는다
    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;
        a = action[0];

        Vector3 dir1 = Vector3.zero;

        if (a == 0)
        {
            dir1 = Vector3.forward; //0, 0, 1
        }
        else if (a == 1)
        {
            dir1 = Vector3.back;    //0, 0, -1
        }

        Vector3 dir2 = Vector3.zero;
        if (a == 2)
        {
            dir2 = Vector3.left;
        }
        else if (a == 3)
        {
            dir2 = Vector3.right;
        }

        var dir = dir1 + dir2;
        this.rBody.AddForce(dir * this.moveForce);


        //떨어지면
        if (this.transform.localPosition.y <= 0)
        {
            //this.AddReward(-1.0f);
            SetReward(-1.0f);
            EndEpisode();
        }

        if (Vector3.Distance(this.target.localPosition, this.transform.localPosition) <= 1.5f)
        {
            //this.AddReward(1f);
            SetReward(1.0f);
            EndEpisode();
        }
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        int outAction = -1;

        if (Input.GetKey(KeyCode.W))
        {
            // move forward
            outAction = 0;
        }
        else if (Input.GetKey(KeyCode.S))
        {
            // move back
            outAction = 1;
        }

        if (Input.GetKey(KeyCode.A))
        {
            // go left
            outAction = 2;
        }
        else if (Input.GetKey(KeyCode.D))
        {
            // go right
            outAction = 3;
        }

        // Put the actions into the array
        actionsOut.DiscreteActions.Array[0] = outAction;
    }
}