simplestarの技術ブログ

目的を書いて、思想と試行、結果と考察、そして具体的な手段を記録します。

Unity:3層ニューラルネットワークプログラム

人工知能とか機械学習とかを勉強していきたいと思います。
3層ニューラルネットワークをご存知でしょうか?

まずは解説されている書籍を読んでみます。
「ディジタル画像処理」のニューラルネットワークの説明を参考にしました。
書籍リンク:CG-ARTS協会 | 書籍・教材

ネットでもニューラルネットワークについて色々と情報が入手できますが
わかりやすいという点で、この書籍の説明より良いものを見たことがありません。

さて、この本の説明内容にそってプログラムを書きます。

先に結果だけ書くと、うまく動作していることがわかりました。
f:id:simplestar_tech:20160131223731p:plain
f:id:simplestar_tech:20160131223745p:plain

3層ニューラルネットワークのプログラムがこちら

using UnityEngine;
using System.Collections;
using System.Collections.Generic;
using System.IO;

public class NeuralNetwork3
{
    public float Rate = 0.05f;

    public void Setup(int inputUnitCount, int intermediateUnitCount, int outputUnitCount)
    {
        _inputCount = inputUnitCount;
        _unitCount = intermediateUnitCount;
        _resultCount = outputUnitCount;
        _layerInputIntermediateValues = new float[_inputCount * _unitCount];
        _intermediateUnits = new float[_unitCount];
        _layerIntermediateOutputValues = new float[_unitCount * _resultCount];
        _ek = new float[_resultCount];

        for (int i = 0; i < _layerInputIntermediateValues.Length; i++)
        {
            _layerInputIntermediateValues[i] = Random.value;
        }
        for (int i = 0; i < _layerIntermediateOutputValues.Length; i++)
        {
            _layerIntermediateOutputValues[i] = Random.value;
        }
    }

    public void Learn(float[] inputVector, float[] outputVector, float[] resultVector)
    {
        _CalculateIntemediateUnits(inputVector);
        _CalculateResultVector(resultVector);
        _UpdateIntermediateOutputValues(resultVector, outputVector);
        _UpdateInputIntermediateValues(inputVector, resultVector);
    }

    int _inputCount;
    int _unitCount;
    int _resultCount;
    float[] _layerInputIntermediateValues;
    float[] _layerIntermediateOutputValues;
    float[] _intermediateUnits;
    float[] _ek;

    private void _CalculateIntemediateUnits(float[] inputVector)
    {
        for (int unitIdx = 0; unitIdx < _unitCount; unitIdx++)
        {
            _intermediateUnits[unitIdx] = 0;
            for (int inputIdx = 0; inputIdx < _inputCount; inputIdx++)
            {
                _intermediateUnits[unitIdx] += inputVector[inputIdx] * _layerInputIntermediateValues[_inputCount * unitIdx + inputIdx];
            }
        }
    }

    private void _CalculateResultVector(float[] resultVector)
    {
        for (int resIdx = 0; resIdx < _resultCount; resIdx++)
        {
            resultVector[resIdx] = 0;
            for (int unitIdx = 0; unitIdx < _unitCount; unitIdx++)
            {
                resultVector[resIdx] += _intermediateUnits[unitIdx] * _layerIntermediateOutputValues[_unitCount * resIdx + unitIdx];
            }
        }
    }

    private void _UpdateIntermediateOutputValues(float[] resultVector, float[] outputVector)
    {
        for (int resIdx = 0; resIdx < _resultCount; resIdx++)
        {
            float diff = (resultVector[resIdx] - outputVector[resIdx]);
            float s = 1 / (1 + Mathf.Exp(-resultVector[resIdx]));
            _ek[resIdx] = diff * s * (1 - s);
            for (int unitIdx = 0; unitIdx < _unitCount; unitIdx++)
            {
                _layerIntermediateOutputValues[_unitCount * resIdx + unitIdx] -= Rate * _ek[resIdx] * _intermediateUnits[unitIdx];
            }
        }
    }

