/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.codecs.lucene104;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene104.OffHeapScalarQuantizedVectorValues;
import org.apache.lucene.codecs.lucene104.QuantizedByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;

public class Lucene104ScalarQuantizedVectorScorer
implements FlatVectorsScorer {
    private final FlatVectorsScorer nonQuantizedDelegate;
    private static final float[] SCALE_LUT = new float[]{1.0f, 0.33333334f, 0.14285715f, 0.06666667f, 0.032258064f, 0.015873017f, 0.007874016f, 0.003921569f};

    public Lucene104ScalarQuantizedVectorScorer(FlatVectorsScorer nonQuantizedDelegate) {
        this.nonQuantizedDelegate = nonQuantizedDelegate;
    }

    @Override
    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException {
        if (vectorValues instanceof QuantizedByteVectorValues) {
            QuantizedByteVectorValues qv = (QuantizedByteVectorValues)vectorValues;
            return new ScalarQuantizedVectorScorerSupplier(qv, similarityFunction);
        }
        return this.nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(final VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException {
        if (vectorValues instanceof QuantizedByteVectorValues) {
            final QuantizedByteVectorValues qv = (QuantizedByteVectorValues)vectorValues;
            FlatVectorsScorer.checkDimensions(target.length, qv.dimension());
            OptimizedScalarQuantizer quantizer = qv.getQuantizer();
            Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding scalarEncoding = qv.getScalarEncoding();
            byte[] scratch = new byte[scalarEncoding.getDiscreteDimensions(qv.dimension())];
            final byte[] targetQuantized = !scalarEncoding.isAsymmetric() ? scratch : new byte[scalarEncoding.getQueryPackedLength(scratch.length)];
            float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
            if (similarityFunction == VectorSimilarityFunction.COSINE) {
                VectorUtil.l2normalize(copy);
            }
            target = copy;
            final OptimizedScalarQuantizer.QuantizationResult targetCorrectiveTerms = quantizer.scalarQuantize(target, scratch, scalarEncoding.getQueryBits(), qv.getCentroid());
            if (scalarEncoding == Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE || scalarEncoding == Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.DIBIT_QUERY_NIBBLE) {
                OptimizedScalarQuantizer.transposeHalfByte(scratch, targetQuantized);
            }
            return new RandomVectorScorer.AbstractRandomVectorScorer(this, qv){

                @Override
                public float score(int node) throws IOException {
                    return Lucene104ScalarQuantizedVectorScorer.quantizedScore(targetQuantized, targetCorrectiveTerms, qv, node, similarityFunction);
                }
            };
        }
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    @Override
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException {
        FlatVectorsScorer.checkDimensions(target.length, vectorValues.dimension());
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, QuantizedByteVectorValues scoringVectors, QuantizedByteVectorValues targetVectors) {
        return new AsymmetricQuantizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction);
    }

    public String toString() {
        return "Lucene104ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + String.valueOf(this.nonQuantizedDelegate) + ")";
    }

    private static float quantizedScore(byte[] quantizedQuery, OptimizedScalarQuantizer.QuantizationResult queryCorrections, QuantizedByteVectorValues targetVectors, int targetOrd, VectorSimilarityFunction similarityFunction) throws IOException {
        Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding scalarEncoding = targetVectors.getScalarEncoding();
        byte[] quantizedDoc = targetVectors.vectorValue(targetOrd);
        float qcDist = switch (scalarEncoding) {
            default -> throw new MatchException(null, null);
            case Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.UNSIGNED_BYTE -> VectorUtil.uint8DotProduct(quantizedQuery, quantizedDoc);
            case Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT -> VectorUtil.dotProduct(quantizedQuery, quantizedDoc);
            case Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.PACKED_NIBBLE -> VectorUtil.int4DotProductSinglePacked(quantizedQuery, quantizedDoc);
            case Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE -> VectorUtil.int4BitDotProduct(quantizedQuery, quantizedDoc);
            case Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.DIBIT_QUERY_NIBBLE -> VectorUtil.int4DibitDotProduct(quantizedQuery, quantizedDoc);
        };
        OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd);
        float queryScale = SCALE_LUT[scalarEncoding.getQueryBits() - 1];
        float scale = SCALE_LUT[scalarEncoding.getBits() - 1];
        float x1 = indexCorrections.quantizedComponentSum();
        float ax = indexCorrections.lowerInterval();
        float lx = (indexCorrections.upperInterval() - ax) * scale;
        float ay = queryCorrections.lowerInterval();
        float ly = (queryCorrections.upperInterval() - ay) * queryScale;
        float y1 = queryCorrections.quantizedComponentSum();
        float score = ax * ay * (float)targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
        if (similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
            score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2.0f * score;
            return 1.0f / (1.0f + Math.max(score, 0.0f));
        }
        score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - targetVectors.getCentroidDP();
        if (similarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
            return VectorUtil.scaleMaxInnerProductScore(score);
        }
        score = Math.clamp(score, -1.0f, 1.0f);
        return (1.0f + score) / 2.0f;
    }

    private static final class ScalarQuantizedVectorScorerSupplier
    implements RandomVectorScorerSupplier {
        private final QuantizedByteVectorValues targetValues;
        private final QuantizedByteVectorValues values;
        private final VectorSimilarityFunction similarity;

        public ScalarQuantizedVectorScorerSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction similarity) throws IOException {
            assert (!values.getScalarEncoding().isAsymmetric());
            this.targetValues = values.copy();
            this.values = values;
            this.similarity = similarity;
        }

        @Override
        public UpdateableRandomVectorScorer scorer() throws IOException {
            return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(this.values){
                private byte[] targetVector;
                private OptimizedScalarQuantizer.QuantizationResult targetCorrectiveTerms;

                @Override
                public float score(int node) throws IOException {
                    return Lucene104ScalarQuantizedVectorScorer.quantizedScore(this.targetVector, this.targetCorrectiveTerms, values, node, similarity);
                }

                @Override
                public void setScoringOrdinal(int node) throws IOException {
                    byte[] rawTargetVector = targetValues.vectorValue(node);
                    switch (values.getScalarEncoding()) {
                        case UNSIGNED_BYTE: 
                        case SEVEN_BIT: {
                            this.targetVector = rawTargetVector;
                            break;
                        }
                        case PACKED_NIBBLE: {
                            if (this.targetVector == null) {
                                this.targetVector = new byte[OptimizedScalarQuantizer.discretize(values.dimension(), 2)];
                            }
                            OffHeapScalarQuantizedVectorValues.unpackNibbles(rawTargetVector, this.targetVector);
                            break;
                        }
                        case SINGLE_BIT_QUERY_NIBBLE: 
                        case DIBIT_QUERY_NIBBLE: {
                            throw new IllegalStateException(values.getScalarEncoding().name() + " encoding is not supported for symmetric quantization");
                        }
                    }
                    this.targetCorrectiveTerms = targetValues.getCorrectiveTerms(node);
                }
            };
        }

        @Override
        public RandomVectorScorerSupplier copy() throws IOException {
            return new ScalarQuantizedVectorScorerSupplier(this.values.copy(), this.similarity);
        }
    }

    static class AsymmetricQuantizedRandomVectorScorerSupplier
    implements RandomVectorScorerSupplier {
        private final QuantizedByteVectorValues queryVectors;
        private final QuantizedByteVectorValues targetVectors;
        private final VectorSimilarityFunction similarityFunction;

        AsymmetricQuantizedRandomVectorScorerSupplier(QuantizedByteVectorValues queryVectors, QuantizedByteVectorValues targetVectors, VectorSimilarityFunction similarityFunction) {
            assert (targetVectors.getScalarEncoding().isAsymmetric());
            this.queryVectors = queryVectors;
            this.targetVectors = targetVectors;
            this.similarityFunction = similarityFunction;
        }

        @Override
        public UpdateableRandomVectorScorer scorer() throws IOException {
            final QuantizedByteVectorValues targetVectors = this.targetVectors.copy();
            final QuantizedByteVectorValues queryVectors = this.queryVectors.copy();
            return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(targetVectors){
                private OptimizedScalarQuantizer.QuantizationResult queryCorrections;
                private byte[] vector;
                {
                    super(values);
                    this.queryCorrections = null;
                    this.vector = null;
                }

                @Override
                public void setScoringOrdinal(int node) throws IOException {
                    this.vector = queryVectors.vectorValue(node);
                    this.queryCorrections = queryVectors.getCorrectiveTerms(node);
                }

                @Override
                public float score(int node) throws IOException {
                    if (this.vector == null || this.queryCorrections == null) {
                        throw new IllegalStateException("setScoringOrdinal was not called");
                    }
                    return Lucene104ScalarQuantizedVectorScorer.quantizedScore(this.vector, this.queryCorrections, targetVectors, node, similarityFunction);
                }
            };
        }

        @Override
        public RandomVectorScorerSupplier copy() throws IOException {
            return new AsymmetricQuantizedRandomVectorScorerSupplier(this.queryVectors.copy(), this.targetVectors.copy(), this.similarityFunction);
        }
    }
}

