/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.rules;

import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.AbstractRelNode;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.plan.OpenSearchRuleConfig;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
import org.opensearch.sql.expression.function.udf.binning.WidthBucketFunction;
import org.opensearch.sql.opensearch.planner.rules.ImmutableAggregateIndexScanRule;
import org.opensearch.sql.opensearch.planner.rules.InterruptibleRelRule;
import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan;
import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan;

@Value.Enclosing
public class AggregateIndexScanRule
extends InterruptibleRelRule<Config> {
    protected AggregateIndexScanRule(Config config) {
        super(config);
    }

    @Override
    protected void onMatchImpl(RelOptRuleCall call) {
        if (call.rels.length == 5) {
            List newProjects;
            LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
            LogicalProject topProject = (LogicalProject)call.rel(1);
            LogicalFilter filter = (LogicalFilter)call.rel(2);
            LogicalProject bottomProject = (LogicalProject)call.rel(3);
            CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan)call.rel(4);
            boolean ignoreNullBucket = Config.aggIgnoreNullBucket.test((Aggregate)aggregate);
            List<Integer> groupRefList = aggregate.getGroupSet().asList().stream().map(topProject.getProjects()::get).filter(rex -> ignoreNullBucket || this.isTimeSpan((RexNode)rex)).flatMap(expr -> PlanUtils.getInputRefs(expr).stream()).map(RexSlot::getIndex).toList();
            if (this.isNotNullDerivedFromAgg(filter, groupRefList) && (newProjects = RelOptUtil.pushPastProjectUnlessBloat((List)topProject.getProjects(), (Project)bottomProject, (int)100)) != null) {
                RelBuilder relBuilder = call.builder();
                relBuilder.push((RelNode)scan);
                relBuilder.project((Iterable)newProjects, (Iterable)topProject.getRowType().getFieldNames());
                RelNode node = relBuilder.build();
                if (node instanceof LogicalProject) {
                    LogicalProject newProject = (LogicalProject)node;
                    this.apply(call, aggregate, newProject, scan);
                } else if (node.equals((Object)scan)) {
                    this.apply(call, aggregate, null, scan);
                }
            }
        } else if (call.rels.length == 4) {
            LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
            LogicalFilter filter = (LogicalFilter)call.rel(1);
            LogicalProject project = (LogicalProject)call.rel(2);
            CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan)call.rel(3);
            List groupList = aggregate.getGroupSet().asList();
            if (this.isNotNullDerivedFromAgg(filter, groupList)) {
                this.apply(call, aggregate, project, scan);
            }
        } else if (call.rels.length == 3) {
            LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
            LogicalProject project = (LogicalProject)call.rel(1);
            CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan)call.rel(2);
            this.apply(call, aggregate, project, scan);
        } else if (call.rels.length == 2) {
            LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
            CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan)call.rel(1);
            this.apply(call, aggregate, null, scan);
        } else {
            throw new AssertionError((Object)String.format("The length of rels should be %s but got %s", this.operands.size(), call.rels.length));
        }
    }

    private boolean isTimeSpan(RexNode rex) {
        RexLiteral unitLiteral;
        Object e;
        RexCall rexCall;
        return rex instanceof RexCall && (rexCall = (RexCall)rex).getKind() == SqlKind.OTHER_FUNCTION && rexCall.getOperator().getName().equalsIgnoreCase(BuiltinFunctionName.SPAN.name()) && rexCall.getOperands().size() == 3 && (e = rexCall.getOperands().get(2)) instanceof RexLiteral && (unitLiteral = (RexLiteral)e).getTypeName() != SqlTypeName.NULL;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private boolean isNotNullDerivedFromAgg(LogicalFilter filter, List<Integer> groupRefList) {
        Function<RexNode, Boolean> isNotNullFromAgg = rex -> {
            RexInputRef ref;
            Object patt0$temp;
            RexCall rexCall;
            return rex instanceof RexCall && (rexCall = (RexCall)rex).isA(SqlKind.IS_NOT_NULL) && (patt0$temp = rexCall.getOperands().get(0)) instanceof RexInputRef && groupRefList.contains((ref = (RexInputRef)patt0$temp).getIndex());
        };
        RexNode condition = filter.getCondition();
        if (isNotNullFromAgg.apply(condition) != false) return true;
        if (!(condition instanceof RexCall)) return false;
        RexCall rexCall = (RexCall)condition;
        if (rexCall.getOperator() != SqlStdOperatorTable.AND) return false;
        if (!rexCall.getOperands().stream().allMatch(isNotNullFromAgg::apply)) return false;
        return true;
    }

    protected void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project, CalciteLogicalIndexScan scan) {
        AbstractRelNode newRelNode = scan.pushDownAggregate((Aggregate)aggregate, (Project)project);
        if (newRelNode != null) {
            call.transformTo((RelNode)newRelNode);
        }
    }

    @Value.Immutable
    public static interface Config
    extends OpenSearchRuleConfig {
        public static final Config DEFAULT = ImmutableAggregateIndexScanRule.Config.builder().build().withDescription("Agg-Project-TableScan").withOperandSupplier(b0 -> b0.operand(LogicalAggregate.class).oneInput(b1 -> b1.operand(LogicalProject.class).predicate(Predicate.not(Project::containsOver).and(PlanUtils::distinctProjectList).or(Config::containsWidthBucketFuncOnDate)).oneInput(b2 -> b2.operand(CalciteLogicalIndexScan.class).predicate(Predicate.not(AbstractCalciteIndexScan::isLimitPushed).and(AbstractCalciteIndexScan::noAggregatePushed)).noInputs())));
        public static final Config COUNT_STAR = ImmutableAggregateIndexScanRule.Config.builder().build().withDescription("Agg[count()]-TableScan").withOperandSupplier(b0 -> b0.operand(LogicalAggregate.class).predicate(agg -> agg.getGroupSet().isEmpty() && agg.getAggCallList().stream().allMatch(call -> call.getAggregation().kind == SqlKind.COUNT && call.getArgList().isEmpty())).oneInput(b1 -> b1.operand(CalciteLogicalIndexScan.class).predicate(Predicate.not(AbstractCalciteIndexScan::isLimitPushed).and(AbstractCalciteIndexScan::noAggregatePushed)).noInputs()));
        public static final Predicate<Aggregate> aggIgnoreNullBucket = agg -> agg.getHints().stream().anyMatch(hint -> hint.hintName.equals("stats_args") && ((String)hint.kvOptions.get("bucket_nullable")).equals("false"));
        public static final Predicate<Aggregate> maybeTimeSpanAgg = agg -> agg.getGroupSet().stream().allMatch(group -> OpenSearchTypeFactory.isTimeBasedType(((RelDataTypeField)agg.getInput().getRowType().getFieldList().get(group)).getType()));
        public static final Config BUCKET_NON_NULL_AGG = ImmutableAggregateIndexScanRule.Config.builder().build().withDescription("Agg-Filter-Project-TableScan").withOperandSupplier(b0 -> b0.operand(LogicalAggregate.class).predicate(aggIgnoreNullBucket).oneInput(b1 -> b1.operand(LogicalFilter.class).predicate(PlanUtils::mayBeFilterFromBucketNonNull).oneInput(b2 -> b2.operand(LogicalProject.class).predicate(Predicate.not(Project::containsOver).and(PlanUtils::distinctProjectList).or(Config::containsWidthBucketFuncOnDate)).oneInput(b3 -> b3.operand(CalciteLogicalIndexScan.class).predicate(Predicate.not(AbstractCalciteIndexScan::isLimitPushed).and(AbstractCalciteIndexScan::noAggregatePushed)).noInputs()))));
        public static final Config BUCKET_NON_NULL_AGG_WITH_UDF = ImmutableAggregateIndexScanRule.Config.builder().build().withDescription("Agg-Project-Filter-Project-TableScan").withOperandSupplier(b0 -> b0.operand(LogicalAggregate.class).predicate(aggIgnoreNullBucket.or(maybeTimeSpanAgg)).oneInput(b1 -> b1.operand(LogicalProject.class).predicate(Predicate.not(Project::containsOver).and(PlanUtils::distinctProjectList)).oneInput(b2 -> b2.operand(LogicalFilter.class).predicate(PlanUtils::mayBeFilterFromBucketNonNull).oneInput(b3 -> b3.operand(LogicalProject.class).predicate(Predicate.not(Project::containsOver).and(PlanUtils::distinctProjectList).or(Config::containsWidthBucketFuncOnDate)).oneInput(b4 -> b4.operand(CalciteLogicalIndexScan.class).predicate(Predicate.not(AbstractCalciteIndexScan::isLimitPushed).and(AbstractCalciteIndexScan::noAggregatePushed)).noInputs())))));

        default public AggregateIndexScanRule toRule() {
            return new AggregateIndexScanRule(this);
        }

        public static boolean containsWidthBucketFuncOnDate(LogicalProject project) {
            return project.getProjects().stream().anyMatch(expr -> {
                RexCall rexCall;
                return expr instanceof RexCall && (rexCall = (RexCall)expr).getOperator().equals((Object)PPLBuiltinOperators.WIDTH_BUCKET) && WidthBucketFunction.dateRelatedType(((RexNode)rexCall.getOperands().getFirst()).getType());
            });
        }
    }
}

