package fastfix; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.InputStreamReader; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.List; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.deeplearning4j.berkeley.Pair; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; /** * arrayindexoutofboundsexceptionを回避するためのFast Fixソース * @author karura */ public class WordVectorSerializerFastFix extends WordVectorSerializer { public static WordVectors loadTxtVectors(File vectorsFile) throws FileNotFoundException, UnsupportedEncodingException { Pair pair = loadTxt(vectorsFile); return WordVectorSerializer.fromPair(pair); } public static Pair loadTxt(File vectorsFile) throws FileNotFoundException, UnsupportedEncodingException { BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(vectorsFile), "UTF-8")); VocabCache cache = new AbstractCache(); LineIterator iter = IOUtils.lineIterator(reader); String line = null; boolean hasHeader = false; if (iter.hasNext()) { line = iter.nextLine(); if (!line.contains(" ")) { hasHeader = true; } else { String[] split = line.split(" "); try { for (int x = 1; x < split.length; x++) { double d = Double.parseDouble(split[x]); } if (split.length < 4) hasHeader = true; } catch (Exception e) { hasHeader = true; try { reader.close(); } catch (Exception localException3) {} } } } if (hasHeader) { line = ""; iter.close(); reader = new BufferedReader(new FileReader(vectorsFile)); iter = IOUtils.lineIterator(reader); iter.nextLine(); } List arrays = new ArrayList(); while (iter.hasNext()) { if (line.isEmpty()) line = iter.nextLine(); String[] split = line.split(" "); String word = split[0].replaceAll("_Az92_", " "); VocabWord word1 = new VocabWord(1.0D, word); word1.setIndex(cache.numWords()); cache.addToken(word1); cache.addWordToIndex(word1.getIndex(), word); cache.putVocabWord(word); //FastFix //float[] vector = new float[split.length - 2]; float[] vector = new float[split.length - 1]; for (int i = 1; i < split.length; i++) { vector[(i - 1)] = Float.parseFloat(split[i]); } INDArray row = Nd4j.create(vector); arrays.add(row); line = ""; } INDArray syn = Nd4j.vstack(arrays); InMemoryLookupTable lookupTable = (InMemoryLookupTable)new InMemoryLookupTable.Builder().vectorLength(((INDArray)arrays.get(0)).columns()).useAdaGrad(false).cache(cache).build(); Nd4j.clearNans(syn); lookupTable.setSyn0(syn); iter.close(); try { reader.close(); } catch (Exception localException2) {} return new Pair(lookupTable, cache); } }