/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import lombok.Generated;
import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.common.TriConsumer;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationUtils;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationUtil;

public class RRFNormalizationTechnique
implements ScoreNormalizationTechnique,
ExplainableTechnique {
    public static final String TECHNIQUE_NAME = "rrf";
    public static final int DEFAULT_RANK_CONSTANT = 60;
    public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant";
    private static final Set<String> SUPPORTED_PARAMS = Set.of("rank_constant");
    private static final int MIN_RANK_CONSTANT = 1;
    private static final int MAX_RANK_CONSTANT = 10000;
    private static final Range<Integer> RANK_CONSTANT_RANGE = Range.of((Comparable)Integer.valueOf(1), (Comparable)Integer.valueOf(10000));
    private final int rankConstant;

    public RRFNormalizationTechnique(Map<String, Object> params, ScoreNormalizationUtil scoreNormalizationUtil) {
        scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS);
        this.rankConstant = this.getRankConstant(params);
    }

    @Override
    public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
        List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            this.processTopDocs(compoundQueryTopDocs, (TriConsumer<DocIdAtSearchShard, Float, Integer>)((TriConsumer)(docId, score, subQueryIndex) -> {}));
        }
    }

    @Override
    public String describe() {
        return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, this.rankConstant);
    }

    @Override
    public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
        HashMap<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<DocIdAtSearchShard, List<Float>>();
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int numberOfSubQueries = topDocsPerSubQuery.size();
            this.processTopDocs(compoundQueryTopDocs, (TriConsumer<DocIdAtSearchShard, Float, Integer>)((TriConsumer)(docId, score, subQueryIndex) -> ScoreNormalizationUtil.setNormalizedScore(normalizedScores, docId, subQueryIndex, numberOfSubQueries, score.floatValue())));
        }
        return ExplanationUtils.getDocIdAtQueryForNormalization(normalizedScores, this);
    }

    private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor) {
        if (Objects.isNull(compoundQueryTopDocs)) {
            return;
        }
        List<TopDocs> topDocsList = compoundQueryTopDocs.getTopDocs();
        SearchShard searchShard = compoundQueryTopDocs.getSearchShard();
        for (int topDocsIndex = 0; topDocsIndex < topDocsList.size(); ++topDocsIndex) {
            this.processTopDocsEntry(topDocsList.get(topDocsIndex), searchShard, topDocsIndex, scoreProcessor);
        }
    }

    private void processTopDocsEntry(TopDocs topDocs, SearchShard searchShard, int topDocsIndex, TriConsumer<DocIdAtSearchShard, Float, Integer> scoreProcessor) {
        for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
            float normalizedScore = this.calculateNormalizedScore(Arrays.asList(topDocs.scoreDocs).indexOf(scoreDoc));
            DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, searchShard);
            scoreProcessor.apply((Object)docIdAtSearchShard, (Object)Float.valueOf(normalizedScore), (Object)topDocsIndex);
            scoreDoc.score = normalizedScore;
        }
    }

    private float calculateNormalizedScore(int position) {
        return BigDecimal.ONE.divide(BigDecimal.valueOf(this.rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue();
    }

    private int getRankConstant(Map<String, Object> params) {
        if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_RANK_CONSTANT)) {
            return 60;
        }
        int rankConstant = RRFNormalizationTechnique.getParamAsInteger(params, PARAM_NAME_RANK_CONSTANT);
        this.validateRankConstant(rankConstant);
        return rankConstant;
    }

    private void validateRankConstant(int rankConstant) {
        if (!RANK_CONSTANT_RANGE.contains((Object)rankConstant)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "rank constant must be in the interval between 1 and 10000, submitted rank constant: %d", rankConstant));
        }
    }

    private static int getParamAsInteger(Map<String, Object> parameters, String fieldName) {
        try {
            return NumberUtils.createInteger((String)String.valueOf(parameters.get(fieldName)));
        }
        catch (NumberFormatException e) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "parameter [%s] must be an integer", fieldName));
        }
    }

    @Generated
    public String toString() {
        return "RRFNormalizationTechnique(TECHNIQUE_NAME=rrf, rankConstant=" + this.rankConstant + ")";
    }
}

