忍者ブログ

軽Lab

 Javaを中心とした、プログラミング関係のナレッジベース

Home > > Java DeepLearning4j 基本的な利用方法

Java DeepLearning4j 基本的な利用方法

×

[PR]上記の広告は3ヶ月以上新規記事投稿のないブログに表示されています。新しい記事を書く事で広告が消えます。

Home > > Java DeepLearning4j 基本的な利用方法

- ランダム記事 -

コメント

ただいまコメントを受けつけておりません。

Home > Javaで機械学習 - Deeplearning4j入門 > Java DeepLearning4j 基本的な利用方法

Java DeepLearning4j 基本的な利用方法

前回の記事ではDeeplearning4jの開発環境構築について確認した。今回はDeeplearning 4jでニューラルネットワークを構築する基本的な方法について確認する。また、Deeplearning 4jの動作を確認するため、単純パーセプトロンの実装をサンプルプログラムで確認する。


■ Deeplearning 4jの基本的な利用方法

 環境構築の際にも述べたとおり、Deeplearning 4jはDeep Learningの構築を容易にするフレームワークである。Deep Learningの定義は一般的に『4つ以上の層を持つューラルネットワークによる学習』と言われており、Deeplearning4jではこの階層型ニューラルネットワークを構築・学習・利用する事ができる。Deeplearning 4jでニューラルネットワークを構築する基本的な手順は次のとおりである。

  1. 学習・テストデータの準備
  2. ニューラルネットワークの定義・構築
  3. ニューラルネットワークの学習
  4. ニューラルネットワークの利用

 ニューラルネットワークについては前回紹介したとおりであり、Deep Learningでは階層型ニューラルネットワークを利用する。階層型ニューラルネットワークの特徴として、入出力のペアデータを与えることで入力→出力の変換を自動的に学習することができる点があげられる。この特徴によって、プログラマが入力→出力の変換式を知らなくてもそれらしい計算式を得ることができる(学習することができる)ため、Deep Learningは人工知能の分野で近年注目を集めている。

図:階層型ニューラルネットワークのイメージ。入力[1,0,1]に対して、
ニューラルネットワーク内の関数が連鎖的に計算を実施し、出力[0,1]を出力している

 階層型ニューラルネットワークの種類としては多層パーセプトロンや畳み込みニューラルネットワークなど多数存在するが、Deeplearning 4jではニューラルネットワークの定義を変えることにより実現できるようになっている。以下ではそれぞれのフェーズについてみていく。

学習・テストデータの準備

 Deeplearning 4jの学習データやテストデータ、実際の入力データはDataSetクラスとして用意する必要がある。DataSetはND4Jライブラリに属するクラスで、内部的にメモリ節約機構を持ち、行列計算用の関数を備えている(*1,*2)。クラス内部には入力(featureMatrix)と出力(label)の2種類のデータを保持している。

 DataSetで新規データを作成する例を以下に示す。以下のコード・スぺニットでは、(1,1)(1,0)(0,1)(0,0)という4つの2次元入力と、1,1,1,0という4つの1次元出力を持つDataSetを作成している。入力と出力は対になっており、入力配列i番目の出力は、出力配列i番目に格納する。
INDArray    tIn     = Nd4j.create( new float[]{ 1 , 1 ,         // 入力1
                                                1 , 0 ,         // 入力2
                                                0 , 1 ,         // 入力3
                                                0 , 0 },        // 入力4
                                   new int[]{ 4 , 2 } );        // サイズ
INDArray    tOut    = Nd4j.create( new float[]{ 1 , 1 , 1 , 0} ,
                                   new int[]{ 4 , 1 } );        // サイズ
DataSet     train   = new DataSet( tIn , tOut );                // 入出力を対応付けたデータセット
 また、MnistデータベースやIris flower datasetといった学習用のデータについてはDeeplearning 4j内部に保持しており、特定のクラスを呼び出すことで利用できる。詳細については別記事で確認する。

ニューラルネットワークの定義・構築、学習、利用


 ニューラルネットワークの定義~利用は以下の流れで実施する。具体的な利用例はサンプルプログラムで確認する。

  1. MultiLayerConfiguration.Builderクラスでニューラルネットワークの動作を定義する
  2. MultiLayerConfiguration.Builder::build関数を呼び出すことで、ニューラルネットワークの設計書にあたるMultiLayerConfigurationインスタンスを作成する
  3. MultiLayerConfigurationインスタンスをもとに、ニューラルネットワークを表すMultiLayerNetworkインスタンスを作成する
  4. MultiLayerNetwork::fit関数で学習を実施
  5. MultiLayerNetwork::output関数で出力


■ サンプルプログラム(単純パーセプトロンの構築)

 以下にDeepLearning 4jを用いて単純パーセプトロン(OR計算を処理)を構成するサンプルプログラムを示す。単純パーセプトロンとは、別記事で解説しているとおり入力層と出力層だけを持つニューラルネットワークである(下図)。

図:単純パーセプトロンのイメージ

 サンプルでは活性化関数にシグモイド関数、誤差関数にMSE(\(\sum(t-o)^2\))を利用し、誤差逆伝搬法によって4つの訓練データで1000回学習している。以下のプログラムは前回作成したDeepLearning4j用のプロジェクト内部で動作する。

サンプルプログラム

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * OR計算を行うパーセプトロン
 * @author karura
 */
