忍者ブログ

軽Lab

 Javaを中心とした、プログラミング関係のナレッジベース

Home > Javaで機械学習 - Deeplearning4j入門 > Java DeepLearning4j ファイル入出力(永続化)

Java DeepLearning4j ファイル入出力(永続化)

ニューラルネットワークは学習に時間がかかるため、実際に利用する場合には事前に学習したニューラルネットワークの設定を読み込んで利用することが多い。Deeplearning4jにおいても事前学習したニューラルネットワークをファイルとして保存・再構築することができる。今回はその方法を確認していく。


■ ファイル入出力用関数

 Deeplearning4jでニューラルネットワークをファイルに出力、ファイルから構築するには以下の関数を利用する。writeModel関数で出力したクラスがMultiLayerNetworkの場合はrestoreMultiLayerNetwork関数で再構築でき、writeModel関数で出力したクラスがComputationGraphクラスの場合にはrestoreComputationGraph関数で再構築できる。

関数 入出力 内容
ModelSerializer.writeModel 出力 ニューラルネットワークをファイルに出力
ModelSerializer.restoreMultiLayerNetwork 入力 ファイルからMultiLayerNetworkクラスのニューラルネットワークを構築
ModelSerializer.restoreComputationGraph 入力 ファイルからComputationGraphクラスのニューラルネットワークを構築


■ サンプルプログラム

 以下にDeepLearning 4jを用いて学習済みのニューラルネットワークをファイルに出力、ファイルから再構築するプログラムを示す。サンプルでは前記事で作成した単純パーセプトロンをファイルに一度出力した後、出力したファイルをもとに単純パーセプトロンを再構築・利用している。

サンプルプログラム

import java.io.File;
import java.io.IOException;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * OR計算を行うパーセプトロン
 * @author karura
 */
public class Perceptron
{
    // 変数
    protected static final Logger log = LoggerFactory.getLogger( LenetMnistExample.class );       // ロガー
    
    // メイン関数
    public static void main(String[] args) throws Exception
    {
        // 変数定義
        int seed        = 123;          // 乱数シード
        int iterations  = 1000;         // 学習の試行回数
        int inputNum    = 2;            // 入力数
        int outputNum   = 1;            // 出力数
        INDArray    tIn     = Nd4j.create( new float[]{ 1 , 1 ,         // 入力1
                                                        1 , 0 ,         // 入力2
                                                        0 , 1 ,         // 入力3
                                                        0 , 0 },        // 入力4
                                           new int[]{ 4 , 2 } );        // サイズ
        INDArray    tOut    = Nd4j.create( new float[]{ 1 , 1 , 1 , 0} ,
                                           new int[]{ 4 , 1 } );        // サイズ
        DataSet     train   = new DataSet( tIn , tOut );                // 入出力を対応付けたデータセット
        System.out.println( train );
        
        // ニューラルネットワークを定義
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .learningRate(0.01)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater( Updater.NONE )
                .list()
                .layer(0, new OutputLayer.Builder( LossFunctions.LossFunction.MSE )
                        .nIn(inputNum)
                        .nOut(outputNum)
                        .activation("sigmoid")
                        .build())
                .backprop(true).pretrain(false);
        
        // ニューラルネットワークを作成
        MultiLayerConfiguration conf        = builder.build();
        MultiLayerNetwork       perceptron  = new MultiLayerNetwork(conf);
        perceptron.init();
        
        // 確認用のリスナーを追加
        perceptron.setListeners( new ScoreIterationListener(1) );
        
        // 学習(fit)
        perceptron.fit( train );
        
        // パーセプトロンをファイルに出力
        File    f1   = new File( "output/perceptron.dl4j" );
        ModelSerializer.writeModel( perceptron , f1 , true );
        
        // 構築済みパーセプトロンの読込
        File                f2      = new File( "output/perceptron.dl4j" );
        MultiLayerNetwork   model   = null;
        try {
            // ファイルから読込
            model   = ModelSerializer.restoreMultiLayerNetwork( f2 );
        } catch (IOException e) { e.printStackTrace(); }
        
        // パーセプトロンの使用
        for( int i=0 ; i<train.numExamples() ; i++ )
        {
            // i個目のサンプルについて、
            INDArray    input  = train.get(i).getFeatureMatrix();
            INDArray    answer = train.get(i).getLabels();
            INDArray    output = model.output( input , false );
            System.out.println( "result" + i );
            System.out.println( " input  : " + input );
            System.out.println( " output : " + output );
            System.out.println( " answer : " + answer );
            System.out.flush();
            
        }
        
    }
}

実行結果

00:59:11.485 [main] DEBUG org.nd4j.nativeblas.NativeOps - Number of threads used for linear algebra 1
00:59:11.492 [main] DEBUG org.nd4j.nativeblas.NativeOps - Number of threads used for linear algebra 1

…中略…

00:58:40.017 [main] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 998 is 0.0031967340037226677
00:58:40.019 [main] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 999 is 0.0031930606346577406
result0
 input  : [1.00, 1.00]
 output : 1.00
 answer : 1.00
result1
 input  : [1.00, 0.00]
 output : 0.93
 answer : 1.00
result2
 input  : [0.00, 1.00]
 output : 0.93
 answer : 1.00
result3
 input  : [0.00, 0.00]
 output : 0.12
 answer : 0.00

解説

 単純パーセプトロンの実装については割愛するが72行目までで単純パーセプトロンの構築・学習が完了しており、75行目~76行目でファイルへと書き出している。その後、79行目~80行目ではファイルをもとにして学習後の単純パーセプトロンを再構築している。

 今回は単一のプログラム内でファイル出力・入力を行っているので意味は特にない。しかし、例えば学習に10時間かかるニューラルネットワークあったとしても、事前にファイル出力しておけば利用の際にはファイル読込にかかる数秒程度の時間しかかからないため高速な動作が可能になる。

■ 参照

  1. DeepLearning4j 公式 「Saving and Loading a Neural Network」
Home > Javaで機械学習 - Deeplearning4j入門 > Java DeepLearning4j ファイル入出力(永続化)

- ランダム記事 -
- PR -

コメント

プロフィール

管理者:
 連絡はContactよりお願いします。

PR