ニューラルネットワークは学習に時間がかかるため、実際に利用する場合には事前に学習したニューラルネットワークの設定を読み込んで利用することが多い。Deeplearning4jにおいても事前学習したニューラルネットワークをファイルとして保存・再構築することができる。今回はその方法を確認していく。
■ ファイル入出力用関数
Deeplearning4jでニューラルネットワークをファイルに出力、ファイルから構築するには以下の関数を利用する。writeModel関数で出力したクラスがMultiLayerNetworkの場合はrestoreMultiLayerNetwork関数で再構築でき、writeModel関数で出力したクラスがComputationGraphクラスの場合にはrestoreComputationGraph関数で再構築できる。
関数 |
入出力 |
内容 |
ModelSerializer.writeModel |
出力 |
ニューラルネットワークをファイルに出力 |
ModelSerializer.restoreMultiLayerNetwork |
入力 |
ファイルからMultiLayerNetworkクラスのニューラルネットワークを構築 |
ModelSerializer.restoreComputationGraph |
入力 |
ファイルからComputationGraphクラスのニューラルネットワークを構築 |
■ サンプルプログラム
以下にDeepLearning 4jを用いて学習済みのニューラルネットワークをファイルに出力、ファイルから再構築するプログラムを示す。サンプルでは前記事で作成した単純パーセプトロンをファイルに一度出力した後、出力したファイルをもとに単純パーセプトロンを再構築・利用している。
サンプルプログラム
import java.io.File;
import java.io.IOException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* OR計算を行うパーセプトロン
* @author karura
*/
public class Perceptron
{
// 変数
protected static final Logger log = LoggerFactory.getLogger( LenetMnistExample.class ); // ロガー
// メイン関数
public static void main(String[] args) throws Exception
{
// 変数定義
int seed = 123; // 乱数シード
int iterations = 1000; // 学習の試行回数
int inputNum = 2; // 入力数
int outputNum = 1; // 出力数
INDArray tIn = Nd4j.create( new float[]{ 1 , 1 , // 入力1
1 , 0 , // 入力2
0 , 1 , // 入力3
0 , 0 }, // 入力4
new int[]{ 4 , 2 } ); // サイズ
INDArray tOut = Nd4j.create( new float[]{ 1 , 1 , 1 , 0} ,
new int[]{ 4 , 1 } ); // サイズ
DataSet train = new DataSet( tIn , tOut ); // 入出力を対応付けたデータセット
System.out.println( train );
// ニューラルネットワークを定義
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.learningRate(0.01)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater( Updater.NONE )
.list()
.layer(0, new OutputLayer.Builder( LossFunctions.LossFunction.MSE )
.nIn(inputNum)
.nOut(outputNum)
.activation("sigmoid")
.build())
.backprop(true).pretrain(false);
// ニューラルネットワークを作成
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork perceptron = new MultiLayerNetwork(conf);
perceptron.init();
// 確認用のリスナーを追加
perceptron.setListeners( new ScoreIterationListener(1) );
// 学習(fit)
perceptron.fit( train );
// パーセプトロンをファイルに出力
File f1 = new File( "output/perceptron.dl4j" );
ModelSerializer.writeModel( perceptron , f1 , true );
// 構築済みパーセプトロンの読込
File f2 = new File( "output/perceptron.dl4j" );
MultiLayerNetwork model = null;
try {
// ファイルから読込
model = ModelSerializer.restoreMultiLayerNetwork( f2 );
} catch (IOException e) { e.printStackTrace(); }
// パーセプトロンの使用
for( int i=0 ; i<train.numExamples() ; i++ )
{
// i個目のサンプルについて、
INDArray input = train.get(i).getFeatureMatrix();
INDArray answer = train.get(i).getLabels();
INDArray output = model.output( input , false );
System.out.println( "result" + i );
System.out.println( " input : " + input );
System.out.println( " output : " + output );
System.out.println( " answer : " + answer );
System.out.flush();
}
}
}
実行結果
00:59:11.485 [main] DEBUG org.nd4j.nativeblas.NativeOps - Number of threads used for linear algebra 1
00:59:11.492 [main] DEBUG org.nd4j.nativeblas.NativeOps - Number of threads used for linear algebra 1
…中略…
00:58:40.017 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 998 is 0.0031967340037226677
00:58:40.019 [main] INFO o.d.o.l.ScoreIterationListener - Score at iteration 999 is 0.0031930606346577406
result0
input : [1.00, 1.00]
output : 1.00
answer : 1.00
result1
input : [1.00, 0.00]
output : 0.93
answer : 1.00
result2
input : [0.00, 1.00]
output : 0.93
answer : 1.00
result3
input : [0.00, 0.00]
output : 0.12
answer : 0.00
解説
単純パーセプトロンの実装については割愛するが72行目までで単純パーセプトロンの構築・学習が完了しており、75行目~76行目でファイルへと書き出している。その後、79行目~80行目ではファイルをもとにして学習後の単純パーセプトロンを再構築している。
今回は単一のプログラム内でファイル出力・入力を行っているので意味は特にない。しかし、例えば学習に10時間かかるニューラルネットワークあったとしても、事前にファイル出力しておけば利用の際にはファイル読込にかかる数秒程度の時間しかかからないため高速な動作が可能になる。
■ 参照
- DeepLearning4j 公式 「Saving and Loading a Neural Network」