/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.client.record.reader;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.RawComparator;
import org.apache.uniffle.client.api.ShuffleServerClient;
import org.apache.uniffle.client.factory.ShuffleServerClientFactory;
import org.apache.uniffle.client.record.Record;
import org.apache.uniffle.client.record.RecordBlob;
import org.apache.uniffle.client.record.RecordBuffer;
import org.apache.uniffle.client.record.metrics.MetricsReporter;
import org.apache.uniffle.client.record.reader.BufferedSegment;
import org.apache.uniffle.client.record.reader.KeyValueReader;
import org.apache.uniffle.client.record.reader.KeyValuesReader;
import org.apache.uniffle.client.record.writer.Combiner;
import org.apache.uniffle.client.request.RssGetSortedShuffleDataRequest;
import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.merger.MergeState;
import org.apache.uniffle.common.merger.Merger;
import org.apache.uniffle.common.merger.Segment;
import org.apache.uniffle.common.netty.buffer.ManagedBuffer;
import org.apache.uniffle.common.records.RecordsReader;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.serializer.SerInputStream;
import org.apache.uniffle.common.serializer.Serializer;
import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.writable.ComparativeOutputBuffer;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RMRecordsReader<K, V, C> {
    private static final Logger LOG = LoggerFactory.getLogger(RMRecordsReader.class);
    private String appId;
    private final int shuffleId;
    private final Set<Integer> partitionIds;
    private final RssConf rssConf;
    private final Class<K> keyClass;
    private final Class<V> valueClass;
    private final Comparator comparator;
    private boolean raw;
    private final Combiner combiner;
    private boolean isMapCombine;
    private final MetricsReporter metrics;
    private final String clientType;
    private SerializerInstance serializerInstance;
    private final int retryMax;
    private final long retryIntervalMax;
    private final long initFetchSleepTime;
    private final long maxFetchSleepTime;
    private final int maxBufferPerPartition;
    private final int maxRecordsNumPerBuffer;
    private Map<Integer, List<ShuffleServerInfo>> shuffleServerInfoMap;
    private volatile boolean stop = false;
    private volatile Throwable error = null;
    private Map<Integer, Queue<RecordBuffer>> combineBuffers = JavaUtils.newConcurrentMap();
    private Map<Integer, Queue<RecordBuffer>> mergeBuffers = JavaUtils.newConcurrentMap();
    private Queue<Record> results;

    public RMRecordsReader(String appId, int shuffleId, Set<Integer> partitionIds, Map<Integer, List<ShuffleServerInfo>> shuffleServerInfoMap, RssConf rssConf, Class<K> keyClass, Class<V> valueClass, Comparator<K> comparator, boolean raw, Combiner combiner, boolean isMapCombine, MetricsReporter metrics) {
        this(appId, shuffleId, partitionIds, shuffleServerInfoMap, rssConf, keyClass, valueClass, comparator, raw, combiner, isMapCombine, metrics, ClientType.GRPC.name());
    }

    public RMRecordsReader(String appId, int shuffleId, Set<Integer> partitionIds, Map<Integer, List<ShuffleServerInfo>> shuffleServerInfoMap, RssConf rssConf, Class<K> keyClass, Class<V> valueClass, Comparator<K> comparator, boolean raw, Combiner combiner, boolean isMapCombine, MetricsReporter metrics, String clientType) {
        this.appId = appId;
        this.shuffleId = shuffleId;
        this.partitionIds = partitionIds;
        this.shuffleServerInfoMap = shuffleServerInfoMap;
        this.rssConf = rssConf;
        this.keyClass = keyClass;
        this.valueClass = valueClass;
        this.raw = raw;
        if (raw && comparator == null) {
            throw new RssException("RawComparator must be set!");
        }
        this.comparator = comparator != null ? comparator : new Comparator<K>(){

            @Override
            public int compare(K o1, K o2) {
                int h2;
                int h1 = o1 == null ? 0 : o1.hashCode();
                int n = h2 = o2 == null ? 0 : o2.hashCode();
                return h1 < h2 ? -1 : (h1 == h2 ? 0 : 1);
            }
        };
        this.combiner = combiner;
        this.isMapCombine = isMapCombine;
        this.metrics = metrics;
        this.clientType = clientType;
        if (this.raw) {
            SerializerFactory factory = new SerializerFactory(rssConf);
            Serializer serializer = factory.getSerializer(keyClass);
            assert (factory.getSerializer(valueClass).getClass().equals(serializer.getClass()));
            this.serializerInstance = serializer.newInstance();
        }
        this.initFetchSleepTime = rssConf.get(RssClientConf.RSS_CLIENT_REMOTE_MERGE_FETCH_INIT_SLEEP_MS).intValue();
        this.maxFetchSleepTime = rssConf.get(RssClientConf.RSS_CLIENT_REMOTE_MERGE_FETCH_MAX_SLEEP_MS).intValue();
        int maxBuffer = rssConf.get(RssClientConf.RSS_CLIENT_REMOTE_MERGE_READER_MAX_BUFFER);
        this.maxBufferPerPartition = Math.max(1, maxBuffer / partitionIds.size());
        this.maxRecordsNumPerBuffer = rssConf.get(RssClientConf.RSS_CLIENT_REMOTE_MERGE_READER_MAX_RECORDS_PER_BUFFER);
        this.results = new Queue(this.maxBufferPerPartition * this.maxRecordsNumPerBuffer * partitionIds.size());
        this.retryMax = rssConf.getInteger("rss.client.retry.max", 50);
        this.retryIntervalMax = rssConf.getLong("rss.client.retry.interval.max", 10000L);
        LOG.info("RMRecordsReader constructed for partitions {}", partitionIds);
    }

    public void start() {
        for (int partitionId : this.partitionIds) {
            this.mergeBuffers.put(partitionId, new Queue(this.maxBufferPerPartition));
            if (this.combiner != null) {
                this.combineBuffers.put(partitionId, new Queue(this.maxBufferPerPartition));
            }
            RecordsFetcher fetcher = new RecordsFetcher(partitionId);
            fetcher.start();
            if (this.combiner == null) continue;
            RecordsCombiner combineThread = new RecordsCombiner(partitionId);
            combineThread.start();
        }
        RecordsMerger recordMerger = new RecordsMerger();
        recordMerger.start();
    }

    public void close() {
        this.error = null;
        this.stop = true;
        for (Queue<RecordBuffer> buffer : this.mergeBuffers.values()) {
            buffer.clear();
        }
        this.mergeBuffers.clear();
        if (this.combiner != null) {
            for (Queue<RecordBuffer> buffer : this.combineBuffers.values()) {
                buffer.clear();
            }
            this.combineBuffers.clear();
        }
        if (this.results != null) {
            this.results.clear();
            this.results = null;
        }
    }

    private boolean isSameKey(Object k1, Object k2) {
        if (this.raw) {
            ComparativeOutputBuffer buffer1 = (ComparativeOutputBuffer)((Object)k1);
            ComparativeOutputBuffer buffer2 = (ComparativeOutputBuffer)((Object)k2);
            return ((RawComparator)this.comparator).compare(buffer1.getData(), 0, buffer1.getLength(), buffer2.getData(), 0, buffer2.getLength()) == 0;
        }
        return this.comparator.compare(k1, k2) == 0;
    }

    public KeyValueReader<ComparativeOutputBuffer, ComparativeOutputBuffer> rawKeyValueReader() {
        if (!this.raw) {
            throw new RssException("rawKeyValueReader is not supported!");
        }
        return new KeyValueReader<ComparativeOutputBuffer, ComparativeOutputBuffer>(){
            private Record<ComparativeOutputBuffer, ComparativeOutputBuffer> curr = null;

            @Override
            public boolean hasNext() throws IOException {
                try {
                    if (this.curr != null) {
                        return true;
                    }
                    this.curr = (Record)RMRecordsReader.this.results.take();
                    return this.curr != null;
                }
                catch (InterruptedException e) {
                    throw new IOException(e);
                }
            }

            @Override
            public Record<ComparativeOutputBuffer, ComparativeOutputBuffer> next() throws IOException {
                Record next = Record.create(this.curr.getKey(), this.curr.getValue());
                this.curr = null;
                return next;
            }
        };
    }

    public KeyValueReader<K, C> keyValueReader() {
        return new KeyValueReader<K, C>(){
            private Record<K, C> curr = null;

            @Override
            public boolean hasNext() throws IOException {
                try {
                    if (this.curr != null) {
                        return true;
                    }
                    this.curr = (Record)RMRecordsReader.this.results.take();
                    return this.curr != null;
                }
                catch (InterruptedException e) {
                    throw new IOException(e);
                }
            }

            @Override
            public Record<K, C> next() throws IOException {
                Record record = Record.create(this.getCurrentKey(), this.getCurrentValue());
                this.curr = null;
                return record;
            }

            public K getCurrentKey() throws IOException {
                if (RMRecordsReader.this.raw) {
                    ComparativeOutputBuffer keyBuffer = (ComparativeOutputBuffer)((Object)this.curr.getKey());
                    DataInputBuffer keyInputBuffer = new DataInputBuffer();
                    keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength());
                    return RMRecordsReader.this.serializerInstance.deserialize(keyInputBuffer, RMRecordsReader.this.keyClass);
                }
                return this.curr.getKey();
            }

            public C getCurrentValue() throws IOException {
                if (RMRecordsReader.this.raw) {
                    ComparativeOutputBuffer valueBuffer = (ComparativeOutputBuffer)((Object)this.curr.getValue());
                    DataInputBuffer valueInputBuffer = new DataInputBuffer();
                    valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength());
                    return RMRecordsReader.this.serializerInstance.deserialize(valueInputBuffer, RMRecordsReader.this.valueClass);
                }
                return this.curr.getValue();
            }
        };
    }

    public KeyValuesReader<K, C> keyValuesReader() {
        return new KeyValuesReader(){
            private Record<K, C> start = null;

            @Override
            public boolean next() throws IOException {
                try {
                    if (this.start == null) {
                        this.start = (Record)RMRecordsReader.this.results.take();
                        return this.start != null;
                    }
                    return true;
                }
                catch (InterruptedException e) {
                    throw new IOException(e);
                }
            }

            @Override
            public K getCurrentKey() throws IOException {
                if (RMRecordsReader.this.raw) {
                    ComparativeOutputBuffer keyBuffer = (ComparativeOutputBuffer)((Object)this.start.getKey());
                    DataInputBuffer keyInputBuffer = new DataInputBuffer();
                    keyInputBuffer.reset(keyBuffer.getData(), 0, keyBuffer.getLength());
                    return RMRecordsReader.this.serializerInstance.deserialize(keyInputBuffer, RMRecordsReader.this.keyClass);
                }
                return this.start.getKey();
            }

            public Iterable<C> getCurrentValues() throws IOException {
                return new Iterable<C>(){

                    @Override
                    public Iterator<C> iterator() {
                        return new Iterator<C>(){
                            Record<K, C> curr;
                            {
                                this.curr = start;
                            }

                            @Override
                            public boolean hasNext() {
                                if (this.curr != null && RMRecordsReader.this.isSameKey(this.curr.getKey(), start.getKey())) {
                                    return true;
                                }
                                start = this.curr;
                                return false;
                            }

                            @Override
                            public C next() {
                                try {
                                    Object ret;
                                    if (RMRecordsReader.this.raw) {
                                        ComparativeOutputBuffer valueBuffer = (ComparativeOutputBuffer)((Object)this.curr.getValue());
                                        DataInputBuffer valueInputBuffer = new DataInputBuffer();
                                        valueInputBuffer.reset(valueBuffer.getData(), 0, valueBuffer.getLength());
                                        ret = RMRecordsReader.this.serializerInstance.deserialize(valueInputBuffer, RMRecordsReader.this.valueClass);
                                    } else {
                                        ret = this.curr.getValue();
                                    }
                                    this.curr = (Record)RMRecordsReader.this.results.take();
                                    return ret;
                                }
                                catch (IOException | InterruptedException e) {
                                    throw new RssException(e);
                                }
                            }
                        };
                    }
                };
            }
        };
    }

    @VisibleForTesting
    public ShuffleServerClient createShuffleServerClient(ShuffleServerInfo shuffleServerInfo) {
        return ShuffleServerClientFactory.getInstance().getShuffleServerClient(this.clientType, shuffleServerInfo, this.rssConf);
    }

    class RecordsMerger
    extends Thread {
        RecordsMerger() {
            this.setName("RecordsMerger");
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            try {
                ArrayList<Segment> segments = new ArrayList<Segment>();
                Iterator iterator = RMRecordsReader.this.partitionIds.iterator();
                while (iterator.hasNext()) {
                    int partitionId = (Integer)iterator.next();
                    RecordBuffer recordBuffer = (RecordBuffer)((Queue)RMRecordsReader.this.mergeBuffers.get(partitionId)).take();
                    if (recordBuffer == null) continue;
                    BufferedSegment resolvedSegment = new BufferedSegment(recordBuffer);
                    segments.add(resolvedSegment);
                }
                try (Merger.MergeQueue mergeQueue = new Merger.MergeQueue(RMRecordsReader.this.rssConf, segments, RMRecordsReader.this.keyClass, RMRecordsReader.this.valueClass, RMRecordsReader.this.comparator, RMRecordsReader.this.raw, false);){
                    mergeQueue.init();
                    mergeQueue.setPopSegmentHook(pid -> {
                        try {
                            RecordBuffer recordBuffer = (RecordBuffer)((Queue)RMRecordsReader.this.mergeBuffers.get(pid)).take();
                            if (recordBuffer == null) {
                                return null;
                            }
                            return new BufferedSegment(recordBuffer);
                        }
                        catch (InterruptedException ex) {
                            throw new RssException(ex);
                        }
                    });
                    while (!RMRecordsReader.this.stop && mergeQueue.next()) {
                        RMRecordsReader.this.results.put(Record.create(mergeQueue.getCurrentKey(), mergeQueue.getCurrentValue()));
                    }
                }
                if (!RMRecordsReader.this.stop) {
                    RMRecordsReader.this.results.setProducerDone(true);
                }
            }
            catch (IOException | InterruptedException e) {
                RMRecordsReader.this.error = e;
                RMRecordsReader.this.stop = true;
            }
        }
    }

    class RecordsCombiner
    extends Thread {
        private int partitionId;
        private RecordBuffer<K, C> cached;
        private Queue<RecordBuffer> nextQueue;

        RecordsCombiner(int partitionId) {
            this.partitionId = partitionId;
            this.cached = new RecordBuffer(partitionId);
            this.nextQueue = (Queue)RMRecordsReader.this.mergeBuffers.get(partitionId);
            this.setName("RecordsCombiner-" + partitionId);
        }

        @Override
        public void run() {
            while (!RMRecordsReader.this.stop) {
                try {
                    RecordBuffer current = (RecordBuffer)((Queue)RMRecordsReader.this.combineBuffers.get(this.partitionId)).take();
                    if (current == null) {
                        if (this.cached.size() > 0) {
                            this.sendCachedBuffer(this.cached);
                        }
                        this.nextQueue.setProducerDone(true);
                        break;
                    }
                    if (this.cached.size() > 0 && !RMRecordsReader.this.isSameKey(this.cached.getLastKey(), current.getFirstKey())) {
                        this.sendCachedBuffer(this.cached);
                        this.cached = new RecordBuffer(this.partitionId);
                    }
                    RecordBlob recordBlob = new RecordBlob(this.partitionId);
                    recordBlob.addRecords(current);
                    recordBlob.combine(RMRecordsReader.this.combiner, RMRecordsReader.this.isMapCombine);
                    for (Record record : recordBlob.getResult()) {
                        if (this.cached.size() >= RMRecordsReader.this.maxRecordsNumPerBuffer && !RMRecordsReader.this.isSameKey(record.getKey(), this.cached.getLastKey())) {
                            this.sendCachedBuffer(this.cached);
                            this.cached = new RecordBuffer(this.partitionId);
                        }
                        this.cached.addRecord(record);
                    }
                }
                catch (InterruptedException e) {
                    throw new RssException(e);
                }
            }
        }

        private void sendCachedBuffer(RecordBuffer<K, C> cachedBuffer) throws InterruptedException {
            RecordBlob recordBlob = new RecordBlob(this.partitionId);
            recordBlob.addRecords(cachedBuffer);
            recordBlob.combine(RMRecordsReader.this.combiner, true);
            RecordBuffer recordBuffer = new RecordBuffer(this.partitionId);
            recordBuffer.addRecords(recordBlob.getResult());
            this.nextQueue.put(recordBuffer);
        }
    }

    class RecordsFetcher
    extends Thread {
        private int partitionId;
        private long sleepTime;
        private long blockId = 1L;
        private RecordBuffer recordBuffer;
        private Queue<RecordBuffer> nextQueue;
        private List<ShuffleServerInfo> serverInfos;
        private ShuffleServerClient client;
        private int choose;
        private String fetchError;

        RecordsFetcher(int partitionId) {
            this.partitionId = partitionId;
            this.sleepTime = RMRecordsReader.this.initFetchSleepTime;
            this.recordBuffer = new RecordBuffer(partitionId);
            this.nextQueue = RMRecordsReader.this.combiner == null ? (Queue)RMRecordsReader.this.mergeBuffers.get(partitionId) : (Queue)RMRecordsReader.this.combineBuffers.get(partitionId);
            this.serverInfos = (List)RMRecordsReader.this.shuffleServerInfoMap.get(partitionId);
            this.choose = this.serverInfos.size() - 1;
            this.client = RMRecordsReader.this.createShuffleServerClient(this.serverInfos.get(this.choose));
            this.setName("RecordsFetcher-" + partitionId);
        }

        private void nextShuffleServerInfo() {
            if (this.choose <= 0) {
                throw new RssException("Fetch sorted record failed, last error message is " + this.fetchError);
            }
            --this.choose;
            this.client = RMRecordsReader.this.createShuffleServerClient(this.serverInfos.get(this.choose));
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            while (!RMRecordsReader.this.stop) {
                try {
                    RssGetSortedShuffleDataRequest request = new RssGetSortedShuffleDataRequest(RMRecordsReader.this.appId, RMRecordsReader.this.shuffleId, this.partitionId, this.blockId, RMRecordsReader.this.retryMax, RMRecordsReader.this.retryIntervalMax);
                    RssGetSortedShuffleDataResponse response = this.client.getSortedShuffleData(request);
                    if (response.getStatusCode() != StatusCode.SUCCESS || response.getMergeState() == MergeState.INTERNAL_ERROR.code()) {
                        this.fetchError = response.getMessage();
                        this.nextShuffleServerInfo();
                        break;
                    }
                    if (response.getMergeState() == MergeState.INITED.code()) {
                        this.fetchError = "Remote merge should be started!";
                        this.nextShuffleServerInfo();
                        break;
                    }
                    if (response.getMergeState() == MergeState.MERGING.code() && response.getNextBlockId() == -1L) {
                        LOG.info("RMRecordsFetcher will sleep {} ms", (Object)this.sleepTime);
                        Thread.sleep(this.sleepTime);
                        this.sleepTime = Math.min(this.sleepTime * 2L, RMRecordsReader.this.maxFetchSleepTime);
                        continue;
                    }
                    if (response.getMergeState() == MergeState.DONE.code() && response.getNextBlockId() == -1L) {
                        if (this.recordBuffer.size() > 0) {
                            this.nextQueue.put(this.recordBuffer);
                        }
                        this.nextQueue.setProducerDone(true);
                        break;
                    }
                    if (response.getMergeState() == MergeState.DONE.code() || response.getMergeState() == MergeState.MERGING.code()) {
                        this.sleepTime = RMRecordsReader.this.initFetchSleepTime;
                        this.blockId = response.getNextBlockId();
                        ManagedBuffer managedBuffer = null;
                        ByteBuf byteBuf = null;
                        RecordsReader reader = null;
                        try {
                            managedBuffer = response.getData();
                            byteBuf = managedBuffer.byteBuf();
                            reader = new RecordsReader(RMRecordsReader.this.rssConf, SerInputStream.newInputStream(byteBuf), RMRecordsReader.this.keyClass, RMRecordsReader.this.valueClass, RMRecordsReader.this.raw, false);
                            reader.init();
                            while (reader.next()) {
                                if (RMRecordsReader.this.metrics != null) {
                                    RMRecordsReader.this.metrics.incRecordsRead(1L);
                                }
                                if (this.recordBuffer.size() >= RMRecordsReader.this.maxRecordsNumPerBuffer) {
                                    this.nextQueue.put(this.recordBuffer);
                                    this.recordBuffer = new RecordBuffer(this.partitionId);
                                }
                                this.recordBuffer.addRecord(reader.getCurrentKey(), reader.getCurrentValue());
                            }
                            continue;
                        }
                        finally {
                            if (reader != null) {
                                reader.close();
                            }
                            if (byteBuf != null) {
                                byteBuf.release();
                            }
                            if (managedBuffer != null) {
                                managedBuffer.release();
                            }
                            continue;
                        }
                    }
                    this.fetchError = "Receive wrong offset from server, offset is " + response.getNextBlockId();
                    this.nextShuffleServerInfo();
                    break;
                }
                catch (Throwable e) {
                    RMRecordsReader.this.error = e;
                    RMRecordsReader.this.stop = true;
                    LOG.info("Found exception when fetch sorted record, caused by ", e);
                }
            }
        }
    }

    class Queue<E> {
        private LinkedBlockingQueue<E> queue;
        private volatile boolean producerDone = false;

        Queue(int maxBufferPerPartition) {
            this.queue = new LinkedBlockingQueue(maxBufferPerPartition);
        }

        public void setProducerDone(boolean producerDone) {
            this.producerDone = producerDone;
        }

        public void put(E recordBuffer) throws InterruptedException {
            this.queue.put(recordBuffer);
        }

        public E take() throws InterruptedException {
            while (!this.producerDone && !RMRecordsReader.this.stop) {
                E e = this.queue.poll(100L, TimeUnit.MILLISECONDS);
                if (e == null) continue;
                return e;
            }
            if (RMRecordsReader.this.error != null) {
                throw new RssException("RMShuffleReader fetch record failed, caused by " + RMRecordsReader.this.error);
            }
            return this.queue.poll(100L, TimeUnit.MILLISECONDS);
        }

        public void clear() {
            this.queue.clear();
            this.producerDone = false;
        }
    }
}