    private void _UpdateInputIntermediateValues(float[] inputVector, float[] resultVector)
    {
        for (int unitIdx = 0; unitIdx < _unitCount; unitIdx++)
        {
            float diffs = 0;
            for (int resIdx = 0; resIdx < _resultCount; resIdx++)
            {
                diffs += _ek[resIdx] * _layerIntermediateOutputValues[_unitCount * resIdx + unitIdx];
            }
            float s = 1 / (1 + Mathf.Exp(-_intermediateUnits[unitIdx]));
            diffs *= diffs * s * (1 - s);
            for (int inputIdx = 0; inputIdx < _inputCount; inputIdx++)
            {
                _layerInputIntermediateValues[_inputCount * unitIdx + inputIdx] -= Rate * diffs * inputVector[inputIdx];
            }
        }
    }
}

使い方:
公開変数 Rate は学習率です。
対象データに応じて適切に設定してください。
入力ベクトルのサイズと中間層のユニット数と出力ベクトルのサイズを渡して Setup します。
その後、用意している対象データを Learn に渡すことで内部の重み係数が更新されます。

以上です。

正しいかテストするコードは次の通り

using UnityEngine;
using System.Collections;
using System.Collections.Generic;
using System.IO;

public class TestNN3Behaviour : MonoBehaviour
{

    public int DataCount = 50;
    public int IterationCount = 50;

    NeuralNetwork3 _neuralNetwork3 = new NeuralNetwork3();
    List<InputSet> _inputSets;

    enum Janken
    {
        GOO,
        CYO,
        PAR,
        MAX
    }

    enum Wairo
    {
        NO,
        LO,
        HI,
        MAX
    }

    struct InputSet
    {
        public Janken aite;
        public Wairo wairo;
        public Janken jibun;
    }

    // Use this for initialization
    void Start()
    {
        int inputUnitCount = 6;
        int intermediateUnitCount = 9;
        int outputUnitCount = 9;
        _neuralNetwork3.Setup(inputUnitCount, intermediateUnitCount, outputUnitCount);
        _CreateInputSet();
        float[] resultVector = new float[outputUnitCount];
        string nowString = System.DateTime.Now.ToString("HHmmss");
        for (int itrIdx = 0; itrIdx < IterationCount; itrIdx++)
        {
            int victoryCount = 0;
            float error = 0;
            for (int dataIdx = 0; dataIdx < _inputSets.Count; dataIdx++)
            {
                InputSet input = _inputSets[dataIdx];
                float[] inputVector = _GetInputVector(ref input);
                float[] outputVector = _GetOutputVector(ref input);

                _neuralNetwork3.Learn(inputVector, outputVector, resultVector);

                Janken result = _GetResultJanken(resultVector);
                if (result == input.jibun)
                    ++victoryCount;
                for (int resIdx = 0; resIdx < outputUnitCount; resIdx++)
                {
                    float diff = (resultVector[resIdx] - outputVector[resIdx]);
                    error += diff * diff;
                }
            }
            float vRate = victoryCount / (float)_inputSets.Count * 100;

            using (StreamWriter writer = new StreamWriter("Result" + nowString + ".csv", true))
            {
                writer.WriteLine(vRate.ToString("0.00000") + "," + error.ToString("0.00000"));
            }
        }
    }

    // Update is called once per frame
    void Update()
    {

    }

    private Janken _GetResultJanken(float[] resultVector)
    {
        Janken result = Janken.MAX;
        float maxRes = 0;
        int maxResIdx = 0;
        for (int resIdx = 0; resIdx < resultVector.Length; resIdx++)
        {
            if (maxRes < resultVector[resIdx])
            {
                maxRes = resultVector[resIdx];
                maxResIdx = resIdx;
            }
        }
        result = (Janken)(maxResIdx % (int)Janken.MAX);
        return result;
    }

    private float[] _GetOutputVector(ref InputSet input)
    {
        float[] outputVector = new float[9];
        outputVector[0] = (input.jibun == Janken.GOO && input.wairo == Wairo.NO) ? 1 : 0;
        outputVector[1] = (input.jibun == Janken.CYO && input.wairo == Wairo.NO) ? 1 : 0;
        outputVector[2] = (input.jibun == Janken.PAR && input.wairo == Wairo.NO) ? 1 : 0;
        outputVector[3] = (input.jibun == Janken.GOO && input.wairo == Wairo.LO) ? 1 : 0;
        outputVector[4] = (input.jibun == Janken.CYO && input.wairo == Wairo.LO) ? 1 : 0;
        outputVector[5] = (input.jibun == Janken.PAR && input.wairo == Wairo.LO) ? 1 : 0;
        outputVector[6] = (input.jibun == Janken.GOO && input.wairo == Wairo.HI) ? 1 : 0;
        outputVector[7] = (input.jibun == Janken.CYO && input.wairo == Wairo.HI) ? 1 : 0;
        outputVector[8] = (input.jibun == Janken.PAR && input.wairo == Wairo.HI) ? 1 : 0;
        return outputVector;
    }

