/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.LongAdder;
import javax.annotation.Nullable;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.write.DataPusher;
import org.apache.celeborn.client.write.PushTask;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.spark.Aggregator;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle;
import org.apache.spark.shuffle.celeborn.OpenByteArrayOutputStream;
import org.apache.spark.shuffle.celeborn.SendBufferPool;
import org.apache.spark.shuffle.celeborn.SparkUtils;
import org.apache.spark.shuffle.celeborn.TaskInterruptedHelper;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.columnar.CelebornBatchBuilder;
import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchBuilder;
import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchCodeGenBuild;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.unsafe.Platform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

@Private
public class HashBasedShuffleWriter<K, V, C>
extends ShuffleWriter<K, V> {
    private static final Logger logger = LoggerFactory.getLogger(HashBasedShuffleWriter.class);
    private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
    private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 0x100000;
    private final int PUSH_BUFFER_INIT_SIZE;
    private final int PUSH_BUFFER_MAX_SIZE;
    private final ShuffleDependency<K, V, C> dep;
    private final Partitioner partitioner;
    private final ShuffleWriteMetricsReporter writeMetrics;
    private final int stageId;
    private final int shuffleId;
    private final int mapId;
    private final TaskContext taskContext;
    private final ShuffleClient shuffleClient;
    private final int numMappers;
    private final int numPartitions;
    private final CelebornConf conf;
    @Nullable
    private MapStatus mapStatus;
    private long peakMemoryUsedBytes = 0L;
    private final OpenByteArrayOutputStream serBuffer;
    private final SerializationStream serOutputStream;
    private byte[][] sendBuffers;
    private int[] sendOffsets;
    private final LongAdder[] mapStatusLengths;
    private final long[] tmpRecords;
    private CelebornBatchBuilder[] celebornBatchBuilders;
    private final SendBufferPool sendBufferPool;
    private volatile boolean stopping = false;
    private DataPusher dataPusher;
    private StructType schema;
    private final boolean unsafeRowFastWrite;
    private boolean isColumnarShuffle = false;
    private int columnarShuffleBatchSize;
    private boolean columnarShuffleCodeGenEnabled;
    private boolean columnarShuffleDictionaryEnabled;
    private double columnarShuffleDictionaryMaxFactor;

    public HashBasedShuffleWriter(CelebornShuffleHandle<K, V, C> handle, TaskContext taskContext, CelebornConf conf, ShuffleClient client, ShuffleWriteMetricsReporter metrics, SendBufferPool sendBufferPool) throws IOException {
        this.mapId = taskContext.partitionId();
        this.dep = handle.dependency();
        this.stageId = taskContext.stageId();
        this.shuffleId = this.dep.shuffleId();
        SerializerInstance serializer = this.dep.serializer().newInstance();
        this.partitioner = this.dep.partitioner();
        this.writeMetrics = metrics;
        this.taskContext = taskContext;
        this.numMappers = handle.numMappers();
        this.numPartitions = this.dep.partitioner().numPartitions();
        this.shuffleClient = client;
        this.conf = conf;
        this.unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
        this.serBuffer = new OpenByteArrayOutputStream(0x100000);
        this.serOutputStream = serializer.serializeStream((OutputStream)this.serBuffer);
        this.mapStatusLengths = new LongAdder[this.numPartitions];
        for (int i = 0; i < this.numPartitions; ++i) {
            this.mapStatusLengths[i] = new LongAdder();
        }
        this.tmpRecords = new long[this.numPartitions];
        this.PUSH_BUFFER_INIT_SIZE = conf.clientPushBufferInitialSize();
        this.PUSH_BUFFER_MAX_SIZE = conf.clientPushBufferMaxSize();
        this.sendBufferPool = sendBufferPool;
        this.sendBuffers = sendBufferPool.acquireBuffer(this.numPartitions);
        this.sendOffsets = new int[this.numPartitions];
        try {
            LinkedBlockingQueue<PushTask> pushTaskQueue = sendBufferPool.acquirePushTaskQueue();
            this.dataPusher = new DataPusher(this.shuffleId, this.mapId, taskContext.attemptNumber(), taskContext.taskAttemptId(), this.numMappers, this.numPartitions, conf, this.shuffleClient, pushTaskQueue, arg_0 -> ((ShuffleWriteMetricsReporter)this.writeMetrics).incBytesWritten(arg_0), this.mapStatusLengths);
        }
        catch (InterruptedException e) {
            TaskInterruptedHelper.throwTaskKillException();
        }
        if (conf.columnarShuffleEnabled()) {
            this.columnarShuffleBatchSize = conf.columnarShuffleBatchSize();
            this.columnarShuffleCodeGenEnabled = conf.columnarShuffleCodeGenEnabled();
            this.columnarShuffleDictionaryEnabled = conf.columnarShuffleDictionaryEnabled();
            this.columnarShuffleDictionaryMaxFactor = conf.columnarShuffleDictionaryMaxFactor();
            this.schema = SparkUtils.getSchema(this.dep);
            this.celebornBatchBuilders = new CelebornBatchBuilder[this.numPartitions];
            this.isColumnarShuffle = this.schema != null && CelebornBatchBuilder.supportsColumnarType(this.schema);
        }
    }

    public void write(Iterator<Product2<K, V>> records) throws IOException {
        try {
            if (this.canUseFastWrite()) {
                if (this.isColumnarShuffle) {
                    logger.info("Fast columnar write of columnar shuffle {} for stage {}.", (Object)this.shuffleId, (Object)this.stageId);
                    this.fastColumnarWrite0(records);
                } else {
                    this.fastWrite0(records);
                }
            } else if (this.dep.mapSideCombine()) {
                if (this.dep.aggregator().isEmpty()) {
                    throw new UnsupportedOperationException("When using map side combine, an aggregator must be specified.");
                }
                this.write0(((Aggregator)this.dep.aggregator().get()).combineValuesByKey(records, this.taskContext));
            } else {
                this.write0(records);
            }
            this.close();
        }
        catch (InterruptedException e) {
            TaskInterruptedHelper.throwTaskKillException();
        }
    }

    @VisibleForTesting
    boolean canUseFastWrite() {
        boolean keyIsPartitionId = false;
        if (this.unsafeRowFastWrite && this.dep.serializer() instanceof UnsafeRowSerializer) {
            String partitionerClassName = this.partitioner.getClass().getSimpleName();
            keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName);
        }
        return keyIsPartitionId;
    }

    private void fastColumnarWrite0(Iterator iterator) throws IOException {
        Iterator records = iterator;
        SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer)this.dep.serializer());
        while (records.hasNext()) {
            Product2 record = (Product2)records.next();
            int partitionId = (Integer)record._1();
            UnsafeRow row = (UnsafeRow)record._2();
            if (this.celebornBatchBuilders[partitionId] == null) {
                CelebornBatchBuilder columnBuilders = this.columnarShuffleCodeGenEnabled && !this.columnarShuffleDictionaryEnabled ? new CelebornColumnarBatchCodeGenBuild().create(this.schema, this.columnarShuffleBatchSize) : new CelebornColumnarBatchBuilder(this.schema, this.columnarShuffleBatchSize, this.columnarShuffleDictionaryMaxFactor, this.columnarShuffleDictionaryEnabled);
                columnBuilders.newBuilders();
                this.celebornBatchBuilders[partitionId] = columnBuilders;
            }
            this.celebornBatchBuilders[partitionId].writeRow((InternalRow)row);
            if (this.celebornBatchBuilders[partitionId].getRowCnt() >= this.columnarShuffleBatchSize) {
                byte[] arr = this.celebornBatchBuilders[partitionId].buildColumnBytes();
                this.pushGiantRecord(partitionId, arr, arr.length);
                if (dataSize != null) {
                    dataSize.add((long)arr.length);
                }
                this.celebornBatchBuilders[partitionId].newBuilders();
            }
            int n = partitionId;
            this.tmpRecords[n] = this.tmpRecords[n] + 1L;
        }
    }

    private void fastWrite0(Iterator iterator) throws IOException, InterruptedException {
        Iterator records = iterator;
        SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer)this.dep.serializer());
        while (records.hasNext()) {
            Product2 record = (Product2)records.next();
            int partitionId = (Integer)record._1();
            UnsafeRow row = (UnsafeRow)record._2();
            int rowSize = row.getSizeInBytes();
            int serializedRecordSize = 4 + rowSize;
            if (dataSize != null) {
                dataSize.add((long)rowSize);
            }
            if (serializedRecordSize > this.PUSH_BUFFER_MAX_SIZE) {
                byte[] giantBuffer = new byte[serializedRecordSize];
                Platform.putInt((Object)giantBuffer, (long)Platform.BYTE_ARRAY_OFFSET, (int)Integer.reverseBytes(rowSize));
                Platform.copyMemory((Object)row.getBaseObject(), (long)row.getBaseOffset(), (Object)giantBuffer, (long)(Platform.BYTE_ARRAY_OFFSET + 4), (long)rowSize);
                this.pushGiantRecord(partitionId, giantBuffer, serializedRecordSize);
            } else {
                int offset = this.getOrUpdateOffset(partitionId, serializedRecordSize);
                byte[] buffer = this.getOrCreateBuffer(partitionId);
                Platform.putInt((Object)buffer, (long)(Platform.BYTE_ARRAY_OFFSET + offset), (int)Integer.reverseBytes(rowSize));
                Platform.copyMemory((Object)row.getBaseObject(), (long)row.getBaseOffset(), (Object)buffer, (long)(Platform.BYTE_ARRAY_OFFSET + offset + 4), (long)rowSize);
                this.sendOffsets[partitionId] = offset + serializedRecordSize;
            }
            int n = partitionId;
            this.tmpRecords[n] = this.tmpRecords[n] + 1L;
        }
    }

    private void write0(Iterator iterator) throws IOException, InterruptedException {
        Iterator records = iterator;
        while (records.hasNext()) {
            Product2 record = (Product2)records.next();
            Object key = record._1();
            int partitionId = this.partitioner.getPartition(key);
            this.serBuffer.reset();
            this.serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
            this.serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
            this.serOutputStream.flush();
            int serializedRecordSize = this.serBuffer.size();
            assert (serializedRecordSize > 0);
            if (serializedRecordSize > this.PUSH_BUFFER_MAX_SIZE) {
                this.pushGiantRecord(partitionId, this.serBuffer.getBuf(), serializedRecordSize);
            } else {
                int offset = this.getOrUpdateOffset(partitionId, serializedRecordSize);
                byte[] buffer = this.getOrCreateBuffer(partitionId);
                System.arraycopy(this.serBuffer.getBuf(), 0, buffer, offset, serializedRecordSize);
                this.sendOffsets[partitionId] = offset + serializedRecordSize;
            }
            int n = partitionId;
            this.tmpRecords[n] = this.tmpRecords[n] + 1L;
        }
    }

    private byte[] getOrCreateBuffer(int partitionId) {
        byte[] buffer = this.sendBuffers[partitionId];
        if (buffer == null) {
            buffer = new byte[this.PUSH_BUFFER_INIT_SIZE];
            this.sendBuffers[partitionId] = buffer;
            this.peakMemoryUsedBytes += (long)this.PUSH_BUFFER_INIT_SIZE;
        }
        return buffer;
    }

    private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException {
        logger.debug("Push giant record, size {}.", (Object)numBytes);
        int bytesWritten = this.shuffleClient.pushData(this.shuffleId, this.mapId, this.taskContext.attemptNumber(), partitionId, buffer, 0, numBytes, this.numMappers, this.numPartitions);
        this.mapStatusLengths[partitionId].add(bytesWritten);
        this.writeMetrics.incBytesWritten((long)bytesWritten);
    }

    private int getOrUpdateOffset(int partitionId, int serializedRecordSize) throws IOException, InterruptedException {
        int offset = this.sendOffsets[partitionId];
        byte[] buffer = this.getOrCreateBuffer(partitionId);
        while (buffer.length - offset < serializedRecordSize && buffer.length < this.PUSH_BUFFER_MAX_SIZE) {
            byte[] newBuffer = new byte[Math.min(buffer.length * 2, this.PUSH_BUFFER_MAX_SIZE)];
            this.peakMemoryUsedBytes += (long)(newBuffer.length - buffer.length);
            System.arraycopy(buffer, 0, newBuffer, 0, offset);
            this.sendBuffers[partitionId] = newBuffer;
            buffer = newBuffer;
        }
        if (buffer.length - offset < serializedRecordSize) {
            this.flushSendBuffer(partitionId, buffer, offset);
            this.updateMapStatus();
            offset = 0;
        }
        return offset;
    }

    private void flushSendBuffer(int partitionId, byte[] buffer, int size) throws IOException, InterruptedException {
        long start = System.nanoTime();
        logger.debug("Flush buffer, size {}.", (Object)Utils.bytesToString(size));
        this.dataPusher.addTask(partitionId, buffer, size);
        this.writeMetrics.incWriteTime(System.nanoTime() - start);
    }

    private void closeColumnarWrite() throws IOException {
        SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer)this.dep.serializer());
        for (int i = 0; i < this.numPartitions; ++i) {
            CelebornBatchBuilder builders = this.celebornBatchBuilders[i];
            if (builders == null || builders.getRowCnt() <= 0) continue;
            byte[] buffers = builders.buildColumnBytes();
            if (dataSize != null) {
                dataSize.add((long)buffers.length);
            }
            int bytesWritten = this.shuffleClient.mergeData(this.shuffleId, this.mapId, this.taskContext.attemptNumber(), i, buffers, 0, buffers.length, this.numMappers, this.numPartitions);
            this.celebornBatchBuilders[i] = null;
            this.mapStatusLengths[i].add(bytesWritten);
            this.writeMetrics.incBytesWritten((long)bytesWritten);
        }
    }

    private void closeRowWrite() throws IOException {
        for (int i = 0; i < this.numPartitions; ++i) {
            int size = this.sendOffsets[i];
            if (size <= 0) continue;
            int bytesWritten = this.shuffleClient.mergeData(this.shuffleId, this.mapId, this.taskContext.attemptNumber(), i, this.sendBuffers[i], 0, size, this.numMappers, this.numPartitions);
            this.sendBuffers[i] = null;
            this.mapStatusLengths[i].add(bytesWritten);
            this.writeMetrics.incBytesWritten((long)bytesWritten);
        }
        this.sendBufferPool.returnBuffer(this.sendBuffers);
        this.sendBuffers = null;
        this.sendOffsets = null;
    }

    private void close() throws IOException, InterruptedException {
        long pushMergedDataTime = System.nanoTime();
        this.dataPusher.waitOnTermination();
        this.sendBufferPool.returnPushTaskQueue(this.dataPusher.getIdleQueue());
        this.shuffleClient.prepareForMergeData(this.shuffleId, this.mapId, this.taskContext.attemptNumber());
        if (this.isColumnarShuffle) {
            this.closeColumnarWrite();
        } else {
            this.closeRowWrite();
        }
        this.shuffleClient.pushMergedData(this.shuffleId, this.mapId, this.taskContext.attemptNumber());
        this.writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
        this.updateMapStatus();
        long waitStartTime = System.nanoTime();
        this.shuffleClient.mapperEnd(this.shuffleId, this.mapId, this.taskContext.attemptNumber(), this.numMappers);
        this.writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
        BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
        this.mapStatus = SparkUtils.createMapStatus(bmId, SparkUtils.unwrap(this.mapStatusLengths), this.taskContext.taskAttemptId());
    }

    private void updateMapStatus() {
        long recordsWritten = 0L;
        for (int i = 0; i < this.partitioner.numPartitions(); ++i) {
            recordsWritten += this.tmpRecords[i];
            this.tmpRecords[i] = 0L;
        }
        this.writeMetrics.incRecordsWritten(recordsWritten);
    }

    public Option<MapStatus> stop(boolean success) {
        try {
            this.taskContext.taskMetrics().incPeakExecutionMemory(this.peakMemoryUsedBytes);
            if (this.stopping) {
                Option option = Option.empty();
                return option;
            }
            this.stopping = true;
            if (success) {
                if (this.mapStatus == null) {
                    throw new IllegalStateException("Cannot call stop(true) without having called write()");
                }
                Option option = Option.apply((Object)this.mapStatus);
                return option;
            }
            Option option = Option.empty();
            return option;
        }
        finally {
            this.shuffleClient.cleanup(this.shuffleId, this.mapId, this.taskContext.attemptNumber());
        }
    }

    public long[] getPartitionLengths() {
        throw new UnsupportedOperationException("Celeborn is not compatible with Spark push mode, please set spark.shuffle.push.enabled to false");
    }
}

