/*
 * Decompiled with CFR 0.152.
 */
package org.genericsystem.reinforcer;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.datavec.api.records.impl.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;

public class VecRecordReader
extends FileRecordReader {
    private final WordVectors dictionary;
    private final TokenizerFactory tokenizer;

    public VecRecordReader(WordVectors dictionary, TokenizerFactory tokenizer) {
        this.dictionary = dictionary;
        this.tokenizer = tokenizer;
    }

    public org.datavec.api.records.Record nextRecord() {
        File next;
        if (this.iter == null || !this.iter.hasNext()) {
            this.advanceToNextLocation();
        }
        this.currentFile = next = (File)this.iter.next();
        this.invokeListeners(next);
        List<Writable> ret = this.loadFromFile(next);
        return new Record(ret, (RecordMetaData)new RecordMetaDataURI(next.toURI(), FileRecordReader.class));
    }

    private List<Writable> loadFromFile(File next) {
        ArrayList<Writable> ret = new ArrayList<Writable>();
        try {
            String text = FileUtils.readFileToString((File)next);
            List wordList = Arrays.asList(text.split(" ")).stream().map(word -> this.tokenizer.getTokenPreProcessor().preProcess(word)).filter(word -> this.dictionary.hasWord(word)).collect(Collectors.toList());
            if (wordList.isEmpty()) {
                return this.next();
            }
            INDArray average = this.dictionary.getWordVectorsMean(wordList);
            for (int i = 0; i < average.columns(); ++i) {
                ret.add((Writable)new DoubleWritable(average.getRow(0).getDouble(i)));
            }
            if (this.appendLabel) {
                ret.add((Writable)new IntWritable(this.labels.indexOf(next.getParentFile().getName())));
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return ret;
    }
}

