package org.genericsystem.reinforcer;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/genericsystem/reinforcer/VecRecordReader.class */
public class VecRecordReader extends FileRecordReader {
    private final WordVectors dictionary;
    private final TokenizerFactory tokenizer;

    public VecRecordReader(WordVectors wordVectors, TokenizerFactory tokenizerFactory) {
        this.dictionary = wordVectors;
        this.tokenizer = tokenizerFactory;
    }

    public Record nextRecord() {
        if (this.iter == null || !this.iter.hasNext()) {
            advanceToNextLocation();
        }
        File file = (File) this.iter.next();
        this.currentFile = file;
        invokeListeners(file);
        return new org.datavec.api.records.impl.Record(loadFromFile(file), new RecordMetaDataURI(file.toURI(), FileRecordReader.class));
    }

    private List<Writable> loadFromFile(File file) {
        List list;
        ArrayList arrayList = new ArrayList();
        try {
            list = (List) Arrays.asList(FileUtils.readFileToString(file).split(" ")).stream().map(str -> {
                return this.tokenizer.getTokenPreProcessor().preProcess(str);
            }).filter(str2 -> {
                return this.dictionary.hasWord(str2);
            }).collect(Collectors.toList());
        } catch (IOException e) {
            e.printStackTrace();
        }
        if (list.isEmpty()) {
            return next();
        }
        INDArray wordVectorsMean = this.dictionary.getWordVectorsMean(list);
        for (int i = 0; i < wordVectorsMean.columns(); i++) {
            arrayList.add(new DoubleWritable(wordVectorsMean.getRow(0).getDouble(i)));
        }
        if (this.appendLabel) {
            arrayList.add(new IntWritable(this.labels.indexOf(file.getParentFile().getName())));
        }
        return arrayList;
    }
}