    private float[] _GetInputVector(ref InputSet input)
    {
        float[] inputVector = new float[6];
        inputVector[0] = input.aite == Janken.GOO ? 1 : 0;
        inputVector[1] = input.aite == Janken.CYO ? 1 : 0;
        inputVector[2] = input.aite == Janken.PAR ? 1 : 0;
        inputVector[3] = input.wairo == Wairo.NO ? 1 : 0;
        inputVector[4] = input.wairo == Wairo.LO ? 1 : 0;
        inputVector[5] = input.wairo == Wairo.HI ? 1 : 0;
        return inputVector;
    }

    private void _CreateInputSet()
    {
        _inputSets = new List<InputSet>();
        for (int dataIdx = 0; dataIdx < DataCount; dataIdx++)
        {
            InputSet inputSet = new InputSet() { aite = (Janken)Random.Range(0, (int)Janken.MAX), wairo = (Wairo)Random.Range(0, (int)Wairo.MAX) };
            inputSet = _SetWinnerHand(inputSet);
            _inputSets.Add(inputSet);
        }
    }

    private static InputSet _SetWinnerHand(InputSet inputSet)
    {
        switch (inputSet.aite)
        {
            case Janken.GOO:
                switch (inputSet.wairo)
                {
                    case Wairo.NO:
                        inputSet.jibun = Janken.PAR;
                        break;
                    case Wairo.LO:
                        inputSet.jibun = Janken.GOO;
                        break;
                    case Wairo.HI:
                        inputSet.jibun = Janken.CYO;
                        break;
                    case Wairo.MAX:
                        break;
                    default:
                        break;
                }
                break;
            case Janken.CYO:
                switch (inputSet.wairo)
                {
                    case Wairo.NO:
                        inputSet.jibun = Janken.GOO;
                        break;
                    case Wairo.LO:
                        inputSet.jibun = Janken.CYO;
                        break;
                    case Wairo.HI:
                        inputSet.jibun = Janken.PAR;
                        break;
                    case Wairo.MAX:
                        break;
                    default:
                        break;
                }
                break;
            case Janken.PAR:
                switch (inputSet.wairo)
                {
                    case Wairo.NO:
                        inputSet.jibun = Janken.CYO;
                        break;
                    case Wairo.LO:
                        inputSet.jibun = Janken.PAR;
                        break;
                    case Wairo.HI:
                        inputSet.jibun = Janken.GOO;
                        break;
                    case Wairo.MAX:
                        break;
                    default:
                        break;
                }
                break;
            case Janken.MAX:
                break;
            default:
                break;
        }
        return inputSet;
    }
}

テスト内容は入力ベクトルが6要素
組み合わせとしては 9 通りの入力に対して、9通りの答えを返すというもの
出力ベクトルは9要素とし
中間層は9ユニットとしています。

どういうテストかというと
ジャンケンで勝つログを与えるというものです。
賄賂を受け取るとその量に応じてあいこにしたり、わざと負けたりするというのが正解としています。
これ結構意地悪なルールで勝率100%は難しいと思って設定しました。

データ数50で50回イテレーションを回すと(合計50 x 50 Learn を実行するということ)
記録するコードの通り csv には最初に示した通りの結果が出ました。
なんと!こんないじわるなルールでも正解率100%のAIができたのです。

是非あなたが持っている価値あるデータを入力に3層ニューラルネットワークを学習させてみてください。
誤差は順調に減少していくと思いますから、そうなったとき好ましい結果を返すAIになってくれると良いですね。

今回のプログラムはUnityパッケージとしてダウンロードできるようにしておきます。
http://file.blenderbluelog.anime-movie.net/ANN3Layer.zip