/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TimeZone;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.logical.LogicalTableScan;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.catalog.Catalog;
import org.apache.flink.table.catalog.CatalogPartitionSpec;
import org.apache.flink.table.catalog.CatalogTable;
import org.apache.flink.table.catalog.ObjectIdentifier;
import org.apache.flink.table.catalog.ObjectPath;
import org.apache.flink.table.catalog.ResolvedCatalogTable;
import org.apache.flink.table.catalog.exceptions.CatalogException;
import org.apache.flink.table.catalog.exceptions.PartitionNotExistException;
import org.apache.flink.table.catalog.exceptions.TableNotExistException;
import org.apache.flink.table.catalog.exceptions.TableNotPartitionedException;
import org.apache.flink.table.catalog.stats.CatalogColumnStatistics;
import org.apache.flink.table.catalog.stats.CatalogTableStatistics;
import org.apache.flink.table.connector.source.DynamicTableSource;
import org.apache.flink.table.connector.source.abilities.SupportsPartitionPushDown;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.plan.stats.TableStats;
import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.abilities.source.PartitionPushDownSpec;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilityContext;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec;
import org.apache.flink.table.planner.plan.schema.TableSourceTable;
import org.apache.flink.table.planner.plan.stats.FlinkStatistic;
import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil;
import org.apache.flink.table.planner.plan.utils.PartitionPruner;
import org.apache.flink.table.planner.plan.utils.RexNodeExtractor;
import org.apache.flink.table.planner.plan.utils.RexNodeToExpressionConverter;
import org.apache.flink.table.planner.utils.CatalogTableStatisticsConverter;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.types.logical.LogicalType;
import scala.Option;
import scala.Tuple2;
import scala.collection.JavaConversions;
import scala.collection.Seq;