public class Perceptron
{
    // 変数
    protected static final Logger log = LoggerFactory.getLogger( LenetMnistExample.class );       // ロガー
    
    // メイン関数
    public static void main(String[] args) throws Exception
    {
        // 変数定義
        int seed        = 123;          // 乱数シード
        int iterations  = 1000;         // 学習の試行回数
        int inputNum    = 2;            // 入力数
        int outputNum   = 1;            // 出力数
        INDArray    tIn     = Nd4j.create( new float[]{ 1 , 1 ,         // 入力1
                                                        1 , 0 ,         // 入力2
                                                        0 , 1 ,         // 入力3
                                                        0 , 0 },        // 入力4
                                           new int[]{ 4 , 2 } );        // サイズ
        INDArray    tOut    = Nd4j.create( new float[]{ 1 , 1 , 1 , 0} ,
                                           new int[]{ 4 , 1 } );        // サイズ
        DataSet     train   = new DataSet( tIn , tOut );                // 入出力を対応付けたデータセット
        System.out.println( train );
        
        // ニューラルネットワークを定義
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .learningRate(0.01)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater( Updater.NONE )
                .list()
                .layer(0, new OutputLayer.Builder( LossFunctions.LossFunction.MSE )
                        .nIn(inputNum)
                        .nOut(outputNum)
                        .activation("sigmoid")
                        .build())
                .backprop(true).pretrain(false);
        
        // ニューラルネットワークを作成
        MultiLayerConfiguration conf        = builder.build();
        MultiLayerNetwork       perceptron  = new MultiLayerNetwork(conf);
        perceptron.init();
        
        // 確認用のリスナーを追加
        perceptron.setListeners( new ScoreIterationListener(1) );
        
        // 学習(fit)
        perceptron.fit( train );
        
        // パーセプトロンの使用
        for( int i=0 ; i<train.numExamples() ; i++ )
        {
            // i個目のサンプルについて、
            INDArray    input  = train.get(i).getFeatureMatrix();
            INDArray    answer = train.get(i).getLabels();
            INDArray    output = perceptron.output( input , false );
            System.out.println( "result" + i );
            System.out.println( " input  : " + input );
            System.out.println( " output : " + output );
            System.out.println( " answer : " + answer );
            System.out.flush();
            
        }
        
    }
}

実行結果

00:59:11.485 [main] DEBUG org.nd4j.nativeblas.NativeOps - Number of threads used for linear algebra 1
00:59:11.492 [main] DEBUG org.nd4j.nativeblas.NativeOps - Number of threads used for linear algebra 1

…中略…

00:58:40.017 [main] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 998 is 0.0031967340037226677
00:58:40.019 [main] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 999 is 0.0031930606346577406
result0
 input  : [1.00, 1.00]
 output : 1.00
 answer : 1.00
result1
 input  : [1.00, 0.00]
 output : 0.93
 answer : 1.00
result2
 input  : [0.00, 1.00]
 output : 0.93
 answer : 1.00
result3
 input  : [0.00, 0.00]
 output : 0.12
 answer : 0.00

解説

 学習データの作成は33行目~41行目で行っており、入力データtIn[i]のOR計算結果がtOut[i]となるようにデータを作成している。

 45行目〜58行目ではニューラルネットワークの定義を行っている。定義については、ニューラルネットワークの知識が必要になるため、実装前に利用するニューラルネットワークについて調べる必要がある。今回は単純パーセプトロン(入力データ2個、出力データ1個)であるため、入力層と出力層しか存在しない。Deeplearning 4jでは入力層を層としてカウントしないため、ニューラルネットワークの定義では出力層のみ定義する。出力層の定義は52行目〜57行目で行っており、layer関数によりレイヤー番号0で出力層を定義している。レイヤーの定義はlist関数以降で行うようにするとよい。もし、複数レイヤーを利用する場合は、layer番号を1づつずらしながらlayer関数を複数回呼び出すことになる。

 残りの定義部ではニューラルネットワーク全体の設定を行っている。46行目~47行目で学習率0.01で100回の学習をすることを定義。続いて重みパラメータをXAViIER で初期化して、確率的勾配降下法で更新することを定義している(48行目~50行目)。そして、大事なのがニューラルネットワークの学習を有効にするbackprop関数(57行目)で、誤差逆伝搬法により学習をすることを指定している。これらのパラメータについて定義できる値は別記事で詳しく確認する。

 ニューラルネットワークの定義が終わるとインスタンス化し(60行目~62行目)、68行目で学習を実施、76行目では学習が完了した単純パーセプトロンによる出力を行っている。実行結果を見ると、おおむね正しいOR計算の結果に収束していることが確認できる。


■ 参照

  1. JavaDoc 「Class DataSet」
  2. JavaDoc 「Class Nd4j」
  3. stack overflow 「train simple neural network with deeplearning4j」
  4. ND4J: N-Dimensional Arrays for Java
  5. JavaDoc DeepLearning4j
  6. JavaDoc ND4J

改訂履歴
・2016年 6月11日 一部改訂。『主な設定項目』を別記事に移動
- PR -
Home > Javaで機械学習 - Deeplearning4j入門 > Java DeepLearning4j 基本的な利用方法

- ランダム記事 -

コメント

ただいまコメントを受けつけておりません。

QRコード

プロフィール

管理者:
連絡はContactよりお願いします。

PR