/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.MarkDistinctNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
import org.apache.tsfile.read.common.type.BooleanType;
import org.apache.tsfile.read.common.type.Type;

public class MultipleDistinctAggregationToMarkDistinct
implements Rule<AggregationNode> {
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching((Predicate<AggregationNode>)Predicates.and(MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask, (com.google.common.base.Predicate)Predicates.or(MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts, MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts)));

    private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().noneMatch(aggregation -> aggregation.isDistinct() && (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()));
    }

    private static boolean hasMultipleDistincts(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).map(AggregationNode.Aggregation::getArguments).map(HashSet::new).distinct().count() > 1L;
    }

    private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode) {
        long distincts = aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).count();
        return distincts > 0L && distincts < (long)aggregationNode.getAggregations().size();
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode parent, Captures captures, Rule.Context context) {
        if (!this.shouldAddMarkDistinct(parent, context)) {
            return Rule.Result.empty();
        }
        HashMap markers = new HashMap();
        HashMap<Symbol, AggregationNode.Aggregation> newAggregations = new HashMap<Symbol, AggregationNode.Aggregation>();
        PlanNode subPlan = parent.getChild();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            if (aggregation.isDistinct() && !aggregation.getFilter().isPresent() && !aggregation.getMask().isPresent()) {
                Set inputs = aggregation.getArguments().stream().map(Symbol::from).collect(Collectors.toSet());
                Symbol marker = (Symbol)markers.get(inputs);
                if (marker == null) {
                    marker = context.getSymbolAllocator().newSymbol(((Symbol)Iterables.getLast(inputs)).getName(), (Type)BooleanType.BOOLEAN, "distinct");
                    markers.put(inputs, marker);
                    ImmutableSet.Builder distinctSymbols = ImmutableSet.builder().addAll(parent.getGroupingKeys()).addAll(inputs);
                    parent.getGroupIdSymbol().ifPresent(arg_0 -> ((ImmutableSet.Builder)distinctSymbols).add(arg_0));
                    subPlan = new MarkDistinctNode(context.getIdAllocator().genPlanNodeId(), subPlan, marker, (List<Symbol>)ImmutableList.copyOf((Collection)distinctSymbols.build()), Optional.empty());
                }
                newAggregations.put(entry.getKey(), new AggregationNode.Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), false, aggregation.getFilter(), aggregation.getOrderingScheme(), Optional.of(marker)));
                continue;
            }
            newAggregations.put(entry.getKey(), aggregation);
        }
        return Rule.Result.ofPlanNode(AggregationNode.builderFrom(parent).setSource(subPlan).setAggregations(newAggregations).setPreGroupedSymbols((List<Symbol>)ImmutableList.of()).build());
    }

    private boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Rule.Context context) {
        if (aggregationNode.getGroupingKeys().isEmpty()) {
            return true;
        }
        return aggregationNode.getGroupingKeys().size() > 1;
    }

    private static boolean hasSingleDistinctAndNonDistincts(AggregationNode aggregationNode) {
        long distincts = aggregationNode.getAggregations().values().stream().filter(AggregationNode.Aggregation::isDistinct).count();
        return distincts == 1L && distincts < (long)aggregationNode.getAggregations().size();
    }
}