public class PushPartitionIntoTableSourceScanRule
extends RelOptRule {
    public static final PushPartitionIntoTableSourceScanRule INSTANCE = new PushPartitionIntoTableSourceScanRule();

    public PushPartitionIntoTableSourceScanRule() {
        super(PushPartitionIntoTableSourceScanRule.operand(Filter.class, PushPartitionIntoTableSourceScanRule.operand(LogicalTableScan.class, PushPartitionIntoTableSourceScanRule.none()), new RelOptRuleOperand[0]), "PushPartitionIntoTableSourceScanRule");
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Filter filter = (Filter)call.rel(0);
        if (filter.getCondition() == null) {
            return false;
        }
        TableSourceTable tableSourceTable = call.rel(1).getTable().unwrap(TableSourceTable.class);
        if (tableSourceTable == null) {
            return false;
        }
        DynamicTableSource dynamicTableSource = tableSourceTable.tableSource();
        if (!(dynamicTableSource instanceof SupportsPartitionPushDown)) {
            return false;
        }
        CatalogTable catalogTable = (CatalogTable)tableSourceTable.contextResolvedTable().getTable();
        if (!catalogTable.isPartitioned() || catalogTable.getPartitionKeys().isEmpty()) {
            return false;
        }
        return Arrays.stream(tableSourceTable.abilitySpecs()).noneMatch(spec -> spec instanceof PartitionPushDownSpec);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Filter filter = (Filter)call.rel(0);
        LogicalTableScan scan = (LogicalTableScan)call.rel(1);
        TableSourceTable tableSourceTable = scan.getTable().unwrap(TableSourceTable.class);
        RelDataType inputFieldTypes = filter.getInput().getRowType();
        List<String> inputFieldNames = inputFieldTypes.getFieldNames();
        List partitionFieldNames = ((ResolvedCatalogTable)tableSourceTable.contextResolvedTable().getResolvedTable()).getPartitionKeys();
        RelBuilder relBuilder = call.builder();
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        Tuple2<Seq<RexNode>, Seq<RexNode>> allPredicates = RexNodeExtractor.extractPartitionPredicateList(filter.getCondition(), FlinkRelOptUtil.getMaxCnfNodeCount(scan), inputFieldNames.toArray(new String[0]), rexBuilder, partitionFieldNames.toArray(new String[0]));
        RexNode partitionPredicate = RexUtil.composeConjunction(rexBuilder, JavaConversions.seqAsJavaList((Seq)((Seq)allPredicates._1)));
        if (partitionPredicate.isAlwaysTrue()) {
            return;
        }
        LogicalType[] partitionFieldTypes = (LogicalType[])partitionFieldNames.stream().map(name -> {
            int index = inputFieldNames.indexOf(name);
            if (index < 0) {
                throw new TableException(String.format("Partitioned key '%s' isn't found in input columns. Validator should have checked that.", name));
            }
            return inputFieldTypes.getFieldList().get(index).getType();
        }).map(FlinkTypeFactory::toLogicalType).toArray(LogicalType[]::new);
        RexNode finalPartitionPredicate = this.adjustPartitionPredicate(inputFieldNames, partitionFieldNames, partitionPredicate);
        FlinkContext context = ShortcutUtils.unwrapContext(scan);
        Function<List<Map<String, String>>, List<Map<String, String>>> defaultPruner = partitions -> PartitionPruner.prunePartitions(context.getTableConfig(), partitionFieldNames.toArray(new String[0]), partitionFieldTypes, partitions, finalPartitionPredicate);
        List<Map<String, String>> remainingPartitions = this.readPartitionsAndPrune(rexBuilder, context, tableSourceTable, defaultPruner, (Seq<RexNode>)((Seq)allPredicates._1()), inputFieldNames);
        DynamicTableSource dynamicTableSource = tableSourceTable.tableSource().copy();
        PartitionPushDownSpec partitionPushDownSpec = new PartitionPushDownSpec(remainingPartitions);
        partitionPushDownSpec.apply(dynamicTableSource, SourceAbilityContext.from(scan));
        TableStats newTableStat = null;
        if (tableSourceTable.contextResolvedTable().isPermanent()) {
            ObjectIdentifier identifier = tableSourceTable.contextResolvedTable().getIdentifier();
            ObjectPath tablePath = identifier.toObjectPath();
            Catalog catalog = (Catalog)tableSourceTable.contextResolvedTable().getCatalog().get();
            for (Map<String, String> partition : remainingPartitions) {
                Optional<TableStats> partitionStats = this.getPartitionStats(catalog, tablePath, partition);
                if (!partitionStats.isPresent()) {
                    newTableStat = null;
                    break;
                }
                newTableStat = newTableStat == null ? partitionStats.get() : newTableStat.merge(partitionStats.get());
            }
        }
        FlinkStatistic newStatistic = FlinkStatistic.builder().statistic(tableSourceTable.getStatistic()).tableStats(newTableStat).build();
        TableSourceTable newTableSourceTable = tableSourceTable.copy(dynamicTableSource, newStatistic, new SourceAbilitySpec[]{partitionPushDownSpec});
        LogicalTableScan newScan = LogicalTableScan.create(scan.getCluster(), newTableSourceTable, scan.getHints());
        RexNode nonPartitionPredicate = RexUtil.composeConjunction(rexBuilder, JavaConversions.seqAsJavaList((Seq)((Seq)allPredicates._2())));
        if (nonPartitionPredicate.isAlwaysTrue()) {
            call.transformTo(newScan);
        } else {
            Filter newFilter = filter.copy(filter.getTraitSet(), newScan, nonPartitionPredicate);
            call.transformTo(newFilter);
        }
    }

    private RexNode adjustPartitionPredicate(final List<String> inputFieldNames, final List<String> partitionFieldNames, RexNode partitionPredicate) {
        return partitionPredicate.accept(new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef inputRef) {
                int index = inputRef.getIndex();
                String fieldName = (String)inputFieldNames.get(index);
                int newIndex = partitionFieldNames.indexOf(fieldName);
                if (newIndex < 0) {
                    throw new TableException(String.format("Field name '%s' isn't found in partitioned columns. Validator should have checked that.", fieldName));
                }
                if (newIndex == index) {
                    return inputRef;
                }
                return new RexInputRef(newIndex, inputRef.getType());
            }
        });
    }

    private List<Map<String, String>> readPartitionsAndPrune(RexBuilder rexBuilder, FlinkContext context, TableSourceTable tableSourceTable, Function<List<Map<String, String>>, List<Map<String, String>>> pruner, Seq<RexNode> partitionPredicate, List<String> inputFieldNames) {
        Optional catalogOptional = tableSourceTable.contextResolvedTable().getCatalog();
        DynamicTableSource dynamicTableSource = tableSourceTable.tableSource();
        Optional optionalPartitions = ((SupportsPartitionPushDown)dynamicTableSource).listPartitions();
        if (optionalPartitions.isPresent()) {
            return pruner.apply((List<Map<String, String>>)optionalPartitions.get());
        }
        if (!catalogOptional.isPresent()) {
            throw new TableException(String.format("Table '%s' connector doesn't provide partitions, and it cannot be loaded from the catalog", tableSourceTable.contextResolvedTable().getIdentifier().asSummaryString()));
        }
        try {
            return this.readPartitionFromCatalogAndPrune(rexBuilder, context, (Catalog)catalogOptional.get(), tableSourceTable.contextResolvedTable().getIdentifier(), inputFieldNames, partitionPredicate, pruner);
        }
        catch (TableNotExistException tableNotExistException) {
            throw new TableException(String.format("Table %s is not found in catalog.", tableSourceTable.contextResolvedTable().getIdentifier().asSummaryString()));
        }
        catch (TableNotPartitionedException tableNotPartitionedException) {
            throw new TableException(String.format("Table %s is not a partitionable source. Validator should have checked it.", tableSourceTable.contextResolvedTable().getIdentifier().asSummaryString()), (Throwable)tableNotPartitionedException);
        }
    }

    private List<Map<String, String>> readPartitionFromCatalogAndPrune(RexBuilder rexBuilder, FlinkContext context, Catalog catalog, ObjectIdentifier tableIdentifier, List<String> allFieldNames, Seq<RexNode> partitionPredicate, Function<List<Map<String, String>>, List<Map<String, String>>> pruner) throws TableNotExistException, TableNotPartitionedException {
        ObjectPath tablePath = tableIdentifier.toObjectPath();
        RexNodeToExpressionConverter converter = new RexNodeToExpressionConverter(rexBuilder, allFieldNames.toArray(new String[0]), context.getFunctionCatalog(), context.getCatalogManager(), TimeZone.getTimeZone(context.getTableConfig().getLocalTimeZone()));
        ArrayList<Object> partitionFilters = new ArrayList<Object>();
        for (RexNode node : JavaConversions.seqAsJavaList(partitionPredicate)) {
            Option<ResolvedExpression> subExpr = node.accept(converter);
            if (!subExpr.isEmpty()) {
                partitionFilters.add(subExpr.get());
                continue;
            }
            return this.readPartitionFromCatalogWithoutFilterAndPrune(catalog, tablePath, pruner);
        }
        try {
            return catalog.listPartitionsByFilter(tablePath, partitionFilters).stream().map(CatalogPartitionSpec::getPartitionSpec).collect(Collectors.toList());
        }
        catch (UnsupportedOperationException e) {
            return this.readPartitionFromCatalogWithoutFilterAndPrune(catalog, tablePath, pruner);
        }
    }

    private List<Map<String, String>> readPartitionFromCatalogWithoutFilterAndPrune(Catalog catalog, ObjectPath tablePath, Function<List<Map<String, String>>, List<Map<String, String>>> pruner) throws TableNotExistException, CatalogException, TableNotPartitionedException {
        List allPartitions = catalog.listPartitions(tablePath).stream().map(CatalogPartitionSpec::getPartitionSpec).collect(Collectors.toList());
        return pruner.apply(allPartitions);
    }

    private Optional<TableStats> getPartitionStats(Catalog catalog, ObjectPath tablePath, Map<String, String> partition) {
        try {
            CatalogPartitionSpec spec = new CatalogPartitionSpec(partition);
            CatalogTableStatistics partitionStat = catalog.getPartitionStatistics(tablePath, spec);
            CatalogColumnStatistics partitionColStat = catalog.getPartitionColumnStatistics(tablePath, spec);
            TableStats stats = CatalogTableStatisticsConverter.convertToTableStats(partitionStat, partitionColStat);
            return Optional.of(stats);
        }
        catch (PartitionNotExistException e) {
            return Optional.empty();
        }
    }
}

