diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/plan/RelOptUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/plan/RelOptUtil.java new file mode 100644 index 0000000000000..05fdcfc7ed589 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/plan/RelOptUtil.java @@ -0,0 +1,4567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.plan; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.LinkedHashMultimap; +import com.google.common.collect.Lists; +import com.google.common.collect.Multimap; +import org.apache.calcite.adapter.enumerable.EnumerableRules; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.config.CalciteSystemProperty; +import org.apache.calcite.interpreter.Bindables; +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.linq4j.function.Experimental; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelHomogeneousShuttle; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.RelVisitor; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.core.TableScan; +import org.apache.calcite.rel.externalize.RelDotWriter; +import org.apache.calcite.rel.externalize.RelJsonWriter; +import org.apache.calcite.rel.externalize.RelWriterImpl; +import org.apache.calcite.rel.externalize.RelXmlWriter; +import org.apache.calcite.rel.hint.HintStrategyTable; +import org.apache.calcite.rel.hint.Hintable; +import org.apache.calcite.rel.hint.RelHint; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalCalc; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.rel.rules.MultiJoin; +import org.apache.calcite.rel.stream.StreamRules; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rel.type.RelDataTypeFieldImpl; +import org.apache.calcite.rex.LogicVisitor; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexExecutor; +import org.apache.calcite.rex.RexExecutorImpl; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexLocalRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexOver; +import org.apache.calcite.rex.RexProgram; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSqlStandardConvertletTable; +import org.apache.calcite.rex.RexSubQuery; +import org.apache.calcite.rex.RexToSqlNodeConverter; +import org.apache.calcite.rex.RexToSqlNodeConverterImpl; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.runtime.CalciteContextException; +import org.apache.calcite.runtime.PairList; +import org.apache.calcite.schema.ModifiableView; +import org.apache.calcite.sql.SqlExplainFormat; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.MultisetSqlType; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.tools.RelBuilderFactory; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Litmus; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Permutation; +import org.apache.calcite.util.Util; +import org.apache.calcite.util.mapping.Mapping; +import org.apache.calcite.util.mapping.MappingType; +import org.apache.calcite.util.mapping.Mappings; +import org.checkerframework.checker.initialization.qual.NotOnlyInitialized; +import org.checkerframework.checker.initialization.qual.UnknownInitialization; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.AbstractList; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.Collection; +import java.util.Comparator; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.NavigableSet; +import java.util.Set; +import java.util.TreeSet; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.rel.type.RelDataTypeImpl.NON_NULLABLE_SUFFIX; + +/** + * RelOptUtil defines static utility methods for use in optimizing {@link RelNode}s. + * + *

FLINK modifications (backport of CALCITE-6764): Lines 2074 ~ 2106 + */ +public abstract class RelOptUtil { + // ~ Static fields/initializers --------------------------------------------- + + public static final double EPSILON = 1.0e-5; + + @SuppressWarnings("Guava") + @Deprecated // to be removed before 2.0 + public static final com.google.common.base.Predicate FILTER_PREDICATE = + f -> !f.containsOver(); + + @SuppressWarnings("Guava") + @Deprecated // to be removed before 2.0 + public static final com.google.common.base.Predicate PROJECT_PREDICATE = + RelOptUtil::notContainsWindowedAgg; + + @SuppressWarnings("Guava") + @Deprecated // to be removed before 2.0 + public static final com.google.common.base.Predicate CALC_PREDICATE = + RelOptUtil::notContainsWindowedAgg; + + // ~ Methods ---------------------------------------------------------------- + + /** Whether this node is a limit without sort specification. */ + public static boolean isPureLimit(RelNode rel) { + return isLimit(rel) && !isOrder(rel); + } + + /** Whether this node is a sort without limit specification. */ + public static boolean isPureOrder(RelNode rel) { + return !isLimit(rel) && isOrder(rel); + } + + /** Whether this node contains a limit specification. */ + public static boolean isLimit(RelNode rel) { + return (rel instanceof Sort) && ((Sort) rel).fetch != null; + } + + /** Whether this node contains a sort specification. */ + public static boolean isOrder(RelNode rel) { + return (rel instanceof Sort) && !((Sort) rel).getCollation().getFieldCollations().isEmpty(); + } + + /** Whether this node contains a offset specification. */ + public static boolean isOffset(RelNode rel) { + return (rel instanceof Sort) && ((Sort) rel).offset != null; + } + + /** Returns a set of tables used by this expression or its children. */ + public static Set findTables(RelNode rel) { + return new LinkedHashSet<>(findAllTables(rel)); + } + + /** Returns a list of all tables used by this expression or its children. */ + public static List findAllTables(RelNode rel) { + final Multimap, RelNode> nodes = + rel.getCluster().getMetadataQuery().getNodeTypes(rel); + final List usedTables = new ArrayList<>(); + if (nodes == null) { + return usedTables; + } + for (Map.Entry, Collection> e : + nodes.asMap().entrySet()) { + if (TableScan.class.isAssignableFrom(e.getKey())) { + for (RelNode node : e.getValue()) { + TableScan scan = (TableScan) node; + usedTables.add(scan.getTable()); + } + } + } + return usedTables; + } + + /** Returns a list of all table qualified names used by this expression or its children. */ + public static List findAllTableQualifiedNames(RelNode rel) { + return findAllTables(rel).stream() + .map(table -> table.getQualifiedName().toString()) + .collect(Collectors.toList()); + } + + /** Returns a list of variables set by a relational expression or its descendants. */ + public static Set getVariablesSet(RelNode rel) { + VariableSetVisitor visitor = new VariableSetVisitor(); + go(visitor, rel); + return visitor.variables; + } + + @Deprecated // to be removed before 2.0 + @SuppressWarnings("MixedMutabilityReturnType") + public static List getVariablesSetAndUsed(RelNode rel0, RelNode rel1) { + Set set = getVariablesSet(rel0); + if (set.size() == 0) { + return ImmutableList.of(); + } + Set used = getVariablesUsed(rel1); + if (used.size() == 0) { + return ImmutableList.of(); + } + final List result = new ArrayList<>(); + for (CorrelationId s : set) { + if (used.contains(s) && !result.contains(s)) { + result.add(s); + } + } + return result; + } + + /** + * Returns the set of variables used by a relational expression or its descendants. + * + *

The set may contain "duplicates" (variables with different ids that, when resolved, will + * reference the same source relational expression). + * + *

The item type is the same as {@link org.apache.calcite.rex.RexCorrelVariable#id}. + */ + public static Set getVariablesUsed(RelNode rel) { + CorrelationCollector visitor = new CorrelationCollector(); + rel.accept(visitor); + return visitor.vuv.variables; + } + + /** + * Returns the set of variables used by the given list of sub-queries and its descendants. + * + * @param subQueries The sub-queries containing correlation variables + * @return A list of correlation identifiers found within the sub-queries. The type of the + * [CorrelationId] parameter corresponds to {@link + * org.apache.calcite.rex.RexCorrelVariable#id}. + */ + public static Set getVariablesUsed(List subQueries) { + // Internally this function calls getVariablesUsed on a RelNode to get all the + // correlated variables in that RelNode + Set correlationIds = new HashSet<>(); + for (RexSubQuery subQ : subQueries) { + correlationIds.addAll(getVariablesUsed(subQ.rel)); + } + return correlationIds; + } + + /** Finds which columns of a correlation variable are used within a relational expression. */ + public static ImmutableBitSet correlationColumns(CorrelationId id, RelNode rel) { + final CorrelationCollector collector = new CorrelationCollector(); + rel.accept(collector); + final ImmutableBitSet.Builder builder = ImmutableBitSet.builder(); + for (int field : collector.vuv.variableFields.get(id)) { + if (field >= 0) { + builder.set(field); + } + } + return builder.build(); + } + + /** + * Returns true, and calls {@link Litmus#succeed()} if a given relational expression does not + * contain a given correlation. + */ + public static boolean notContainsCorrelation( + RelNode r, CorrelationId correlationId, Litmus litmus) { + final Set set = getVariablesUsed(r); + if (!set.contains(correlationId)) { + return litmus.succeed(); + } else { + return litmus.fail("contains {}", correlationId); + } + } + + /** Sets a {@link RelVisitor} going on a given relational expression, and returns the result. */ + public static void go(RelVisitor visitor, RelNode p) { + try { + visitor.go(p); + } catch (Exception e) { + throw new RuntimeException("while visiting tree", e); + } + } + + /** + * Returns a list of the types of the fields in a given struct type. The list is immutable. + * + * @param type Struct type + * @return List of field types + * @see org.apache.calcite.rel.type.RelDataType#getFieldNames() + */ + public static List getFieldTypeList(final RelDataType type) { + return Util.transform(type.getFieldList(), RelDataTypeField::getType); + } + + public static boolean areRowTypesEqual( + RelDataType rowType1, RelDataType rowType2, boolean compareNames) { + if (rowType1 == rowType2) { + return true; + } + if (compareNames) { + // if types are not identity-equal, then either the names or + // the types must be different + return false; + } + if (rowType2.getFieldCount() != rowType1.getFieldCount()) { + return false; + } + final List f1 = rowType1.getFieldList(); + final List f2 = rowType2.getFieldList(); + for (Pair pair : Pair.zip(f1, f2)) { + final RelDataType type1 = pair.left.getType(); + final RelDataType type2 = pair.right.getType(); + // If one of the types is ANY comparison should succeed + if (type1.getSqlTypeName() == SqlTypeName.ANY + || type2.getSqlTypeName() == SqlTypeName.ANY) { + continue; + } + if (!type1.equals(type2)) { + return false; + } + } + return true; + } + + /** + * Verifies that a row type being added to an equivalence class matches the existing type, + * raising an assertion if this is not the case. + * + * @param originalRel canonical rel for equivalence class + * @param newRel rel being added to equivalence class + * @param equivalenceClass object representing equivalence class + */ + public static void verifyTypeEquivalence( + RelNode originalRel, RelNode newRel, Object equivalenceClass) { + RelDataType expectedRowType = originalRel.getRowType(); + RelDataType actualRowType = newRel.getRowType(); + + // Row types must be the same, except for field names. + if (areRowTypesEqual(expectedRowType, actualRowType, false)) { + return; + } + + String s = + "Cannot add expression of different type to set:\n" + + "set type is " + + expectedRowType.getFullTypeString() + + "\nexpression type is " + + actualRowType.getFullTypeString() + + "\nset is " + + equivalenceClass.toString() + + "\nexpression is " + + RelOptUtil.toString(newRel) + + getFullTypeDifferenceString( + "rowtype of original rel", + expectedRowType, + "rowtype of new rel", + actualRowType); + throw new AssertionError(s); + } + + /** + * Copy the {@link org.apache.calcite.rel.hint.RelHint}s from {@code originalRel} to {@code + * newRel} if both of them are {@link Hintable}. + * + *

The two relational expressions are assumed as semantically equivalent, that means the + * hints should be attached to the relational expression that expects to have them. + * + *

Try to propagate the hints to the first relational expression that matches, this is needed + * because many planner rules would generate a sub-tree whose root rel type is different with + * the original matched rel. + * + *

For the worst case, there is no relational expression that can apply these hints, and the + * whole sub-tree would be visited. We add a protection here: if the visiting depth is over than + * 3, just returns, because there are rare cases the new created sub-tree has layers bigger than + * that. + * + *

This is a best effort, we do not know exactly how the nodes are transformed in all kinds + * of planner rules, so for some complex relational expressions, the hints would very probably + * lost. + * + *

This function is experimental and would change without any notes. + * + * @param originalRel Original relational expression + * @param equiv New equivalent relational expression + * @return A copy of {@code newRel} with attached qualified hints from {@code originalRel}, or + * {@code newRel} directly if one of them are not {@link Hintable} + */ + @Experimental + public static RelNode propagateRelHints(RelNode originalRel, RelNode equiv) { + if (!(originalRel instanceof Hintable) || ((Hintable) originalRel).getHints().size() == 0) { + return equiv; + } + final RelShuttle shuttle = + new SubTreeHintPropagateShuttle( + originalRel.getCluster().getHintStrategies(), + ((Hintable) originalRel).getHints()); + return equiv.accept(shuttle); + } + + /** + * Propagates the relational expression hints from root node to leaf node. + * + * @param rel The relational expression + * @param reset Flag saying if to reset the existing hints before the propagation + * @return New relational expression with hints propagated + */ + public static RelNode propagateRelHints(RelNode rel, boolean reset) { + if (reset) { + rel = rel.accept(new ResetHintsShuttle()); + } + final RelShuttle shuttle = + new RelHintPropagateShuttle(rel.getCluster().getHintStrategies()); + return rel.accept(shuttle); + } + + /** + * Copy the {@link org.apache.calcite.rel.hint.RelHint}s from {@code originalRel} to {@code + * newRel} if both of them are {@link Hintable}. + * + *

The hints would be attached directly(e.g. without any filtering). + * + * @param originalRel Original relational expression + * @param newRel New relational expression + * @return A copy of {@code newRel} with attached hints from {@code originalRel}, or {@code + * newRel} directly if one of them are not {@link Hintable} + */ + public static RelNode copyRelHints(RelNode originalRel, RelNode newRel) { + return copyRelHints(originalRel, newRel, false); + } + + /** + * Copy the {@link org.apache.calcite.rel.hint.RelHint}s from {@code originalRel} to {@code + * newRel} if both of them are {@link Hintable}. + * + *

The hints would be filtered by the specified hint strategies if {@code filterHints} is + * true. + * + * @param originalRel Original relational expression + * @param newRel New relational expression + * @param filterHints Flag saying if to filter out unqualified hints for {@code newRel} + * @return A copy of {@code newRel} with attached hints from {@code originalRel}, or {@code + * newRel} directly if one of them are not {@link Hintable} + */ + public static RelNode copyRelHints(RelNode originalRel, RelNode newRel, boolean filterHints) { + if (originalRel == newRel && !filterHints) { + return originalRel; + } + + if (originalRel instanceof Hintable + && newRel instanceof Hintable + && ((Hintable) originalRel).getHints().size() > 0) { + final List hints = ((Hintable) originalRel).getHints(); + if (filterHints) { + HintStrategyTable hintStrategies = originalRel.getCluster().getHintStrategies(); + return ((Hintable) newRel).attachHints(hintStrategies.apply(hints, newRel)); + } else { + // Keep all the hints if filterHints is false for 2 reasons: + // 1. Keep sync with the hints propagation logic, + // see RelHintPropagateShuttle for details. + // 2. We may re-propagate these hints when decorrelating a query. + return ((Hintable) newRel).attachHints(hints); + } + } + return newRel; + } + + /** + * Returns a permutation describing where output fields come from. In the returned map, value of + * {@code map.getTargetOpt(i)} is {@code n} if field {@code i} projects input field {@code n} or + * applies a cast on {@code n}, -1 if it is another expression. + */ + public static Mappings.TargetMapping permutationIgnoreCast( + List nodes, RelDataType inputRowType) { + final Mappings.TargetMapping mapping = + Mappings.create( + MappingType.PARTIAL_FUNCTION, nodes.size(), inputRowType.getFieldCount()); + for (Ord node : Ord.zip(nodes)) { + if (node.e instanceof RexInputRef) { + mapping.set(node.i, ((RexInputRef) node.e).getIndex()); + } else if (node.e.isA(SqlKind.CAST)) { + final RexNode operand = ((RexCall) node.e).getOperands().get(0); + if (operand instanceof RexInputRef) { + mapping.set(node.i, ((RexInputRef) operand).getIndex()); + } + } + } + return mapping; + } + + /** + * Returns a permutation describing where output fields come from. In the returned map, value of + * {@code map.getTargetOpt(i)} is {@code n} if field {@code i} projects input field {@code n}, + * -1 if it is an expression. + */ + public static Mappings.TargetMapping permutation( + List nodes, RelDataType inputRowType) { + final Mappings.TargetMapping mapping = + Mappings.create( + MappingType.PARTIAL_FUNCTION, nodes.size(), inputRowType.getFieldCount()); + for (Ord node : Ord.zip(nodes)) { + if (node.e instanceof RexInputRef) { + mapping.set(node.i, ((RexInputRef) node.e).getIndex()); + } + } + return mapping; + } + + /** + * Returns a permutation describing where the Project's fields come from after the Project is + * pushed down. + */ + public static Mappings.TargetMapping permutationPushDownProject( + List nodes, RelDataType inputRowType, int sourceOffset, int targetOffset) { + final Mappings.TargetMapping mapping = + Mappings.create( + MappingType.PARTIAL_FUNCTION, + inputRowType.getFieldCount() + sourceOffset, + nodes.size() + targetOffset); + for (Ord node : Ord.zip(nodes)) { + if (node.e instanceof RexInputRef) { + mapping.set( + ((RexInputRef) node.e).getIndex() + sourceOffset, node.i + targetOffset); + } + } + return mapping; + } + + @Deprecated // to be removed before 2.0 + public static RelNode createExistsPlan( + RelOptCluster cluster, + RelNode seekRel, + @Nullable List conditions, + @Nullable RexLiteral extraExpr, + @Nullable String extraName) { + assert extraExpr == null || extraName != null; + RelNode ret = seekRel; + + if ((conditions != null) && (conditions.size() > 0)) { + RexNode conditionExp = + RexUtil.composeConjunction(cluster.getRexBuilder(), conditions, true); + + if (conditionExp != null) { + final RelFactories.FilterFactory factory = RelFactories.DEFAULT_FILTER_FACTORY; + ret = factory.createFilter(ret, conditionExp, ImmutableSet.of()); + } + } + + if (extraExpr != null) { + RexBuilder rexBuilder = cluster.getRexBuilder(); + + assert extraExpr == rexBuilder.makeLiteral(true); + + // this should only be called for the exists case + // first stick an Agg on top of the sub-query + // agg does not like no agg functions so just pretend it is + // doing a min(TRUE) + + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(cluster, null); + ret = + relBuilder + .push(ret) + .project(extraExpr) + .aggregate( + relBuilder.groupKey(), + relBuilder.min(relBuilder.field(0)).as(extraName)) + .build(); + } + + return ret; + } + + @Deprecated // to be removed before 2.0 + public static Exists createExistsPlan( + RelNode seekRel, SubQueryType subQueryType, Logic logic, boolean notIn) { + final RelBuilder relBuilder = + RelFactories.LOGICAL_BUILDER.create(seekRel.getCluster(), null); + return createExistsPlan(seekRel, subQueryType, logic, notIn, relBuilder); + } + + /** + * Creates a plan suitable for use in EXISTS or IN statements. + * + * @see org.apache.calcite.sql2rel.SqlToRelConverter SqlToRelConverter#convertExists + * @param seekRel A query rel, for example the resulting rel from 'select * from emp' or 'values + * (1,2,3)' or '('Foo', 34)'. + * @param subQueryType Sub-query type + * @param logic Whether to use 2- or 3-valued boolean logic + * @param notIn Whether the operator is NOT IN + * @param relBuilder Builder for relational expressions + * @return A pair of a relational expression which outer joins a boolean condition column, and a + * numeric offset. The offset is 2 if column 0 is the number of rows and column 1 is the + * number of rows with not-null keys; 0 otherwise. + */ + public static Exists createExistsPlan( + RelNode seekRel, + SubQueryType subQueryType, + Logic logic, + boolean notIn, + RelBuilder relBuilder) { + switch (subQueryType) { + case SCALAR: + return new Exists(seekRel, false, true); + default: + break; + } + + switch (logic) { + case TRUE_FALSE_UNKNOWN: + case UNKNOWN_AS_TRUE: + if (notIn && !containsNullableFields(seekRel)) { + logic = Logic.TRUE_FALSE; + } + break; + default: + break; + } + RelNode ret = seekRel; + final RelOptCluster cluster = seekRel.getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + final int keyCount = ret.getRowType().getFieldCount(); + final boolean outerJoin = notIn || logic == RelOptUtil.Logic.TRUE_FALSE_UNKNOWN; + if (!outerJoin) { + final LogicalAggregate aggregate = + LogicalAggregate.create( + ret, + ImmutableList.of(), + ImmutableBitSet.range(keyCount), + null, + ImmutableList.of()); + return new Exists(aggregate, false, false); + } + + // for IN/NOT IN, it needs to output the fields + final List exprs = new ArrayList<>(); + if (subQueryType == SubQueryType.IN) { + for (int i = 0; i < keyCount; i++) { + exprs.add(rexBuilder.makeInputRef(ret, i)); + } + } + + final int projectedKeyCount = exprs.size(); + exprs.add(rexBuilder.makeLiteral(true)); + + ret = + relBuilder + .push(ret) + .project(exprs) + .aggregate( + relBuilder.groupKey(ImmutableBitSet.range(projectedKeyCount)), + relBuilder.min(relBuilder.field(projectedKeyCount))) + .build(); + + switch (logic) { + case TRUE_FALSE_UNKNOWN: + case UNKNOWN_AS_TRUE: + return new Exists(ret, true, true); + default: + return new Exists(ret, false, true); + } + } + + @Deprecated // to be removed before 2.0 + public static RelNode createRenameRel(RelDataType outputType, RelNode rel) { + RelDataType inputType = rel.getRowType(); + List inputFields = inputType.getFieldList(); + int n = inputFields.size(); + + List outputFields = outputType.getFieldList(); + assert outputFields.size() == n + : "rename: field count mismatch: in=" + inputType + ", out" + outputType; + + final PairList renames = PairList.of(); + final RexBuilder rexBuilder = rel.getCluster().getRexBuilder(); + Pair.forEach( + inputFields, + outputFields, + (inputField, outputField) -> { + assert inputField.getType().equals(outputField.getType()); + renames.add( + rexBuilder.makeInputRef(inputField.getType(), inputField.getIndex()), + outputField.getName()); + }); + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(rel.getCluster(), null); + return relBuilder.push(rel).project(renames.leftList(), renames.rightList(), true).build(); + } + + @Deprecated // to be removed before 2.0 + public static RelNode createFilter(RelNode child, RexNode condition) { + final RelFactories.FilterFactory factory = RelFactories.DEFAULT_FILTER_FACTORY; + return factory.createFilter(child, condition, ImmutableSet.of()); + } + + @Deprecated // to be removed before 2.0 + public static RelNode createFilter( + RelNode child, RexNode condition, RelFactories.FilterFactory filterFactory) { + return filterFactory.createFilter(child, condition, ImmutableSet.of()); + } + + /** + * Creates a filter, using the default filter factory, or returns the original relational + * expression if the condition is trivial. + */ + public static RelNode createFilter(RelNode child, Iterable conditions) { + return createFilter(child, conditions, RelFactories.DEFAULT_FILTER_FACTORY); + } + + /** + * Creates a filter using the default factory, or returns the original relational expression if + * the condition is trivial. + */ + public static RelNode createFilter( + RelNode child, + Iterable conditions, + RelFactories.FilterFactory filterFactory) { + final RelOptCluster cluster = child.getCluster(); + final RexNode condition = + RexUtil.composeConjunction(cluster.getRexBuilder(), conditions, true); + if (condition == null) { + return child; + } else { + return filterFactory.createFilter(child, condition, ImmutableSet.of()); + } + } + + @Deprecated // to be removed before 2.0 + public static RelNode createNullFilter(RelNode rel, Integer[] fieldOrdinals) { + RexNode condition = null; + final RexBuilder rexBuilder = rel.getCluster().getRexBuilder(); + RelDataType rowType = rel.getRowType(); + int n; + if (fieldOrdinals != null) { + n = fieldOrdinals.length; + } else { + n = rowType.getFieldCount(); + } + List fields = rowType.getFieldList(); + for (int i = 0; i < n; ++i) { + int iField; + if (fieldOrdinals != null) { + iField = fieldOrdinals[i]; + } else { + iField = i; + } + RelDataType type = fields.get(iField).getType(); + if (!type.isNullable()) { + continue; + } + RexNode newCondition = + rexBuilder.makeCall( + SqlStdOperatorTable.IS_NOT_NULL, rexBuilder.makeInputRef(type, iField)); + if (condition == null) { + condition = newCondition; + } else { + condition = rexBuilder.makeCall(SqlStdOperatorTable.AND, condition, newCondition); + } + } + if (condition == null) { + // no filtering required + return rel; + } + + final RelFactories.FilterFactory factory = RelFactories.DEFAULT_FILTER_FACTORY; + return factory.createFilter(rel, condition, ImmutableSet.of()); + } + + /** + * Creates a projection which casts a rel's output to a desired row type. + * + *

No need to create new projection if {@code rel} is already a project, instead, create a + * projection with the input of {@code rel} and the new cast expressions. + * + *

The desired row type and the row type to be converted must have the same number of fields. + * + * @param rel producer of rows to be converted + * @param castRowType row type after cast + * @param rename if true, use field names from castRowType; if false, preserve field names from + * rel + * @return conversion rel + */ + public static RelNode createCastRel( + final RelNode rel, RelDataType castRowType, boolean rename) { + return createCastRel(rel, castRowType, rename, RelFactories.DEFAULT_PROJECT_FACTORY); + } + + /** + * Creates a projection which casts a rel's output to a desired row type. + * + *

No need to create new projection if {@code rel} is already a project, instead, create a + * projection with the input of {@code rel} and the new cast expressions. + * + *

The desired row type and the row type to be converted must have the same number of fields. + * + * @param rel producer of rows to be converted + * @param castRowType row type after cast + * @param rename if true, use field names from castRowType; if false, preserve field names from + * rel + * @param projectFactory Project Factory + * @return conversion rel + */ + public static RelNode createCastRel( + final RelNode rel, + RelDataType castRowType, + boolean rename, + RelFactories.ProjectFactory projectFactory) { + assert projectFactory != null; + RelDataType rowType = rel.getRowType(); + if (areRowTypesEqual(rowType, castRowType, rename)) { + // nothing to do + return rel; + } + if (rowType.getFieldCount() != castRowType.getFieldCount()) { + throw new IllegalArgumentException( + "Field counts are not equal: " + + "rowType [" + + rowType + + "] castRowType [" + + castRowType + + "]"); + } + final RexBuilder rexBuilder = rel.getCluster().getRexBuilder(); + List castExps; + RelNode input; + List hints = ImmutableList.of(); + Set correlationVariables; + if (rel instanceof Project) { + // No need to create another project node if the rel + // is already a project. + final Project project = (Project) rel; + castExps = + RexUtil.generateCastExpressions( + rexBuilder, castRowType, ((Project) rel).getProjects()); + input = rel.getInput(0); + hints = project.getHints(); + correlationVariables = project.getVariablesSet(); + } else { + castExps = RexUtil.generateCastExpressions(rexBuilder, castRowType, rowType); + input = rel; + correlationVariables = ImmutableSet.of(); + } + if (rename) { + // Use names and types from castRowType. + return projectFactory.createProject( + input, hints, castExps, castRowType.getFieldNames(), correlationVariables); + } else { + // Use names from rowType, types from castRowType. + return projectFactory.createProject( + input, hints, castExps, rowType.getFieldNames(), correlationVariables); + } + } + + /** Gets all fields in an aggregate. */ + public static Set getAllFields(Aggregate aggregate) { + return getAllFields2(aggregate.getGroupSet(), aggregate.getAggCallList()); + } + + /** Gets all fields in an aggregate. */ + public static Set getAllFields2( + ImmutableBitSet groupSet, List aggCallList) { + final Set allFields = new TreeSet<>(); + allFields.addAll(groupSet.asList()); + for (AggregateCall aggregateCall : aggCallList) { + allFields.addAll(aggregateCall.getArgList()); + if (aggregateCall.filterArg >= 0) { + allFields.add(aggregateCall.filterArg); + } + if (aggregateCall.distinctKeys != null) { + allFields.addAll(aggregateCall.distinctKeys.asList()); + } + allFields.addAll(RelCollations.ordinals(aggregateCall.collation)); + } + return allFields; + } + + /** + * Creates a LogicalAggregate that removes all duplicates from the result of an underlying + * relational expression. + * + * @param rel underlying rel + * @return rel implementing SingleValueAgg + */ + public static RelNode createSingleValueAggRel(RelOptCluster cluster, RelNode rel) { + final int aggCallCnt = rel.getRowType().getFieldCount(); + final List aggCalls = new ArrayList<>(); + + for (int i = 0; i < aggCallCnt; i++) { + aggCalls.add( + AggregateCall.create( + SqlStdOperatorTable.SINGLE_VALUE, + false, + false, + false, + ImmutableList.of(), + ImmutableList.of(i), + -1, + null, + RelCollations.EMPTY, + 0, + rel, + null, + null)); + } + + return LogicalAggregate.create( + rel, ImmutableList.of(), ImmutableBitSet.of(), null, aggCalls); + } + + // CHECKSTYLE: IGNORE 1 + /** + * @deprecated Use {@link RelBuilder#distinct()}. + */ + @Deprecated // to be removed before 2.0 + public static RelNode createDistinctRel(RelNode rel) { + return LogicalAggregate.create( + rel, + ImmutableList.of(), + ImmutableBitSet.range(rel.getRowType().getFieldCount()), + null, + ImmutableList.of()); + } + + @Deprecated // to be removed before 2.0 + public static boolean analyzeSimpleEquiJoin(LogicalJoin join, int[] joinFieldOrdinals) { + RexNode joinExp = join.getCondition(); + if (joinExp.getKind() != SqlKind.EQUALS) { + return false; + } + RexCall binaryExpression = (RexCall) joinExp; + RexNode leftComparand = binaryExpression.operands.get(0); + RexNode rightComparand = binaryExpression.operands.get(1); + if (!(leftComparand instanceof RexInputRef)) { + return false; + } + if (!(rightComparand instanceof RexInputRef)) { + return false; + } + + final int leftFieldCount = join.getLeft().getRowType().getFieldCount(); + RexInputRef leftFieldAccess = (RexInputRef) leftComparand; + if (!(leftFieldAccess.getIndex() < leftFieldCount)) { + // left field must access left side of join + return false; + } + + RexInputRef rightFieldAccess = (RexInputRef) rightComparand; + if (!(rightFieldAccess.getIndex() >= leftFieldCount)) { + // right field must access right side of join + return false; + } + + joinFieldOrdinals[0] = leftFieldAccess.getIndex(); + joinFieldOrdinals[1] = rightFieldAccess.getIndex() - leftFieldCount; + return true; + } + + /** + * Splits out the equi-join components of a join condition, and returns what's left. For + * example, given the condition + * + *

+ * + * L.A = R.X AND L.B = L.C AND (L.D = 5 OR L.E = + * R.Y) + * + *
+ * + *

returns + * + *

+ * + * @param left left input to join + * @param right right input to join + * @param condition join condition + * @param leftKeys The ordinals of the fields from the left input which are equi-join keys + * @param rightKeys The ordinals of the fields from the right input which are equi-join keys + * @param filterNulls List of boolean values for each join key position indicating whether the + * operator filters out nulls or not. Value is true if the operator is EQUALS and false if + * the operator is IS NOT DISTINCT FROM (or an expanded version). If filterNulls + * is null, only join conditions with EQUALS operators are considered equi-join + * components. Rest (including IS NOT DISTINCT FROM) are returned in remaining join + * condition. + * @return remaining join filters that are not equijoins; may return a {@link RexLiteral} true, + * but never null + */ + public static RexNode splitJoinCondition( + RelNode left, + RelNode right, + RexNode condition, + List leftKeys, + List rightKeys, + @Nullable List filterNulls) { + final List nonEquiList = new ArrayList<>(); + + splitJoinCondition(left, right, condition, leftKeys, rightKeys, filterNulls, nonEquiList); + + return RexUtil.composeConjunction(left.getCluster().getRexBuilder(), nonEquiList); + } + + /** + * As {@link #splitJoinCondition(RelNode, RelNode, RexNode, List, List, List)}, but writes + * non-equi conditions to a conjunctive list. + */ + public static void splitJoinCondition( + RelNode left, + RelNode right, + RexNode condition, + List leftKeys, + List rightKeys, + @Nullable List filterNulls, + List nonEquiList) { + splitJoinCondition( + left.getCluster().getRexBuilder(), + left.getRowType().getFieldCount(), + condition, + leftKeys, + rightKeys, + filterNulls, + nonEquiList); + } + + @Deprecated // to be removed before 2.0 + public static boolean isEqui(RelNode left, RelNode right, RexNode condition) { + final List leftKeys = new ArrayList<>(); + final List rightKeys = new ArrayList<>(); + final List filterNulls = new ArrayList<>(); + final List nonEquiList = new ArrayList<>(); + splitJoinCondition( + left.getCluster().getRexBuilder(), + left.getRowType().getFieldCount(), + condition, + leftKeys, + rightKeys, + filterNulls, + nonEquiList); + return nonEquiList.size() == 0; + } + + /** + * Splits out the equi-join (and optionally, a single non-equi) components of a join condition, + * and returns what's left. Projection might be required by the caller to provide join keys that + * are not direct field references. + * + * @param sysFieldList list of system fields + * @param leftRel left join input + * @param rightRel right join input + * @param condition join condition + * @param leftJoinKeys The join keys from the left input which are equi-join keys + * @param rightJoinKeys The join keys from the right input which are equi-join keys + * @param filterNulls The join key positions for which null values will not match. null values + * only match for the "is not distinct from" condition. + * @param rangeOp if null, only locate equi-joins; otherwise, locate a single non-equi join + * predicate and return its operator in this list; join keys associated with the non-equi + * join predicate are at the end of the key lists returned + * @return What's left, never null + */ + public static RexNode splitJoinCondition( + List sysFieldList, + RelNode leftRel, + RelNode rightRel, + RexNode condition, + List leftJoinKeys, + List rightJoinKeys, + @Nullable List filterNulls, + @Nullable List rangeOp) { + return splitJoinCondition( + sysFieldList, + ImmutableList.of(leftRel, rightRel), + condition, + ImmutableList.of(leftJoinKeys, rightJoinKeys), + filterNulls, + rangeOp); + } + + /** + * Splits out the equi-join (and optionally, a single non-equi) components of a join condition, + * and returns what's left. Projection might be required by the caller to provide join keys that + * are not direct field references. + * + * @param sysFieldList list of system fields + * @param inputs join inputs + * @param condition join condition + * @param joinKeys The join keys from the inputs which are equi-join keys + * @param filterNulls The join key positions for which null values will not match. null values + * only match for the "is not distinct from" condition. + * @param rangeOp if null, only locate equi-joins; otherwise, locate a single non-equi join + * predicate and return its operator in this list; join keys associated with the non-equi + * join predicate are at the end of the key lists returned + * @return What's left, never null + */ + public static RexNode splitJoinCondition( + List sysFieldList, + List inputs, + RexNode condition, + List> joinKeys, + @Nullable List filterNulls, + @Nullable List rangeOp) { + final List nonEquiList = new ArrayList<>(); + + splitJoinCondition( + sysFieldList, inputs, condition, joinKeys, filterNulls, rangeOp, nonEquiList); + + // Convert the remainders into a list that are AND'ed together. + return RexUtil.composeConjunction(inputs.get(0).getCluster().getRexBuilder(), nonEquiList); + } + + @Deprecated // to be removed before 2.0 + public static @Nullable RexNode splitCorrelatedFilterCondition( + LogicalFilter filter, List joinKeys, List correlatedJoinKeys) { + final List nonEquiList = new ArrayList<>(); + + splitCorrelatedFilterCondition( + filter, filter.getCondition(), joinKeys, correlatedJoinKeys, nonEquiList); + + // Convert the remainders into a list that are AND'ed together. + return RexUtil.composeConjunction(filter.getCluster().getRexBuilder(), nonEquiList, true); + } + + public static @Nullable RexNode splitCorrelatedFilterCondition( + LogicalFilter filter, + List joinKeys, + List correlatedJoinKeys, + boolean extractCorrelatedFieldAccess) { + return splitCorrelatedFilterCondition( + (Filter) filter, joinKeys, correlatedJoinKeys, extractCorrelatedFieldAccess); + } + + public static @Nullable RexNode splitCorrelatedFilterCondition( + Filter filter, + List joinKeys, + List correlatedJoinKeys, + boolean extractCorrelatedFieldAccess) { + final List nonEquiList = new ArrayList<>(); + + splitCorrelatedFilterCondition( + filter, + filter.getCondition(), + joinKeys, + correlatedJoinKeys, + nonEquiList, + extractCorrelatedFieldAccess); + + // Convert the remainders into a list that are AND'ed together. + return RexUtil.composeConjunction(filter.getCluster().getRexBuilder(), nonEquiList, true); + } + + private static void splitJoinCondition( + List sysFieldList, + List inputs, + RexNode condition, + List> joinKeys, + @Nullable List filterNulls, + @Nullable List rangeOp, + List nonEquiList) { + final int sysFieldCount = sysFieldList.size(); + final RelOptCluster cluster = inputs.get(0).getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); + + final ImmutableBitSet[] inputsRange = new ImmutableBitSet[inputs.size()]; + int totalFieldCount = 0; + for (int i = 0; i < inputs.size(); i++) { + final int firstField = totalFieldCount + sysFieldCount; + totalFieldCount = firstField + inputs.get(i).getRowType().getFieldCount(); + inputsRange[i] = ImmutableBitSet.range(firstField, totalFieldCount); + } + + // adjustment array + int[] adjustments = new int[totalFieldCount]; + for (int i = 0; i < inputs.size(); i++) { + final int adjustment = inputsRange[i].nextSetBit(0); + for (int j = adjustment; j < inputsRange[i].length(); j++) { + adjustments[j] = -adjustment; + } + } + + if (condition.getKind() == SqlKind.AND) { + for (RexNode operand : ((RexCall) condition).getOperands()) { + splitJoinCondition( + sysFieldList, inputs, operand, joinKeys, filterNulls, rangeOp, nonEquiList); + } + return; + } + + if (condition instanceof RexCall) { + RexNode leftKey = null; + RexNode rightKey = null; + int leftInput = 0; + int rightInput = 0; + List leftFields = null; + List rightFields = null; + boolean reverse = false; + + final RexCall call = + collapseExpandedIsNotDistinctFromExpr((RexCall) condition, rexBuilder); + SqlKind kind = call.getKind(); + + // Only consider range operators if we haven't already seen one + if ((kind == SqlKind.EQUALS) + || (filterNulls != null && kind == SqlKind.IS_NOT_DISTINCT_FROM) + || (rangeOp != null + && rangeOp.isEmpty() + && (kind == SqlKind.GREATER_THAN + || kind == SqlKind.GREATER_THAN_OR_EQUAL + || kind == SqlKind.LESS_THAN + || kind == SqlKind.LESS_THAN_OR_EQUAL))) { + final List operands = call.getOperands(); + RexNode op0 = operands.get(0); + RexNode op1 = operands.get(1); + + final ImmutableBitSet projRefs0 = InputFinder.bits(op0); + final ImmutableBitSet projRefs1 = InputFinder.bits(op1); + + boolean foundBothInputs = false; + for (int i = 0; i < inputs.size() && !foundBothInputs; i++) { + if (projRefs0.intersects(inputsRange[i]) + && projRefs0.union(inputsRange[i]).equals(inputsRange[i])) { + if (leftKey == null) { + leftKey = op0; + leftInput = i; + leftFields = inputs.get(leftInput).getRowType().getFieldList(); + } else { + rightKey = op0; + rightInput = i; + rightFields = inputs.get(rightInput).getRowType().getFieldList(); + reverse = true; + foundBothInputs = true; + } + } else if (projRefs1.intersects(inputsRange[i]) + && projRefs1.union(inputsRange[i]).equals(inputsRange[i])) { + if (leftKey == null) { + leftKey = op1; + leftInput = i; + leftFields = inputs.get(leftInput).getRowType().getFieldList(); + } else { + rightKey = op1; + rightInput = i; + rightFields = inputs.get(rightInput).getRowType().getFieldList(); + foundBothInputs = true; + } + } + } + + if ((leftKey != null) && (rightKey != null)) { + // replace right Key input ref + rightKey = + rightKey.accept( + new RelOptUtil.RexInputConverter( + rexBuilder, rightFields, rightFields, adjustments)); + + // left key only needs to be adjusted if there are system + // fields, but do it for uniformity + leftKey = + leftKey.accept( + new RelOptUtil.RexInputConverter( + rexBuilder, leftFields, leftFields, adjustments)); + + RelDataType leftKeyType = leftKey.getType(); + RelDataType rightKeyType = rightKey.getType(); + + if (leftKeyType != rightKeyType) { + // perform casting + RelDataType targetKeyType = + typeFactory.leastRestrictive( + ImmutableList.of(leftKeyType, rightKeyType)); + + if (targetKeyType == null) { + throw new AssertionError( + "Cannot find common type for join keys " + + leftKey + + " (type " + + leftKeyType + + ") and " + + rightKey + + " (type " + + rightKeyType + + ")"); + } + + if (leftKeyType != targetKeyType) { + leftKey = rexBuilder.makeCast(targetKeyType, leftKey); + } + + if (rightKeyType != targetKeyType) { + rightKey = rexBuilder.makeCast(targetKeyType, rightKey); + } + } + } + } + + if ((leftKey != null) && (rightKey != null)) { + // found suitable join keys + // add them to key list, ensuring that if there is a + // non-equi join predicate, it appears at the end of the + // key list; also mark the null filtering property + addJoinKey( + joinKeys.get(leftInput), leftKey, (rangeOp != null) && !rangeOp.isEmpty()); + addJoinKey( + joinKeys.get(rightInput), + rightKey, + (rangeOp != null) && !rangeOp.isEmpty()); + if (filterNulls != null && kind == SqlKind.EQUALS) { + // nulls are considered not matching for equality comparison + // add the position of the most recently inserted key + filterNulls.add(joinKeys.get(leftInput).size() - 1); + } + if (rangeOp != null && kind != SqlKind.EQUALS && kind != SqlKind.IS_DISTINCT_FROM) { + SqlOperator op = call.getOperator(); + if (reverse) { + op = requireNonNull(op.reverse()); + } + rangeOp.add(op); + } + return; + } // else fall through and add this condition as nonEqui condition + } + + // The operator is not of RexCall type + // So we fail. Fall through. + // Add this condition to the list of non-equi-join conditions. + nonEquiList.add(condition); + } + + /** Builds an equi-join condition from a set of left and right keys. */ + public static RexNode createEquiJoinCondition( + final RelNode left, + final List leftKeys, + final RelNode right, + final List rightKeys, + final RexBuilder rexBuilder) { + final List leftTypes = RelOptUtil.getFieldTypeList(left.getRowType()); + final List rightTypes = RelOptUtil.getFieldTypeList(right.getRowType()); + return RexUtil.composeConjunction( + rexBuilder, + new AbstractList() { + @Override + public RexNode get(int index) { + final int leftKey = leftKeys.get(index); + final int rightKey = rightKeys.get(index); + return rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, + rexBuilder.makeInputRef(leftTypes.get(leftKey), leftKey), + rexBuilder.makeInputRef( + rightTypes.get(rightKey), leftTypes.size() + rightKey)); + } + + @Override + public int size() { + return leftKeys.size(); + } + }); + } + + /** + * Returns {@link SqlOperator} for given {@link SqlKind} or returns {@code operator} when {@link + * SqlKind} is not known. + * + * @param kind input kind + * @param operator default operator value + * @return SqlOperator for the given kind + * @see RexUtil#op(SqlKind) + */ + public static SqlOperator op(SqlKind kind, SqlOperator operator) { + switch (kind) { + case EQUALS: + return SqlStdOperatorTable.EQUALS; + case NOT_EQUALS: + return SqlStdOperatorTable.NOT_EQUALS; + case GREATER_THAN: + return SqlStdOperatorTable.GREATER_THAN; + case GREATER_THAN_OR_EQUAL: + return SqlStdOperatorTable.GREATER_THAN_OR_EQUAL; + case LESS_THAN: + return SqlStdOperatorTable.LESS_THAN; + case LESS_THAN_OR_EQUAL: + return SqlStdOperatorTable.LESS_THAN_OR_EQUAL; + case IS_DISTINCT_FROM: + return SqlStdOperatorTable.IS_DISTINCT_FROM; + case IS_NOT_DISTINCT_FROM: + return SqlStdOperatorTable.IS_NOT_DISTINCT_FROM; + default: + return operator; + } + } + + private static void addJoinKey( + List joinKeyList, RexNode key, boolean preserveLastElementInList) { + if (!joinKeyList.isEmpty() && preserveLastElementInList) { + joinKeyList.add(joinKeyList.size() - 1, key); + } else { + joinKeyList.add(key); + } + } + + private static void splitCorrelatedFilterCondition( + LogicalFilter filter, + RexNode condition, + List joinKeys, + List correlatedJoinKeys, + List nonEquiList) { + if (condition instanceof RexCall) { + RexCall call = (RexCall) condition; + if (call.getOperator().getKind() == SqlKind.AND) { + for (RexNode operand : call.getOperands()) { + splitCorrelatedFilterCondition( + filter, operand, joinKeys, correlatedJoinKeys, nonEquiList); + } + return; + } + + if (call.getOperator().getKind() == SqlKind.EQUALS) { + final List operands = call.getOperands(); + RexNode op0 = operands.get(0); + RexNode op1 = operands.get(1); + + if (!RexUtil.containsInputRef(op0) && op1 instanceof RexInputRef) { + correlatedJoinKeys.add(op0); + joinKeys.add((RexInputRef) op1); + return; + } else if (op0 instanceof RexInputRef && !RexUtil.containsInputRef(op1)) { + joinKeys.add((RexInputRef) op0); + correlatedJoinKeys.add(op1); + return; + } + } + } + + // The operator is not of RexCall type + // So we fail. Fall through. + // Add this condition to the list of non-equi-join conditions. + nonEquiList.add(condition); + } + + @SuppressWarnings("unused") + private static void splitCorrelatedFilterCondition( + LogicalFilter filter, + RexNode condition, + List joinKeys, + List correlatedJoinKeys, + List nonEquiList, + boolean extractCorrelatedFieldAccess) { + splitCorrelatedFilterCondition( + (Filter) filter, + condition, + joinKeys, + correlatedJoinKeys, + nonEquiList, + extractCorrelatedFieldAccess); + } + + private static void splitCorrelatedFilterCondition( + Filter filter, + RexNode condition, + List joinKeys, + List correlatedJoinKeys, + List nonEquiList, + boolean extractCorrelatedFieldAccess) { + if (condition instanceof RexCall) { + RexCall call = (RexCall) condition; + if (call.getOperator().getKind() == SqlKind.AND) { + for (RexNode operand : call.getOperands()) { + splitCorrelatedFilterCondition( + filter, + operand, + joinKeys, + correlatedJoinKeys, + nonEquiList, + extractCorrelatedFieldAccess); + } + return; + } + + if (call.getOperator().getKind() == SqlKind.EQUALS) { + final List operands = call.getOperands(); + RexNode op0 = operands.get(0); + RexNode op1 = operands.get(1); + + if (extractCorrelatedFieldAccess) { + if (!RexUtil.containsFieldAccess(op0) && op1 instanceof RexFieldAccess) { + joinKeys.add(op0); + correlatedJoinKeys.add(op1); + return; + } else if (op0 instanceof RexFieldAccess && !RexUtil.containsFieldAccess(op1)) { + correlatedJoinKeys.add(op0); + joinKeys.add(op1); + return; + } + } else { + if (!RexUtil.containsInputRef(op0) && op1 instanceof RexInputRef) { + correlatedJoinKeys.add(op0); + joinKeys.add(op1); + return; + } else if (op0 instanceof RexInputRef && !RexUtil.containsInputRef(op1)) { + joinKeys.add(op0); + correlatedJoinKeys.add(op1); + return; + } + } + } + } + + // The operator is not of RexCall type + // So we fail. Fall through. + // Add this condition to the list of non-equi-join conditions. + nonEquiList.add(condition); + } + + private static void splitJoinCondition( + final RexBuilder rexBuilder, + final int leftFieldCount, + RexNode condition, + List leftKeys, + List rightKeys, + @Nullable List filterNulls, + List nonEquiList) { + if (condition instanceof RexCall) { + RexCall call = (RexCall) condition; + SqlKind kind = call.getKind(); + if (kind == SqlKind.AND) { + for (RexNode operand : call.getOperands()) { + splitJoinCondition( + rexBuilder, + leftFieldCount, + operand, + leftKeys, + rightKeys, + filterNulls, + nonEquiList); + } + return; + } + + if (filterNulls != null) { + call = collapseExpandedIsNotDistinctFromExpr(call, rexBuilder); + kind = call.getKind(); + } + + // "=" and "IS NOT DISTINCT FROM" are the same except for how they + // treat nulls. + if (kind == SqlKind.EQUALS + || (filterNulls != null && kind == SqlKind.IS_NOT_DISTINCT_FROM)) { + final List operands = call.getOperands(); + if ((operands.get(0) instanceof RexInputRef) + && (operands.get(1) instanceof RexInputRef)) { + RexInputRef op0 = (RexInputRef) operands.get(0); + RexInputRef op1 = (RexInputRef) operands.get(1); + + RexInputRef leftField; + RexInputRef rightField; + if ((op0.getIndex() < leftFieldCount) && (op1.getIndex() >= leftFieldCount)) { + // Arguments were of form 'op0 = op1' + leftField = op0; + rightField = op1; + } else if ((op1.getIndex() < leftFieldCount) + && (op0.getIndex() >= leftFieldCount)) { + // Arguments were of form 'op1 = op0' + leftField = op1; + rightField = op0; + } else { + nonEquiList.add(condition); + return; + } + + leftKeys.add(leftField.getIndex()); + rightKeys.add(rightField.getIndex() - leftFieldCount); + if (filterNulls != null) { + filterNulls.add(kind == SqlKind.EQUALS); + } + return; + } + // Arguments were not field references, one from each side, so + // we fail. Fall through. + } + } + + // Add this condition to the list of non-equi-join conditions. + if (!condition.isAlwaysTrue()) { + nonEquiList.add(condition); + } + } + + /** + * Collapses an expanded version of {@code IS NOT DISTINCT FROM} expression. + * + *

Helper method for {@link #splitJoinCondition(RexBuilder, int, RexNode, List, List, List, + * List)} and {@link #splitJoinCondition(List, List, RexNode, List, List, List, List)}. + * + *

If the given expr call is an expanded version of {@code IS NOT DISTINCT FROM} + * function call, collapses it and return a {@code IS NOT DISTINCT FROM} function call. + * + *

For example: {@code t1.key IS NOT DISTINCT FROM t2.key} can rewritten in expanded form as + * {@code t1.key = t2.key OR (t1.key IS NULL AND t2.key IS NULL)}. + * + * @param call Function expression to try collapsing + * @param rexBuilder {@link RexBuilder} instance to create new {@link RexCall} instances. + * @return If the given function is an expanded IS NOT DISTINCT FROM function call, return a IS + * NOT DISTINCT FROM function call. Otherwise return the input function call as it is. + */ + public static RexCall collapseExpandedIsNotDistinctFromExpr( + final RexCall call, final RexBuilder rexBuilder) { + switch (call.getKind()) { + case OR: + return doCollapseExpandedIsNotDistinctFromOrExpr(call, rexBuilder); + + case CASE: + return doCollapseExpandedIsNotDistinctFromCaseExpr(call, rexBuilder); + + default: + return call; + } + } + + private static RexCall doCollapseExpandedIsNotDistinctFromOrExpr( + final RexCall call, final RexBuilder rexBuilder) { + if (call.getKind() != SqlKind.OR || call.getOperands().size() != 2) { + return call; + } + + final RexNode op0 = call.getOperands().get(0); + final RexNode op1 = call.getOperands().get(1); + + if (!(op0 instanceof RexCall) || !(op1 instanceof RexCall)) { + return call; + } + + RexCall opEqCall = (RexCall) op0; + RexCall opNullEqCall = (RexCall) op1; + + // Swapping the operands if necessary + if (opEqCall.getKind() == SqlKind.AND + && (opNullEqCall.getKind() == SqlKind.EQUALS + || opNullEqCall.getKind() == SqlKind.IS_TRUE)) { + RexCall temp = opEqCall; + opEqCall = opNullEqCall; + opNullEqCall = temp; + } + + // Check if EQUALS is actually wrapped in IS TRUE expression + if (opEqCall.getKind() == SqlKind.IS_TRUE) { + RexNode tmp = opEqCall.getOperands().get(0); + if (!(tmp instanceof RexCall)) { + return call; + } + opEqCall = (RexCall) tmp; + } + + if (opNullEqCall.getKind() != SqlKind.AND + || opNullEqCall.getOperands().size() != 2 + || opEqCall.getKind() != SqlKind.EQUALS) { + return call; + } + + final RexNode op10 = opNullEqCall.getOperands().get(0); + final RexNode op11 = opNullEqCall.getOperands().get(1); + if (op10.getKind() != SqlKind.IS_NULL || op11.getKind() != SqlKind.IS_NULL) { + return call; + } + + return doCollapseExpandedIsNotDistinctFrom( + rexBuilder, call, (RexCall) op10, (RexCall) op11, opEqCall); + } + + private static RexCall doCollapseExpandedIsNotDistinctFromCaseExpr( + final RexCall call, final RexBuilder rexBuilder) { + if (call.getKind() != SqlKind.CASE || call.getOperands().size() != 5) { + return call; + } + + final RexNode op0 = call.getOperands().get(0); + final RexNode op1 = call.getOperands().get(1); + final RexNode op2 = call.getOperands().get(2); + final RexNode op3 = call.getOperands().get(3); + final RexNode op4 = call.getOperands().get(4); + + if (!(op0 instanceof RexCall) + || !(op1 instanceof RexCall) + || !(op2 instanceof RexCall) + || !(op3 instanceof RexCall) + || !(op4 instanceof RexCall)) { + return call; + } + + RexCall ifCall = (RexCall) op0; + RexCall thenCall = (RexCall) op1; + RexCall elseIfCall = (RexCall) op2; + RexCall elseIfThenCall = (RexCall) op3; + RexCall elseCall = (RexCall) op4; + + if (ifCall.getKind() != SqlKind.IS_NULL + || thenCall.getKind() != SqlKind.IS_NULL + || elseIfCall.getKind() != SqlKind.IS_NULL + || elseIfThenCall.getKind() != SqlKind.IS_NULL + || elseCall.getKind() != SqlKind.EQUALS) { + return call; + } + + if (!ifCall.equals(elseIfThenCall) || !thenCall.equals(elseIfCall)) { + return call; + } + + return doCollapseExpandedIsNotDistinctFrom(rexBuilder, call, ifCall, elseIfCall, elseCall); + } + + private static RexCall doCollapseExpandedIsNotDistinctFrom( + final RexBuilder rexBuilder, + final RexCall call, + RexCall ifNull0Call, + RexCall ifNull1Call, + RexCall equalsCall) { + final RexNode isNullInput0 = ifNull0Call.getOperands().get(0); + final RexNode isNullInput1 = ifNull1Call.getOperands().get(0); + + final RexNode equalsInput0 = + RexUtil.removeNullabilityCast( + rexBuilder.getTypeFactory(), equalsCall.getOperands().get(0)); + final RexNode equalsInput1 = + RexUtil.removeNullabilityCast( + rexBuilder.getTypeFactory(), equalsCall.getOperands().get(1)); + + if ((isNullInput0.equals(equalsInput0) && isNullInput1.equals(equalsInput1)) + || (isNullInput1.equals(equalsInput0) && isNullInput0.equals(equalsInput1))) { + return (RexCall) + rexBuilder.makeCall( + SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, + ImmutableList.of(isNullInput0, isNullInput1)); + } + + return call; + } + + @Deprecated // to be removed before 2.0 + public static void projectJoinInputs( + RelNode[] inputRels, + List leftJoinKeys, + List rightJoinKeys, + int systemColCount, + List leftKeys, + List rightKeys, + List outputProj) { + RelNode leftRel = inputRels[0]; + RelNode rightRel = inputRels[1]; + final RelOptCluster cluster = leftRel.getCluster(); + final RexBuilder rexBuilder = cluster.getRexBuilder(); + + int origLeftInputSize = leftRel.getRowType().getFieldCount(); + int origRightInputSize = rightRel.getRowType().getFieldCount(); + + final List newLeftFields = new ArrayList<>(); + final List<@Nullable String> newLeftFieldNames = new ArrayList<>(); + + final List newRightFields = new ArrayList<>(); + final List<@Nullable String> newRightFieldNames = new ArrayList<>(); + int leftKeyCount = leftJoinKeys.size(); + int rightKeyCount = rightJoinKeys.size(); + int i; + + for (i = 0; i < systemColCount; i++) { + outputProj.add(i); + } + + for (i = 0; i < origLeftInputSize; i++) { + final RelDataTypeField field = leftRel.getRowType().getFieldList().get(i); + newLeftFields.add(rexBuilder.makeInputRef(field.getType(), i)); + newLeftFieldNames.add(field.getName()); + outputProj.add(systemColCount + i); + } + + int newLeftKeyCount = 0; + for (i = 0; i < leftKeyCount; i++) { + RexNode leftKey = leftJoinKeys.get(i); + + if (leftKey instanceof RexInputRef) { + // already added to the projected left fields + // only need to remember the index in the join key list + leftKeys.add(((RexInputRef) leftKey).getIndex()); + } else { + newLeftFields.add(leftKey); + newLeftFieldNames.add(null); + leftKeys.add(origLeftInputSize + newLeftKeyCount); + newLeftKeyCount++; + } + } + + int leftFieldCount = origLeftInputSize + newLeftKeyCount; + for (i = 0; i < origRightInputSize; i++) { + final RelDataTypeField field = rightRel.getRowType().getFieldList().get(i); + newRightFields.add(rexBuilder.makeInputRef(field.getType(), i)); + newRightFieldNames.add(field.getName()); + outputProj.add(systemColCount + leftFieldCount + i); + } + + int newRightKeyCount = 0; + for (i = 0; i < rightKeyCount; i++) { + RexNode rightKey = rightJoinKeys.get(i); + + if (rightKey instanceof RexInputRef) { + // already added to the projected left fields + // only need to remember the index in the join key list + rightKeys.add(((RexInputRef) rightKey).getIndex()); + } else { + newRightFields.add(rightKey); + newRightFieldNames.add(null); + rightKeys.add(origRightInputSize + newRightKeyCount); + newRightKeyCount++; + } + } + + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(cluster, null); + + // added project if need to produce new keys than the original input + // fields + if (newLeftKeyCount > 0) { + leftRel = + relBuilder + .push(leftRel) + .project(newLeftFields, newLeftFieldNames, true) + .build(); + } + + if (newRightKeyCount > 0) { + rightRel = + relBuilder.push(rightRel).project(newRightFields, newRightFieldNames).build(); + } + + inputRels[0] = leftRel; + inputRels[1] = rightRel; + } + + @Deprecated // to be removed before 2.0 + public static RelNode createProjectJoinRel(List outputProj, RelNode joinRel) { + int newProjectOutputSize = outputProj.size(); + List joinOutputFields = joinRel.getRowType().getFieldList(); + + // If no projection was passed in, or the number of desired projection + // columns is the same as the number of columns returned from the + // join, then no need to create a projection + if (newProjectOutputSize > 0 && newProjectOutputSize < joinOutputFields.size()) { + final PairList newProjects = PairList.of(); + final RelBuilder relBuilder = + RelFactories.LOGICAL_BUILDER.create(joinRel.getCluster(), null); + final RexBuilder rexBuilder = relBuilder.getRexBuilder(); + for (int fieldIndex : outputProj) { + final RelDataTypeField field = joinOutputFields.get(fieldIndex); + newProjects.add( + rexBuilder.makeInputRef(field.getType(), fieldIndex), field.getName()); + } + + // Create a project rel on the output of the join. + return relBuilder + .push(joinRel) + .project(newProjects.leftList(), newProjects.rightList(), true) + .build(); + } + + return joinRel; + } + + @Deprecated // to be removed before 2.0 + public static void registerAbstractRels(RelOptPlanner planner) { + registerAbstractRules(planner); + } + + @Experimental + public static void registerAbstractRules(RelOptPlanner planner) { + RelOptRules.ABSTRACT_RULES.forEach(planner::addRule); + } + + @Experimental + public static void registerAbstractRelationalRules(RelOptPlanner planner) { + RelOptRules.ABSTRACT_RELATIONAL_RULES.forEach(planner::addRule); + if (CalciteSystemProperty.COMMUTE.value()) { + planner.addRule(CoreRules.JOIN_ASSOCIATE); + } + // todo: rule which makes Project({OrdinalRef}) disappear + } + + private static void registerEnumerableRules(RelOptPlanner planner) { + EnumerableRules.ENUMERABLE_RULES.forEach(planner::addRule); + } + + private static void registerBaseRules(RelOptPlanner planner) { + RelOptRules.BASE_RULES.forEach(planner::addRule); + } + + @SuppressWarnings("unused") + private static void registerReductionRules(RelOptPlanner planner) { + RelOptRules.CONSTANT_REDUCTION_RULES.forEach(planner::addRule); + } + + private static void registerMaterializationRules(RelOptPlanner planner) { + RelOptRules.MATERIALIZATION_RULES.forEach(planner::addRule); + } + + @SuppressWarnings("unused") + private static void registerCalcRules(RelOptPlanner planner) { + RelOptRules.CALC_RULES.forEach(planner::addRule); + } + + @Experimental + public static void registerDefaultRules( + RelOptPlanner planner, boolean enableMaterializations, boolean enableBindable) { + if (CalciteSystemProperty.ENABLE_COLLATION_TRAIT.value()) { + registerAbstractRelationalRules(planner); + } + registerAbstractRules(planner); + registerBaseRules(planner); + + if (enableMaterializations) { + registerMaterializationRules(planner); + } + if (enableBindable) { + for (RelOptRule rule : Bindables.RULES) { + planner.addRule(rule); + } + } + // Registers this rule for default ENUMERABLE convention + // because: + // 1. ScannableTable can bind data directly; + // 2. Only BindableTable supports project push down now. + + // EnumerableInterpreterRule.INSTANCE would then transform + // the BindableTableScan to + // EnumerableInterpreter + BindableTableScan. + + // Note: the cost of EnumerableInterpreter + BindableTableScan + // is always bigger that EnumerableTableScan because of the additional + // EnumerableInterpreter node, but if there are pushing projects or filter, + // we prefer BindableTableScan instead, + // see BindableTableScan#computeSelfCost. + planner.addRule(Bindables.BINDABLE_TABLE_SCAN_RULE); + planner.addRule(CoreRules.PROJECT_TABLE_SCAN); + planner.addRule(CoreRules.PROJECT_INTERPRETER_TABLE_SCAN); + + if (CalciteSystemProperty.ENABLE_ENUMERABLE.value()) { + registerEnumerableRules(planner); + planner.addRule(EnumerableRules.TO_INTERPRETER); + } + + if (enableBindable && CalciteSystemProperty.ENABLE_ENUMERABLE.value()) { + planner.addRule(EnumerableRules.TO_BINDABLE); + } + + if (CalciteSystemProperty.ENABLE_STREAM.value()) { + for (RelOptRule rule : StreamRules.RULES) { + planner.addRule(rule); + } + } + + planner.addRule(CoreRules.FILTER_REDUCE_EXPRESSIONS); + } + + /** + * Dumps a plan as a string. + * + * @param header Header to print before the plan. Ignored if the format is XML + * @param rel Relational expression to explain + * @param format Output format + * @param detailLevel Detail level + * @return Plan + */ + public static String dumpPlan( + String header, RelNode rel, SqlExplainFormat format, SqlExplainLevel detailLevel) { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + if (!header.equals("")) { + pw.println(header); + } + RelWriter planWriter; + switch (format) { + case XML: + planWriter = new RelXmlWriter(pw, detailLevel); + break; + case JSON: + planWriter = new RelJsonWriter(); + rel.explain(planWriter); + return ((RelJsonWriter) planWriter).asString(); + case DOT: + planWriter = new RelDotWriter(pw, detailLevel, false); + break; + default: + planWriter = new RelWriterImpl(pw, detailLevel, false); + } + rel.explain(planWriter); + pw.flush(); + return sw.toString(); + } + + @Deprecated // to be removed before 2.0 + public static String dumpPlan( + String header, RelNode rel, boolean asXml, SqlExplainLevel detailLevel) { + return dumpPlan( + header, rel, asXml ? SqlExplainFormat.XML : SqlExplainFormat.TEXT, detailLevel); + } + + /** + * Creates the row type descriptor for the result of a DML operation, which is a single column + * named ROWCOUNT of type BIGINT for INSERT; a single column named PLAN for EXPLAIN. + * + * @param kind Kind of node + * @param typeFactory factory to use for creating type descriptor + * @return created type + */ + public static RelDataType createDmlRowType(SqlKind kind, RelDataTypeFactory typeFactory) { + switch (kind) { + case INSERT: + case DELETE: + case UPDATE: + case MERGE: + return typeFactory.createStructType( + PairList.of( + AvaticaConnection.ROWCOUNT_COLUMN_NAME, + typeFactory.createSqlType(SqlTypeName.BIGINT))); + case EXPLAIN: + return typeFactory.createStructType( + PairList.of( + AvaticaConnection.PLAN_COLUMN_NAME, + typeFactory.createSqlType( + SqlTypeName.VARCHAR, RelDataType.PRECISION_NOT_SPECIFIED))); + default: + throw Util.unexpected(kind); + } + } + + /** + * Returns whether two types are equal using 'equals'. + * + * @param desc1 Description of first type + * @param type1 First type + * @param desc2 Description of second type + * @param type2 Second type + * @param litmus What to do if an error is detected (types are not equal) + * @return Whether the types are equal + */ + public static boolean eq( + final String desc1, + RelDataType type1, + final String desc2, + RelDataType type2, + Litmus litmus) { + // if any one of the types is ANY return true + if (type1.getSqlTypeName() == SqlTypeName.ANY + || type2.getSqlTypeName() == SqlTypeName.ANY) { + return litmus.succeed(); + } + + if (!type1.equals(type2)) { + return litmus.fail( + "type mismatch:\n{}:\n{}\n{}:\n{}", + desc1, + type1.getFullTypeString(), + desc2, + type2.getFullTypeString()); + } + return litmus.succeed(); + } + + // ----- FLINK MODIFICATION BEGIN ----- + // Backport from Calcite (CALCITE-6764) + public static boolean eqUpToNullability( + boolean ignoreNullability, + final String desc1, + RelDataType type1, + final String desc2, + RelDataType type2, + Litmus litmus) { + if (type1.getSqlTypeName() == SqlTypeName.ANY + || type2.getSqlTypeName() == SqlTypeName.ANY) { + return litmus.succeed(); + } + + boolean success; + if (ignoreNullability) { + success = SqlTypeUtil.equalSansNullability(type1, type2); + } else { + success = type1.equals(type2); + } + + if (!success) { + return litmus.fail( + "type mismatch:\n{}:\n{}\n{}:\n{}", + desc1, + type1.getFullTypeString(), + desc2, + type2.getFullTypeString()); + } + return litmus.succeed(); + } + + // ----- FLINK MODIFICATION END ----- + + /** + * Returns whether two types are equal using {@link #areRowTypesEqual(RelDataType, RelDataType, + * boolean)}. Both types must not be null. + * + * @param desc1 Description of role of first type + * @param type1 First type + * @param desc2 Description of role of second type + * @param type2 Second type + * @param litmus Whether to assert if they are not equal + * @return Whether the types are equal + */ + public static boolean equal( + final String desc1, + RelDataType type1, + final String desc2, + RelDataType type2, + Litmus litmus) { + if (!areRowTypesEqual(type1, type2, false)) { + return litmus.fail(getFullTypeDifferenceString(desc1, type1, desc2, type2)); + } + return litmus.succeed(); + } + + /** + * Returns the detailed difference of two types. + * + * @param sourceDesc description of role of source type + * @param sourceType source type + * @param targetDesc description of role of target type + * @param targetType target type + * @return the detailed difference of two types + */ + public static String getFullTypeDifferenceString( + final String sourceDesc, + RelDataType sourceType, + final String targetDesc, + RelDataType targetType) { + if (sourceType == targetType) { + return ""; + } + + final int sourceFieldCount = sourceType.getFieldCount(); + final int targetFieldCount = targetType.getFieldCount(); + if (sourceFieldCount != targetFieldCount) { + return "Type mismatch: the field sizes are not equal.\n" + + sourceDesc + + ": " + + sourceType.getFullTypeString() + + "\n" + + targetDesc + + ": " + + targetType.getFullTypeString(); + } + + final StringBuilder stringBuilder = new StringBuilder(); + final List f1 = sourceType.getFieldList(); + final List f2 = targetType.getFieldList(); + for (Pair pair : Pair.zip(f1, f2)) { + final RelDataType t1 = pair.left.getType(); + final RelDataType t2 = pair.right.getType(); + // If one of the types is ANY comparison should succeed + if (sourceType.getSqlTypeName() == SqlTypeName.ANY + || targetType.getSqlTypeName() == SqlTypeName.ANY) { + continue; + } + if (!t1.equals(t2)) { + stringBuilder.append(pair.left.getName()); + stringBuilder.append(": "); + stringBuilder.append(t1.getFullTypeString()); + stringBuilder.append(" -> "); + stringBuilder.append(t2.getFullTypeString()); + stringBuilder.append("\n"); + } + } + final String difference = stringBuilder.toString(); + if (!difference.isEmpty()) { + return "Type mismatch:\n" + + sourceDesc + + ": " + + sourceType.getFullTypeString() + + "\n" + + targetDesc + + ": " + + targetType.getFullTypeString() + + "\n" + + "Difference:\n" + + difference; + } else { + return ""; + } + } + + /** Returns whether two relational expressions have the same row-type. */ + public static boolean equalType( + String desc0, RelNode rel0, String desc1, RelNode rel1, Litmus litmus) { + // TODO: change 'equal' to 'eq', which is stronger. + return equal(desc0, rel0.getRowType(), desc1, rel1.getRowType(), litmus); + } + + /** + * Returns a translation of the IS DISTINCT FROM (or IS + * NOT DISTINCT FROM) sql operator. + * + * @param neg if false, returns a translation of IS NOT DISTINCT FROM + */ + public static RexNode isDistinctFrom(RexBuilder rexBuilder, RexNode x, RexNode y, boolean neg) { + RexNode ret = null; + if (x.getType().isStruct()) { + assert y.getType().isStruct(); + List xFields = x.getType().getFieldList(); + List yFields = y.getType().getFieldList(); + assert xFields.size() == yFields.size(); + for (Pair pair : Pair.zip(xFields, yFields)) { + RelDataTypeField xField = pair.left; + RelDataTypeField yField = pair.right; + RexNode newX = rexBuilder.makeFieldAccess(x, xField.getIndex()); + RexNode newY = rexBuilder.makeFieldAccess(y, yField.getIndex()); + RexNode newCall = isDistinctFromInternal(rexBuilder, newX, newY, neg); + if (ret == null) { + ret = newCall; + } else { + ret = rexBuilder.makeCall(SqlStdOperatorTable.AND, ret, newCall); + } + } + } else { + ret = isDistinctFromInternal(rexBuilder, x, y, neg); + } + + // The result of IS DISTINCT FROM is NOT NULL because it can + // only return TRUE or FALSE. + assert ret != null; + assert !ret.getType().isNullable(); + + return ret; + } + + private static RexNode isDistinctFromInternal( + RexBuilder rexBuilder, RexNode x, RexNode y, boolean neg) { + + if (neg) { + // x is not distinct from y + // x=y IS TRUE or ((x is null) and (y is null)), + return rexBuilder.makeCall( + SqlStdOperatorTable.OR, + rexBuilder.makeCall( + SqlStdOperatorTable.AND, + rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, x), + rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, y)), + rexBuilder.makeCall( + SqlStdOperatorTable.IS_TRUE, + rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, x, y))); + } else { + // x is distinct from y + // x=y IS NOT TRUE and ((x is not null) or (y is not null)), + return rexBuilder.makeCall( + SqlStdOperatorTable.AND, + rexBuilder.makeCall( + SqlStdOperatorTable.OR, + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, x), + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, y)), + rexBuilder.makeCall( + SqlStdOperatorTable.IS_NOT_TRUE, + rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, x, y))); + } + } + + /** Converts a relational expression to a string, showing just basic attributes. */ + public static String toString(final RelNode rel) { + return toString(rel, SqlExplainLevel.EXPPLAN_ATTRIBUTES); + } + + /** + * Converts a relational expression to a string; returns null if and only if {@code rel} is + * null. + */ + public static @PolyNull String toString( + final @PolyNull RelNode rel, SqlExplainLevel detailLevel) { + if (rel == null) { + return null; + } + final StringWriter sw = new StringWriter(); + final RelWriter planWriter = new RelWriterImpl(new PrintWriter(sw), detailLevel, false); + rel.explain(planWriter); + return sw.toString(); + } + + @Deprecated // to be removed before 2.0 + public static RelNode renameIfNecessary(RelNode rel, RelDataType desiredRowType) { + final RelDataType rowType = rel.getRowType(); + if (rowType == desiredRowType) { + // Nothing to do. + return rel; + } + assert !rowType.equals(desiredRowType); + + if (!areRowTypesEqual(rowType, desiredRowType, false)) { + // The row types are different ignoring names. Nothing we can do. + return rel; + } + rel = createRename(rel, desiredRowType.getFieldNames()); + return rel; + } + + public static String dumpType(RelDataType type) { + final StringWriter sw = new StringWriter(); + final PrintWriter pw = new PrintWriter(sw); + final TypeDumper typeDumper = new TypeDumper(pw); + if (type.isStruct()) { + typeDumper.acceptFields(type.getFieldList()); + } else { + typeDumper.accept(type); + } + pw.flush(); + return sw.toString(); + } + + /** + * Returns the set of columns with unique names, with prior columns taking precedence over + * columns that appear later in the list. + */ + public static List deduplicateColumns( + List baseColumns, List extendedColumns) { + final Set dedupedFieldNames = new HashSet<>(); + final ImmutableList.Builder dedupedFields = ImmutableList.builder(); + for (RelDataTypeField f : Iterables.concat(baseColumns, extendedColumns)) { + if (dedupedFieldNames.add(f.getName())) { + dedupedFields.add(f); + } + } + return dedupedFields.build(); + } + + /** + * Decomposes a predicate into a list of expressions that are AND'ed together. + * + * @param rexPredicate predicate to be analyzed + * @param rexList list of decomposed RexNodes + */ + public static void decomposeConjunction(@Nullable RexNode rexPredicate, List rexList) { + if (rexPredicate == null || rexPredicate.isAlwaysTrue()) { + return; + } + if (rexPredicate.isA(SqlKind.AND)) { + for (RexNode operand : ((RexCall) rexPredicate).getOperands()) { + decomposeConjunction(operand, rexList); + } + } else { + rexList.add(rexPredicate); + } + } + + /** + * Decomposes a predicate into a list of expressions that are AND'ed together, and a list of + * expressions that are preceded by NOT. + * + *

For example, {@code a AND NOT b AND NOT (c and d) AND TRUE AND NOT FALSE} returns {@code + * rexList = [a], notList = [b, c AND d]}. + * + *

TRUE and NOT FALSE expressions are ignored. FALSE and NOT TRUE expressions are placed on + * {@code rexList} and {@code notList} as other expressions. + * + *

For example, {@code a AND TRUE AND NOT TRUE} returns {@code rexList = [a], notList = + * [TRUE]}. + * + * @param rexPredicate predicate to be analyzed + * @param rexList list of decomposed RexNodes (except those with NOT) + * @param notList list of decomposed RexNodes that were prefixed NOT + */ + public static void decomposeConjunction( + @Nullable RexNode rexPredicate, List rexList, List notList) { + if (rexPredicate == null || rexPredicate.isAlwaysTrue()) { + return; + } + switch (rexPredicate.getKind()) { + case AND: + for (RexNode operand : ((RexCall) rexPredicate).getOperands()) { + decomposeConjunction(operand, rexList, notList); + } + break; + case NOT: + final RexNode e = ((RexCall) rexPredicate).getOperands().get(0); + if (e.isAlwaysFalse()) { + return; + } + switch (e.getKind()) { + case OR: + final List ors = new ArrayList<>(); + decomposeDisjunction(e, ors); + for (RexNode or : ors) { + switch (or.getKind()) { + case NOT: + rexList.add(((RexCall) or).operands.get(0)); + break; + default: + notList.add(or); + } + } + break; + default: + notList.add(e); + } + break; + case LITERAL: + if (!RexLiteral.isNullLiteral(rexPredicate) + && RexLiteral.booleanValue(rexPredicate)) { + return; // ignore TRUE + } + // fall through + default: + rexList.add(rexPredicate); + break; + } + } + + /** + * Decomposes a predicate into a list of expressions that are OR'ed together. + * + * @param rexPredicate predicate to be analyzed + * @param rexList list of decomposed RexNodes + */ + public static void decomposeDisjunction(@Nullable RexNode rexPredicate, List rexList) { + if (rexPredicate == null || rexPredicate.isAlwaysFalse()) { + return; + } + if (rexPredicate.isA(SqlKind.OR)) { + for (RexNode operand : ((RexCall) rexPredicate).getOperands()) { + decomposeDisjunction(operand, rexList); + } + } else { + rexList.add(rexPredicate); + } + } + + /** + * Returns a condition decomposed by AND. + * + *

For example, {@code conjunctions(TRUE)} returns the empty list; {@code + * conjunctions(FALSE)} returns list {@code {FALSE}}. + */ + public static List conjunctions(@Nullable RexNode rexPredicate) { + final List list = new ArrayList<>(); + decomposeConjunction(rexPredicate, list); + return list; + } + + /** + * Returns a condition decomposed by OR. + * + *

For example, {@code disjunctions(FALSE)} returns the empty list. + */ + public static List disjunctions(RexNode rexPredicate) { + final List list = new ArrayList<>(); + decomposeDisjunction(rexPredicate, list); + return list; + } + + /** + * Ands two sets of join filters together, either of which can be null. + * + * @param rexBuilder rexBuilder to create AND expression + * @param left filter on the left that the right will be AND'd to + * @param right filter on the right + * @return AND'd filter + * @see org.apache.calcite.rex.RexUtil#composeConjunction + */ + public static RexNode andJoinFilters( + RexBuilder rexBuilder, @Nullable RexNode left, @Nullable RexNode right) { + // don't bother AND'ing in expressions that always evaluate to + // true + if ((left != null) && !left.isAlwaysTrue()) { + if ((right != null) && !right.isAlwaysTrue()) { + left = rexBuilder.makeCall(SqlStdOperatorTable.AND, left, right); + } + } else { + left = right; + } + + // Joins must have some filter + if (left == null) { + left = rexBuilder.makeLiteral(true); + } + return left; + } + + /** + * Decomposes the WHERE clause of a view into predicates that constraint a column to a + * particular value. + * + *

This method is key to the validation of a modifiable view. Columns that are constrained to + * a single value can be omitted from the SELECT clause of a modifiable view. + * + * @param projectMap Mapping from column ordinal to the expression that populate that column, to + * be populated by this method + * @param filters List of remaining filters, to be populated by this method + * @param constraint Constraint to be analyzed + */ + public static void inferViewPredicates( + Map projectMap, List filters, RexNode constraint) { + for (RexNode node : conjunctions(constraint)) { + switch (node.getKind()) { + case EQUALS: + final List operands = ((RexCall) node).getOperands(); + RexNode o0 = operands.get(0); + RexNode o1 = operands.get(1); + if (o0 instanceof RexLiteral) { + o0 = operands.get(1); + o1 = operands.get(0); + } + if (o0.getKind() == SqlKind.CAST) { + o0 = ((RexCall) o0).getOperands().get(0); + } + if (o0 instanceof RexInputRef && o1 instanceof RexLiteral) { + final int index = ((RexInputRef) o0).getIndex(); + if (projectMap.get(index) == null) { + projectMap.put(index, o1); + continue; + } + } + break; + default: + break; + } + filters.add(node); + } + } + + /** + * Returns a mapping of the column ordinal in the underlying table to a column constraint of the + * modifiable view. + * + * @param modifiableViewTable The modifiable view which has a constraint + * @param targetRowType The target type + */ + public static Map getColumnConstraints( + ModifiableView modifiableViewTable, + RelDataType targetRowType, + RelDataTypeFactory typeFactory) { + final RexBuilder rexBuilder = new RexBuilder(typeFactory); + final RexNode constraint = modifiableViewTable.getConstraint(rexBuilder, targetRowType); + final Map projectMap = new HashMap<>(); + final List filters = new ArrayList<>(); + RelOptUtil.inferViewPredicates(projectMap, filters, constraint); + assert filters.isEmpty(); + return projectMap; + } + + /** + * Ensures that a source value does not violate the constraint of the target column. + * + * @param sourceValue The insert value being validated + * @param targetConstraint The constraint applied to sourceValue for validation + * @param errorSupplier The function to apply when validation fails + */ + public static void validateValueAgainstConstraint( + SqlNode sourceValue, + RexNode targetConstraint, + Supplier errorSupplier) { + if (!(sourceValue instanceof SqlLiteral)) { + // We cannot guarantee that the value satisfies the constraint. + throw errorSupplier.get(); + } + final SqlLiteral insertValue = (SqlLiteral) sourceValue; + final RexLiteral columnConstraint = (RexLiteral) targetConstraint; + + final RexSqlStandardConvertletTable convertletTable = new RexSqlStandardConvertletTable(); + final RexToSqlNodeConverter sqlNodeToRexConverter = + new RexToSqlNodeConverterImpl(convertletTable); + final SqlLiteral constraintValue = + (SqlLiteral) sqlNodeToRexConverter.convertLiteral(columnConstraint); + + if (!insertValue.equals(constraintValue)) { + // The value does not satisfy the constraint. + throw errorSupplier.get(); + } + } + + /** + * Adjusts key values in a list by some fixed amount. + * + * @param keys list of key values + * @param adjustment the amount to adjust the key values by + * @return modified list + */ + public static List adjustKeys(List keys, int adjustment) { + if (adjustment == 0) { + return keys; + } + final List newKeys = new ArrayList<>(); + for (int key : keys) { + newKeys.add(key + adjustment); + } + return newKeys; + } + + /** + * Simplifies outer joins if filter above would reject nulls. + * + * @param joinRel Join + * @param aboveFilters Filters from above + * @param joinType Join type, can not be inner join + */ + public static JoinRelType simplifyJoin( + RelNode joinRel, ImmutableList aboveFilters, JoinRelType joinType) { + // No need to simplify if the join only outputs left side. + if (!joinType.projectsRight()) { + return joinType; + } + final int nTotalFields = joinRel.getRowType().getFieldCount(); + final int nSysFields = 0; + final int nFieldsLeft = joinRel.getInputs().get(0).getRowType().getFieldCount(); + final int nFieldsRight = joinRel.getInputs().get(1).getRowType().getFieldCount(); + assert nTotalFields == nSysFields + nFieldsLeft + nFieldsRight; + + // set the reference bitmaps for the left and right children + ImmutableBitSet leftBitmap = ImmutableBitSet.range(nSysFields, nSysFields + nFieldsLeft); + ImmutableBitSet rightBitmap = ImmutableBitSet.range(nSysFields + nFieldsLeft, nTotalFields); + + for (RexNode filter : aboveFilters) { + if (joinType.generatesNullsOnLeft() && Strong.isNotTrue(filter, leftBitmap)) { + joinType = joinType.cancelNullsOnLeft(); + } + if (joinType.generatesNullsOnRight() && Strong.isNotTrue(filter, rightBitmap)) { + joinType = joinType.cancelNullsOnRight(); + } + if (!joinType.isOuterJoin()) { + break; + } + } + return joinType; + } + + /** + * Classifies filters according to where they should be processed. They either stay where they + * are, are pushed to the join (if they originated from above the join), or are pushed to one of + * the children. Filters that are pushed are added to list passed in as input parameters. + * + * @param joinRel join node + * @param filters filters to be classified + * @param pushInto whether filters can be pushed into the join + * @param pushLeft true if filters can be pushed to the left + * @param pushRight true if filters can be pushed to the right + * @param joinFilters list of filters to push to the join + * @param leftFilters list of filters to push to the left child + * @param rightFilters list of filters to push to the right child + * @return whether at least one filter was pushed + */ + public static boolean classifyFilters( + RelNode joinRel, + List filters, + boolean pushInto, + boolean pushLeft, + boolean pushRight, + List joinFilters, + List leftFilters, + List rightFilters) { + RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder(); + List joinFields = joinRel.getRowType().getFieldList(); + final int nSysFields = 0; // joinRel.getSystemFieldList().size(); + final List leftFields = + joinRel.getInputs().get(0).getRowType().getFieldList(); + final int nFieldsLeft = leftFields.size(); + final List rightFields = + joinRel.getInputs().get(1).getRowType().getFieldList(); + final int nFieldsRight = rightFields.size(); + final int nTotalFields = nFieldsLeft + nFieldsRight; + + // set the reference bitmaps for the left and right children + ImmutableBitSet leftBitmap = ImmutableBitSet.range(nSysFields, nSysFields + nFieldsLeft); + ImmutableBitSet rightBitmap = ImmutableBitSet.range(nSysFields + nFieldsLeft, nTotalFields); + + final List filtersToRemove = new ArrayList<>(); + for (RexNode filter : filters) { + final InputFinder inputFinder = InputFinder.analyze(filter); + final ImmutableBitSet inputBits = inputFinder.build(); + + // REVIEW - are there any expressions that need special handling + // and therefore cannot be pushed? + + if (pushLeft && leftBitmap.contains(inputBits)) { + // ignore filters that always evaluate to true + if (!filter.isAlwaysTrue()) { + // adjust the field references in the filter to reflect + // that fields in the left now shift over by the number + // of system fields + final RexNode shiftedFilter = + shiftFilter( + nSysFields, + nSysFields + nFieldsLeft, + -nSysFields, + rexBuilder, + joinFields, + nTotalFields, + leftFields, + filter); + + leftFilters.add(shiftedFilter); + } + filtersToRemove.add(filter); + } else if (pushRight && rightBitmap.contains(inputBits)) { + if (!filter.isAlwaysTrue()) { + // adjust the field references in the filter to reflect + // that fields in the right now shift over to the left + final RexNode shiftedFilter = + shiftFilter( + nSysFields + nFieldsLeft, + nTotalFields, + -(nSysFields + nFieldsLeft), + rexBuilder, + joinFields, + nTotalFields, + rightFields, + filter); + rightFilters.add(shiftedFilter); + } + filtersToRemove.add(filter); + + } else { + // If the filter can't be pushed to either child, we may push them into the join + if (pushInto) { + if (!joinFilters.contains(filter)) { + joinFilters.add(filter); + } + filtersToRemove.add(filter); + } + } + } + + // Remove filters after the loop, to prevent concurrent modification. + if (!filtersToRemove.isEmpty()) { + filters.removeAll(filtersToRemove); + } + + // Did anything change? + return !filtersToRemove.isEmpty(); + } + + /** + * Classifies filters according to where they should be processed. They either stay where they + * are, are pushed to the join (if they originated from above the join), or are pushed to one of + * the children. Filters that are pushed are added to list passed in as input parameters. + * + * @param joinRel join node + * @param filters filters to be classified + * @param joinType join type + * @param pushInto whether filters can be pushed into the ON clause + * @param pushLeft true if filters can be pushed to the left + * @param pushRight true if filters can be pushed to the right + * @param joinFilters list of filters to push to the join + * @param leftFilters list of filters to push to the left child + * @param rightFilters list of filters to push to the right child + * @return whether at least one filter was pushed + * @deprecated Use {@link RelOptUtil#classifyFilters(RelNode, List, boolean, boolean, boolean, + * List, List, List)} + */ + @Deprecated // to be removed before 2.0 + public static boolean classifyFilters( + RelNode joinRel, + List filters, + JoinRelType joinType, + boolean pushInto, + boolean pushLeft, + boolean pushRight, + List joinFilters, + List leftFilters, + List rightFilters) { + RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder(); + List joinFields = joinRel.getRowType().getFieldList(); + final int nTotalFields = joinFields.size(); + final int nSysFields = 0; // joinRel.getSystemFieldList().size(); + final List leftFields = + joinRel.getInputs().get(0).getRowType().getFieldList(); + final int nFieldsLeft = leftFields.size(); + final List rightFields = + joinRel.getInputs().get(1).getRowType().getFieldList(); + final int nFieldsRight = rightFields.size(); + + // SemiJoin, CorrelateSemiJoin, CorrelateAntiJoin: right fields are not returned + assert nTotalFields + == (!joinType.projectsRight() + ? nSysFields + nFieldsLeft + : nSysFields + nFieldsLeft + nFieldsRight); + + // set the reference bitmaps for the left and right children + ImmutableBitSet leftBitmap = ImmutableBitSet.range(nSysFields, nSysFields + nFieldsLeft); + ImmutableBitSet rightBitmap = ImmutableBitSet.range(nSysFields + nFieldsLeft, nTotalFields); + + final List filtersToRemove = new ArrayList<>(); + for (RexNode filter : filters) { + final InputFinder inputFinder = InputFinder.analyze(filter); + final ImmutableBitSet inputBits = inputFinder.build(); + + // REVIEW - are there any expressions that need special handling + // and therefore cannot be pushed? + + // filters can be pushed to the left child if the left child + // does not generate NULLs and the only columns referenced in + // the filter originate from the left child + if (pushLeft && leftBitmap.contains(inputBits)) { + // ignore filters that always evaluate to true + if (!filter.isAlwaysTrue()) { + // adjust the field references in the filter to reflect + // that fields in the left now shift over by the number + // of system fields + final RexNode shiftedFilter = + shiftFilter( + nSysFields, + nSysFields + nFieldsLeft, + -nSysFields, + rexBuilder, + joinFields, + nTotalFields, + leftFields, + filter); + + leftFilters.add(shiftedFilter); + } + filtersToRemove.add(filter); + + // filters can be pushed to the right child if the right child + // does not generate NULLs and the only columns referenced in + // the filter originate from the right child + } else if (pushRight && rightBitmap.contains(inputBits)) { + if (!filter.isAlwaysTrue()) { + // adjust the field references in the filter to reflect + // that fields in the right now shift over to the left; + // since we never push filters to a NULL generating + // child, the types of the source should match the dest + // so we don't need to explicitly pass the destination + // fields to RexInputConverter + final RexNode shiftedFilter = + shiftFilter( + nSysFields + nFieldsLeft, + nTotalFields, + -(nSysFields + nFieldsLeft), + rexBuilder, + joinFields, + nTotalFields, + rightFields, + filter); + rightFilters.add(shiftedFilter); + } + filtersToRemove.add(filter); + + } else { + // If the filter can't be pushed to either child and the join + // is an inner join, push them to the join if they originated + // from above the join + if (!joinType.isOuterJoin() && pushInto) { + if (!joinFilters.contains(filter)) { + joinFilters.add(filter); + } + filtersToRemove.add(filter); + } + } + } + + // Remove filters after the loop, to prevent concurrent modification. + if (!filtersToRemove.isEmpty()) { + filters.removeAll(filtersToRemove); + } + + // Did anything change? + return !filtersToRemove.isEmpty(); + } + + private static RexNode shiftFilter( + int start, + int end, + int offset, + RexBuilder rexBuilder, + List joinFields, + int nTotalFields, + List rightFields, + RexNode filter) { + int[] adjustments = new int[nTotalFields]; + for (int i = start; i < end; i++) { + adjustments[i] = offset; + } + return filter.accept( + new RexInputConverter(rexBuilder, joinFields, rightFields, adjustments)); + } + + /** + * Splits a filter into two lists, depending on whether or not the filter only references its + * child input. + * + * @param childBitmap Fields in the child + * @param predicate filters that will be split + * @param pushable returns the list of filters that can be pushed to the child input + * @param notPushable returns the list of filters that cannot be pushed to the child input + */ + public static void splitFilters( + ImmutableBitSet childBitmap, + @Nullable RexNode predicate, + List pushable, + List notPushable) { + // for each filter, if the filter only references the child inputs, + // then it can be pushed + for (RexNode filter : conjunctions(predicate)) { + ImmutableBitSet filterRefs = InputFinder.bits(filter); + if (childBitmap.contains(filterRefs)) { + pushable.add(filter); + } else { + notPushable.add(filter); + } + } + } + + @Deprecated // to be removed before 2.0 + public static boolean checkProjAndChildInputs(Project project, boolean checkNames) { + int n = project.getProjects().size(); + RelDataType inputType = project.getInput().getRowType(); + if (inputType.getFieldList().size() != n) { + return false; + } + List projFields = project.getRowType().getFieldList(); + List inputFields = inputType.getFieldList(); + boolean namesDifferent = false; + for (int i = 0; i < n; ++i) { + RexNode exp = project.getProjects().get(i); + if (!(exp instanceof RexInputRef)) { + return false; + } + RexInputRef fieldAccess = (RexInputRef) exp; + if (i != fieldAccess.getIndex()) { + // can't support reorder yet + return false; + } + if (checkNames) { + String inputFieldName = inputFields.get(i).getName(); + String projFieldName = projFields.get(i).getName(); + if (!projFieldName.equals(inputFieldName)) { + namesDifferent = true; + } + } + } + + // inputs are the same; return value depends on the checkNames + // parameter + return !checkNames || namesDifferent; + } + + /** + * Creates projection expressions reflecting the swapping of a join's input. + * + * @param newJoin the RelNode corresponding to the join with its inputs swapped + * @param origJoin original LogicalJoin + * @param origOrder if true, create the projection expressions to reflect the original + * (pre-swapped) join projection; otherwise, create the projection to reflect the order of + * the swapped projection + * @return array of expression representing the swapped join inputs + */ + public static List createSwappedJoinExprs( + RelNode newJoin, Join origJoin, boolean origOrder) { + final List newJoinFields = newJoin.getRowType().getFieldList(); + final RexBuilder rexBuilder = newJoin.getCluster().getRexBuilder(); + final List exps = new ArrayList<>(); + final int nFields = + origOrder + ? origJoin.getRight().getRowType().getFieldCount() + : origJoin.getLeft().getRowType().getFieldCount(); + for (int i = 0; i < newJoinFields.size(); i++) { + final int source = (i + nFields) % newJoinFields.size(); + RelDataTypeField field = origOrder ? newJoinFields.get(source) : newJoinFields.get(i); + exps.add(rexBuilder.makeInputRef(field.getType(), source)); + } + return exps; + } + + @Deprecated // to be removed before 2.0 + public static RexNode pushFilterPastProject(RexNode filter, final Project projRel) { + return pushPastProject(filter, projRel); + } + + /** + * Converts an expression that is based on the output fields of a {@link Project} to an + * equivalent expression on the Project's input fields. + * + * @param node The expression to be converted + * @param project Project underneath the expression + * @return converted expression + */ + public static RexNode pushPastProject(RexNode node, Project project) { + return node.accept(pushShuttle(project)); + } + + /** + * Converts a list of expressions that are based on the output fields of a {@link Project} to + * equivalent expressions on the Project's input fields. + * + * @param nodes The expressions to be converted + * @param project Project underneath the expression + * @return converted expressions + */ + public static List pushPastProject(List nodes, Project project) { + return pushShuttle(project).visitList(nodes); + } + + /** + * As {@link #pushPastProject}, but returns null if the resulting expressions are significantly + * more complex. + * + * @param bloat Maximum allowable increase in complexity + */ + public static @Nullable List pushPastProjectUnlessBloat( + List nodes, Project project, int bloat) { + if (bloat < 0) { + // If bloat is negative never merge. + return null; + } + if (RexOver.containsOver(nodes, null) && project.containsOver()) { + // Is it valid relational algebra to apply windowed function to a windowed + // function? Possibly. But it's invalid SQL, so don't go there. + return null; + } + final List list = pushPastProject(nodes, project); + final int bottomCount = RexUtil.nodeCount(project.getProjects()); + final int topCount = RexUtil.nodeCount(nodes); + final int mergedCount = RexUtil.nodeCount(list); + if (mergedCount > bottomCount + topCount + bloat) { + // The merged expression is more complex than the input expressions. + // Do not merge. + return null; + } + return list; + } + + private static RexShuttle pushShuttle(final Project project) { + return new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef ref) { + return project.getProjects().get(ref.getIndex()); + } + }; + } + + /** + * Converts an expression that is based on the output fields of a {@link Calc} to an equivalent + * expression on the Calc's input fields. + * + * @param node The expression to be converted + * @param calc Calc underneath the expression + * @return converted expression + */ + public static RexNode pushPastCalc(RexNode node, Calc calc) { + return node.accept(pushShuttle(calc)); + } + + private static RexShuttle pushShuttle(final Calc calc) { + final List projects = + Util.transform( + calc.getProgram().getProjectList(), calc.getProgram()::expandLocalRef); + return new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef ref) { + return projects.get(ref.getIndex()); + } + }; + } + + /** + * Creates a new {@link org.apache.calcite.rel.rules.MultiJoin} to reflect projection references + * from a {@link Project} that is on top of the {@link org.apache.calcite.rel.rules.MultiJoin}. + * + * @param multiJoin the original MultiJoin + * @param project the Project on top of the MultiJoin + * @return the new MultiJoin + */ + public static MultiJoin projectMultiJoin(MultiJoin multiJoin, Project project) { + // Locate all input references in the projection expressions as well + // the post-join filter. Since the filter effectively sits in + // between the LogicalProject and the MultiJoin, the projection needs + // to include those filter references. + ImmutableBitSet inputRefs = + InputFinder.bits(project.getProjects(), multiJoin.getPostJoinFilter()); + + // create new copies of the bitmaps + List multiJoinInputs = multiJoin.getInputs(); + List newProjFields = new ArrayList<>(); + for (RelNode multiJoinInput : multiJoinInputs) { + newProjFields.add(new BitSet(multiJoinInput.getRowType().getFieldCount())); + } + + // set the bits found in the expressions + int currInput = -1; + int startField = 0; + int nFields = 0; + for (int bit : inputRefs) { + while (bit >= (startField + nFields)) { + startField += nFields; + currInput++; + assert currInput < multiJoinInputs.size(); + nFields = multiJoinInputs.get(currInput).getRowType().getFieldCount(); + } + newProjFields.get(currInput).set(bit - startField); + } + + // create a new MultiJoin containing the new field bitmaps + // for each input + return new MultiJoin( + multiJoin.getCluster(), + multiJoin.getInputs(), + multiJoin.getJoinFilter(), + multiJoin.getRowType(), + multiJoin.isFullOuterJoin(), + multiJoin.getOuterJoinConditions(), + multiJoin.getJoinTypes(), + Util.transform(newProjFields, ImmutableBitSet::fromBitSet), + multiJoin.getJoinFieldRefCountsMap(), + multiJoin.getPostJoinFilter()); + } + + public static T addTrait(T rel, RelTrait trait) { + //noinspection unchecked + return (T) rel.copy(rel.getTraitSet().replace(trait), rel.getInputs()); + } + + /** Returns a shallow copy of a relational expression with a particular input replaced. */ + public static RelNode replaceInput(RelNode parent, int ordinal, RelNode newInput) { + final List inputs = new ArrayList<>(parent.getInputs()); + if (inputs.get(ordinal) == newInput) { + return parent; + } + inputs.set(ordinal, newInput); + return parent.copy(parent.getTraitSet(), inputs); + } + + /** + * Creates a {@link org.apache.calcite.rel.logical.LogicalProject} that projects particular + * fields of its input, according to a mapping. + */ + public static RelNode createProject(RelNode child, Mappings.TargetMapping mapping) { + return createProject(child, Mappings.asListNonNull(mapping.inverse())); + } + + public static RelNode createProject( + RelNode child, + Mappings.TargetMapping mapping, + RelFactories.ProjectFactory projectFactory) { + return createProject(projectFactory, child, Mappings.asListNonNull(mapping.inverse())); + } + + /** + * Returns the relational table node for {@code tableName} if it occurs within a relational + * expression {@code root} otherwise an empty option is returned. + */ + public static @Nullable RelOptTable findTable(RelNode root, final String tableName) { + try { + RelShuttle visitor = + new RelHomogeneousShuttle() { + @Override + public RelNode visit(TableScan scan) { + final RelOptTable scanTable = scan.getTable(); + final List qualifiedName = scanTable.getQualifiedName(); + if (qualifiedName.get(qualifiedName.size() - 1).equals(tableName)) { + throw new Util.FoundOne(scanTable); + } + return super.visit(scan); + } + }; + root.accept(visitor); + return null; + } catch (Util.FoundOne e) { + Util.swallow(e, null); + return (RelOptTable) e.getNode(); + } + } + + /** + * Returns whether relational expression {@code target} occurs within a relational expression + * {@code ancestor}. + */ + public static boolean contains(RelNode ancestor, final RelNode target) { + if (ancestor == target) { + // Short-cut common case. + return true; + } + try { + new RelVisitor() { + @Override + public void visit(RelNode node, int ordinal, @Nullable RelNode parent) { + if (node == target) { + throw Util.FoundOne.NULL; + } + super.visit(node, ordinal, parent); + } + // CHECKSTYLE: IGNORE 1 + }.go(ancestor); + return false; + } catch (Util.FoundOne e) { + return true; + } + } + + /** + * Within a relational expression {@code query}, replaces occurrences of {@code find} with + * {@code replace}. + */ + public static RelNode replace(RelNode query, RelNode find, RelNode replace) { + if (find == replace) { + // Short-cut common case. + return query; + } + assert equalType("find", find, "replace", replace, Litmus.THROW); + if (query == find) { + // Short-cut another common case. + return replace; + } + return replaceRecurse(query, find, replace); + } + + /** Helper for {@link #replace}. */ + private static RelNode replaceRecurse(RelNode query, RelNode find, RelNode replace) { + if (query == find) { + return replace; + } + final List inputs = query.getInputs(); + if (!inputs.isEmpty()) { + final List newInputs = new ArrayList<>(); + for (RelNode input : inputs) { + newInputs.add(replaceRecurse(input, find, replace)); + } + if (!newInputs.equals(inputs)) { + return query.copy(query.getTraitSet(), newInputs); + } + } + return query; + } + + @Deprecated // to be removed before 2.0 + public static RelOptTable.ToRelContext getContext(RelOptCluster cluster) { + return ViewExpanders.simpleContext(cluster); + } + + /** Returns the number of {@link org.apache.calcite.rel.core.Join} nodes in a tree. */ + public static int countJoins(RelNode rootRel) { + /** Visitor that counts join nodes. */ + class JoinCounter extends RelVisitor { + int joinCount; + + @Override + public void visit( + RelNode node, + int ordinal, + @org.checkerframework.checker.nullness.qual.Nullable RelNode parent) { + if (node instanceof Join) { + ++joinCount; + } + super.visit(node, ordinal, parent); + } + + int run(RelNode node) { + go(node); + return joinCount; + } + } + + return new JoinCounter().run(rootRel); + } + + /** Permutes a record type according to a mapping. */ + public static RelDataType permute( + RelDataTypeFactory typeFactory, RelDataType rowType, Mapping mapping) { + return typeFactory.createStructType(Mappings.apply3(mapping, rowType.getFieldList())); + } + + @Deprecated // to be removed before 2.0 + public static RelNode createProject( + RelNode child, List exprList, List fieldNameList) { + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(child.getCluster(), null); + return relBuilder.push(child).project(exprList, fieldNameList, true).build(); + } + + @Deprecated // to be removed before 2.0 + public static RelNode createProject( + RelNode child, + List> projectList, + boolean optimize) { + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(child.getCluster(), null); + return relBuilder + .push(child) + .projectNamed(Pair.left(projectList), Pair.right(projectList), !optimize) + .build(); + } + + /** + * Creates a relational expression that projects the given fields of the input. + * + *

Optimizes if the fields are the identity projection. + * + * @param child Input relational expression + * @param posList Source of each projected field + * @return Relational expression that projects given fields + */ + public static RelNode createProject(final RelNode child, final List posList) { + return createProject(RelFactories.DEFAULT_PROJECT_FACTORY, child, posList); + } + + @Deprecated // to be removed before 2.0 + public static RelNode createProject( + RelNode child, + List exprs, + List fieldNames, + boolean optimize) { + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(child.getCluster(), null); + return relBuilder.push(child).projectNamed(exprs, fieldNames, !optimize).build(); + } + + // CHECKSTYLE: IGNORE 1 + /** + * @deprecated Use {@link RelBuilder#projectNamed(Iterable, Iterable, boolean)} + */ + @Deprecated // to be removed before 2.0 + public static RelNode createProject( + RelNode child, + List exprs, + List fieldNames, + boolean optimize, + RelBuilder relBuilder) { + return relBuilder.push(child).projectNamed(exprs, fieldNames, !optimize).build(); + } + + @Deprecated // to be removed before 2.0 + public static RelNode createRename(RelNode rel, List fieldNames) { + final List fields = rel.getRowType().getFieldList(); + assert fieldNames.size() == fields.size(); + final List refs = + new AbstractList() { + @Override + public int size() { + return fields.size(); + } + + @Override + public RexNode get(int index) { + return RexInputRef.of(index, fields); + } + }; + final RelBuilder relBuilder = RelFactories.LOGICAL_BUILDER.create(rel.getCluster(), null); + return relBuilder.push(rel).projectNamed(refs, fieldNames, false).build(); + } + + /** + * Creates a relational expression which permutes the output fields of a relational expression + * according to a permutation. + * + *

Optimizations: + * + *

    + *
  • If the relational expression is a {@link org.apache.calcite.rel.logical.LogicalCalc} or + * {@link org.apache.calcite.rel.logical.LogicalProject} that is already acting as a + * permutation, combines the new permutation with the old; + *
  • If the permutation is the identity, returns the original relational expression. + *
+ * + *

If a permutation is combined with its inverse, these optimizations would combine to remove + * them both. + * + * @param rel Relational expression + * @param permutation Permutation to apply to fields + * @param fieldNames Field names; if null, or if a particular entry is null, the name of the + * permuted field is used + * @return relational expression which permutes its input fields + */ + public static RelNode permute( + RelNode rel, Permutation permutation, @Nullable List fieldNames) { + if (permutation.isIdentity()) { + return rel; + } + if (rel instanceof LogicalCalc) { + LogicalCalc calc = (LogicalCalc) rel; + Permutation permutation1 = calc.getProgram().getPermutation(); + if (permutation1 != null) { + Permutation permutation2 = permutation.product(permutation1); + return permute(rel, permutation2, null); + } + } + if (rel instanceof LogicalProject) { + Permutation permutation1 = ((LogicalProject) rel).getPermutation(); + if (permutation1 != null) { + Permutation permutation2 = permutation.product(permutation1); + return permute(rel, permutation2, null); + } + } + final List outputTypeList = new ArrayList<>(); + final List outputNameList = new ArrayList<>(); + final List exprList = new ArrayList<>(); + final List projectRefList = new ArrayList<>(); + final List fields = rel.getRowType().getFieldList(); + final RelOptCluster cluster = rel.getCluster(); + for (int i = 0; i < permutation.getTargetCount(); i++) { + int target = permutation.getTarget(i); + final RelDataTypeField targetField = fields.get(target); + outputTypeList.add(targetField.getType()); + outputNameList.add( + ((fieldNames == null) + || (fieldNames.size() <= i) + || (fieldNames.get(i) == null)) + ? targetField.getName() + : fieldNames.get(i)); + exprList.add(cluster.getRexBuilder().makeInputRef(fields.get(i).getType(), i)); + final int source = permutation.getSource(i); + projectRefList.add(new RexLocalRef(source, fields.get(source).getType())); + } + final RelDataTypeFactory typeFactory = cluster.getTypeFactory(); + final RexProgram program = + new RexProgram( + rel.getRowType(), + exprList, + projectRefList, + null, + typeFactory.createStructType(outputTypeList, outputNameList)); + return LogicalCalc.create(rel, program); + } + + /** + * Creates a relational expression that projects the given fields of the input. + * + *

Optimizes if the fields are the identity projection. + * + * @param factory ProjectFactory + * @param child Input relational expression + * @param posList Source of each projected field + * @return Relational expression that projects given fields + */ + public static RelNode createProject( + final RelFactories.ProjectFactory factory, + final RelNode child, + final List posList) { + RelDataType rowType = child.getRowType(); + final List fieldNames = rowType.getFieldNames(); + final RelBuilder relBuilder = RelBuilder.proto(factory).create(child.getCluster(), null); + final List exprs = + new AbstractList() { + @Override + public int size() { + return posList.size(); + } + + @Override + public RexNode get(int index) { + final int pos = posList.get(index); + return relBuilder.getRexBuilder().makeInputRef(child, pos); + } + }; + final List names = Util.select(fieldNames, posList); + return relBuilder.push(child).projectNamed(exprs, names, false).build(); + } + + @Deprecated // to be removed before 2.0 + public static RelNode projectMapping( + RelNode rel, + Mapping mapping, + @Nullable List fieldNames, + RelFactories.ProjectFactory projectFactory) { + assert mapping.getMappingType().isSingleSource(); + assert mapping.getMappingType().isMandatorySource(); + if (mapping.isIdentity()) { + return rel; + } + final List outputNameList = new ArrayList<>(); + final List exprList = new ArrayList<>(); + final List fields = rel.getRowType().getFieldList(); + final RexBuilder rexBuilder = rel.getCluster().getRexBuilder(); + for (int i = 0; i < mapping.getTargetCount(); i++) { + final int source = mapping.getSource(i); + final RelDataTypeField sourceField = fields.get(source); + outputNameList.add( + ((fieldNames == null) + || (fieldNames.size() <= i) + || (fieldNames.get(i) == null)) + ? sourceField.getName() + : fieldNames.get(i)); + exprList.add(rexBuilder.makeInputRef(rel, source)); + } + return projectFactory.createProject( + rel, ImmutableList.of(), exprList, outputNameList, ImmutableSet.of()); + } + + /** Predicate for if a {@link Calc} does not contain windowed aggregates. */ + public static boolean notContainsWindowedAgg(Calc calc) { + return !calc.containsOver(); + } + + /** Predicate for if a {@link Filter} does not windowed aggregates. */ + public static boolean notContainsWindowedAgg(Filter filter) { + return !filter.containsOver(); + } + + /** Predicate for if a {@link Project} does not contain windowed aggregates. */ + public static boolean notContainsWindowedAgg(Project project) { + return !project.containsOver(); + } + + /** Policies for handling two- and three-valued boolean logic. */ + public enum Logic { + /** Three-valued boolean logic. */ + TRUE_FALSE_UNKNOWN, + + /** Nulls are not possible. */ + TRUE_FALSE, + + /** + * Two-valued logic where UNKNOWN is treated as FALSE. + * + *

"x IS TRUE" produces the same result, and "WHERE x", "JOIN ... ON x" and "HAVING x" + * have the same effect. + */ + UNKNOWN_AS_FALSE, + + /** + * Two-valued logic where UNKNOWN is treated as TRUE. + * + *

"x IS FALSE" produces the same result, as does "WHERE NOT x", etc. + * + *

In particular, this is the mode used by "WHERE k NOT IN q". If "k IN q" produces TRUE + * or UNKNOWN, "NOT k IN q" produces FALSE or UNKNOWN and the row is eliminated; if "k IN q" + * it returns FALSE, the row is retained by the WHERE clause. + */ + UNKNOWN_AS_TRUE, + + /** + * A semi-join will have been applied, so that only rows for which the value is TRUE will + * have been returned. + */ + TRUE, + + /** + * An anti-semi-join will have been applied, so that only rows for which the value is FALSE + * will have been returned. + * + *

Currently only used within {@link LogicVisitor}, to ensure that 'NOT (NOT EXISTS (q))' + * behaves the same as 'EXISTS (q)') + */ + FALSE; + + public Logic negate() { + switch (this) { + case UNKNOWN_AS_FALSE: + case TRUE: + return UNKNOWN_AS_TRUE; + case UNKNOWN_AS_TRUE: + return UNKNOWN_AS_FALSE; + default: + return this; + } + } + + /** + * Variant of {@link #negate()} to be used within {@link LogicVisitor}, where FALSE values + * may exist. + */ + public Logic negate2() { + switch (this) { + case FALSE: + return TRUE; + case TRUE: + return FALSE; + case UNKNOWN_AS_FALSE: + return UNKNOWN_AS_TRUE; + case UNKNOWN_AS_TRUE: + return UNKNOWN_AS_FALSE; + default: + return this; + } + } + } + + /** + * Pushes down expressions in "equal" join condition. + * + *

For example, given "emp JOIN dept ON emp.deptno + 1 = dept.deptno", adds a project above + * "emp" that computes the expression "emp.deptno + 1". The resulting join condition is a simple + * combination of AND, equals, and input fields, plus the remaining non-equal conditions. + * + * @param originalJoin Join whose condition is to be pushed down + * @param relBuilder Factory to create project operator + */ + public static RelNode pushDownJoinConditions(Join originalJoin, RelBuilder relBuilder) { + RexNode joinCond = originalJoin.getCondition(); + final JoinRelType joinType = originalJoin.getJoinType(); + + final List extraLeftExprs = new ArrayList<>(); + final List extraRightExprs = new ArrayList<>(); + final int leftCount = originalJoin.getLeft().getRowType().getFieldCount(); + final int rightCount = originalJoin.getRight().getRowType().getFieldCount(); + + // You cannot push a 'get' because field names might change. + // + // Pushing sub-queries is OK in principle (if they don't reference both + // sides of the join via correlating variables) but we'd rather not do it + // yet. + if (!containsGet(joinCond) && RexUtil.SubQueryFinder.find(joinCond) == null) { + joinCond = + pushDownEqualJoinConditions( + joinCond, + leftCount, + rightCount, + extraLeftExprs, + extraRightExprs, + relBuilder.getRexBuilder()); + } + + final PairList pairs = PairList.of(); + relBuilder.push(originalJoin.getLeft()); + if (!extraLeftExprs.isEmpty()) { + final List fields = relBuilder.peek().getRowType().getFieldList(); + for (int i = 0, n = leftCount + extraLeftExprs.size(); i < n; i++) { + if (i < leftCount) { + RelDataTypeField field = fields.get(i); + pairs.add(new RexInputRef(i, field.getType()), field.getName()); + } else { + pairs.add(extraLeftExprs.get(i - leftCount), null); + } + } + relBuilder.project(pairs.leftList(), pairs.rightList()); + pairs.clear(); + } + + relBuilder.push(originalJoin.getRight()); + if (!extraRightExprs.isEmpty()) { + final List fields = relBuilder.peek().getRowType().getFieldList(); + final int newLeftCount = leftCount + extraLeftExprs.size(); + for (int i = 0, n = rightCount + extraRightExprs.size(); i < n; i++) { + if (i < rightCount) { + RelDataTypeField field = fields.get(i); + pairs.add(new RexInputRef(i, field.getType()), field.getName()); + } else { + pairs.add( + RexUtil.shift(extraRightExprs.get(i - rightCount), -newLeftCount), + null); + } + } + relBuilder.project(pairs.leftList(), pairs.rightList()); + pairs.clear(); + } + + final RelNode right = relBuilder.build(); + final RelNode left = relBuilder.build(); + relBuilder.push( + originalJoin.copy( + originalJoin.getTraitSet(), + joinCond, + left, + right, + joinType, + originalJoin.isSemiJoinDone())); + if (!extraLeftExprs.isEmpty() || !extraRightExprs.isEmpty()) { + final int totalFields = + joinType.projectsRight() + ? leftCount + + extraLeftExprs.size() + + rightCount + + extraRightExprs.size() + : leftCount + extraLeftExprs.size(); + final int[] mappingRanges = + joinType.projectsRight() + ? new int[] { + 0, + 0, + leftCount, + leftCount, + leftCount + extraLeftExprs.size(), + rightCount + } + : new int[] {0, 0, leftCount}; + Mappings.TargetMapping mapping = + Mappings.createShiftMapping(totalFields, mappingRanges); + relBuilder.project(relBuilder.fields(mapping.inverse())); + } + return relBuilder.build(); + } + + @Deprecated // to be removed before 2.0 + public static RelNode pushDownJoinConditions(Join originalJoin) { + return pushDownJoinConditions(originalJoin, RelFactories.LOGICAL_BUILDER); + } + + @Deprecated // to be removed before 2.0 + public static RelNode pushDownJoinConditions( + Join originalJoin, RelFactories.ProjectFactory projectFactory) { + return pushDownJoinConditions(originalJoin, RelBuilder.proto(projectFactory)); + } + + private static RelNode pushDownJoinConditions( + Join originalJoin, RelBuilderFactory relBuilderFactory) { + return pushDownJoinConditions( + originalJoin, relBuilderFactory.create(originalJoin.getCluster(), null)); + } + + private static boolean containsGet(RexNode node) { + try { + node.accept( + new RexVisitorImpl(true) { + @Override + public Void visitCall(RexCall call) { + if (call.getOperator() == RexBuilder.GET_OPERATOR) { + throw Util.FoundOne.NULL; + } + return super.visitCall(call); + } + }); + return false; + } catch (Util.FoundOne e) { + return true; + } + } + + /** + * Pushes down parts of a join condition. + * + *

For example, given "emp JOIN dept ON emp.deptno + 1 = dept.deptno", adds a project above + * "emp" that computes the expression "emp.deptno + 1". The resulting join condition is a simple + * combination of AND, equals, and input fields. + */ + private static RexNode pushDownEqualJoinConditions( + RexNode condition, + int leftCount, + int rightCount, + List extraLeftExprs, + List extraRightExprs, + RexBuilder builder) { + // Normalize the condition first + RexNode node = + (condition instanceof RexCall) + ? collapseExpandedIsNotDistinctFromExpr((RexCall) condition, builder) + : condition; + + switch (node.getKind()) { + case EQUALS: + case IS_NOT_DISTINCT_FROM: + final RexCall call0 = (RexCall) node; + final RexNode leftRex = call0.getOperands().get(0); + final RexNode rightRex = call0.getOperands().get(1); + final ImmutableBitSet leftBits = RelOptUtil.InputFinder.bits(leftRex); + final ImmutableBitSet rightBits = RelOptUtil.InputFinder.bits(rightRex); + final int pivot = leftCount + extraLeftExprs.size(); + Side lside = Side.of(leftBits, pivot); + Side rside = Side.of(rightBits, pivot); + if (!lside.opposite(rside)) { + return call0; + } + // fall through + case AND: + final RexCall call = (RexCall) node; + final List list = new ArrayList<>(); + List operands = Lists.newArrayList(call.getOperands()); + for (int i = 0; i < operands.size(); i++) { + RexNode operand = operands.get(i); + if (operand instanceof RexCall) { + operand = collapseExpandedIsNotDistinctFromExpr((RexCall) operand, builder); + } + if (node.getKind() == SqlKind.AND + && operand.getKind() != SqlKind.EQUALS + && operand.getKind() != SqlKind.IS_NOT_DISTINCT_FROM) { + // one of the join condition is neither EQ nor INDF + list.add(operand); + } else { + final int left2 = leftCount + extraLeftExprs.size(); + final RexNode e = + pushDownEqualJoinConditions( + operand, + leftCount, + rightCount, + extraLeftExprs, + extraRightExprs, + builder); + if (!e.equals(operand)) { + final List remainingOperands = Util.skip(operands, i + 1); + final int left3 = leftCount + extraLeftExprs.size(); + fix(remainingOperands, left2, left3); + fix(list, left2, left3); + } + list.add(e); + } + } + if (!list.equals(call.getOperands())) { + return call.clone(call.getType(), list); + } + return call; + case OR: + case INPUT_REF: + case LITERAL: + case NOT: + return node; + default: + final ImmutableBitSet bits = RelOptUtil.InputFinder.bits(node); + final int mid = leftCount + extraLeftExprs.size(); + switch (Side.of(bits, mid)) { + case LEFT: + fix(extraRightExprs, mid, mid + 1); + extraLeftExprs.add(node); + return new RexInputRef(mid, node.getType()); + case RIGHT: + final int index2 = mid + rightCount + extraRightExprs.size(); + extraRightExprs.add(node); + return new RexInputRef(index2, node.getType()); + case BOTH: + case EMPTY: + default: + return node; + } + } + } + + private static void fix(List operands, int before, int after) { + if (before == after) { + return; + } + for (int i = 0; i < operands.size(); i++) { + RexNode node = operands.get(i); + operands.set(i, RexUtil.shift(node, before, after - before)); + } + } + + /** + * Determines whether any of the fields in a given relational expression may contain null + * values, taking into account constraints on the field types and also deduced predicates. + * + *

The method is cautious: It may sometimes return {@code true} when the actual answer is + * {@code false}. In particular, it does this when there is no executor, or the executor is not + * a sub-class of {@link RexExecutorImpl}. + */ + private static boolean containsNullableFields(RelNode r) { + final RexBuilder rexBuilder = r.getCluster().getRexBuilder(); + final RelDataType rowType = r.getRowType(); + final List list = new ArrayList<>(); + final RelMetadataQuery mq = r.getCluster().getMetadataQuery(); + for (RelDataTypeField field : rowType.getFieldList()) { + if (field.getType().isNullable()) { + list.add( + rexBuilder.makeCall( + SqlStdOperatorTable.IS_NOT_NULL, + rexBuilder.makeInputRef(field.getType(), field.getIndex()))); + } + } + if (list.isEmpty()) { + // All columns are declared NOT NULL. + return false; + } + final RelOptPredicateList predicates = mq.getPulledUpPredicates(r); + if (RelOptPredicateList.isEmpty(predicates)) { + // We have no predicates, so cannot deduce that any of the fields + // declared NULL are really NOT NULL. + return true; + } + final RexExecutor executor = r.getCluster().getPlanner().getExecutor(); + if (!(executor instanceof RexExecutorImpl)) { + // Cannot proceed without an executor. + return true; + } + final RexImplicationChecker checker = + new RexImplicationChecker(rexBuilder, executor, rowType); + final RexNode first = RexUtil.composeConjunction(rexBuilder, predicates.pulledUpPredicates); + final RexNode second = RexUtil.composeConjunction(rexBuilder, list); + // Suppose we have EMP(empno INT NOT NULL, mgr INT), + // and predicates [empno > 0, mgr > 0]. + // We make first: "empno > 0 AND mgr > 0" + // and second: "mgr IS NOT NULL" + // and ask whether first implies second. + // It does, so we have no nullable columns. + return !checker.implies(first, second); + } + + // ~ Inner Classes ---------------------------------------------------------- + + /** + * A {@code RelShuttle} which propagates all the hints of relational expression to their + * children nodes. + * + *

Given a plan: + * + *

+ * + *
+     *            Filter (Hint1)
+     *                |
+     *               Join
+     *              /    \
+     *            Scan  Project (Hint2)
+     *                     |
+     *                    Scan2
+     * 
+ * + *
+ * + *

Every hint has a {@code inheritPath} (integers list) which records its propagate path, + * number `0` represents the hint is propagated from the first(left) child, number `1` + * represents the hint is propagated from the second(right) child, so the plan would have hints + * path as follows (assumes each hint can be propagated to all child nodes): + * + *

    + *
  • Filter would have hints {Hint1[]} + *
  • Join would have hints {Hint1[0]} + *
  • Scan would have hints {Hint1[0, 0]} + *
  • Project would have hints {Hint1[0,1], Hint2[]} + *
  • Scan2 would have hints {[Hint1[0, 1, 0], Hint2[0]} + *
+ */ + private static class RelHintPropagateShuttle extends RelHomogeneousShuttle { + /** Stack recording the hints and its current inheritPath. */ + private final Deque, Deque>> inheritPaths = new ArrayDeque<>(); + + /** + * The hint strategies to decide if a hint should be attached to a relational expression. + */ + private final HintStrategyTable hintStrategies; + + RelHintPropagateShuttle(HintStrategyTable hintStrategies) { + this.hintStrategies = hintStrategies; + } + + /** Visits a particular child of a parent. */ + @Override + protected RelNode visitChild(RelNode parent, int i, RelNode child) { + inheritPaths.forEach(inheritPath -> inheritPath.right.push(i)); + try { + RelNode child2 = child.accept(this); + if (child2 != child) { + final List newInputs = new ArrayList<>(parent.getInputs()); + newInputs.set(i, child2); + return parent.copy(parent.getTraitSet(), newInputs); + } + return parent; + } finally { + inheritPaths.forEach(inheritPath -> inheritPath.right.pop()); + } + } + + @Override + public RelNode visit(RelNode other) { + if (other instanceof Hintable) { + return visitHintable(other); + } else { + return visitChildren(other); + } + } + + /** + * Handle the {@link Hintable}s. + * + *

There are two cases to handle hints: + * + *

    + *
  • For TableScan: table scan is always a leaf node, attach the hints of the + * propagation path directly; + *
  • For other {@link Hintable}s: if the node has hints itself, that means, these hints + * are query hints that need to propagate to its children, so we do these things: + *
      + *
    1. push the hints with empty inheritPath to the stack + *
    2. visit the children nodes and propagate the hints + *
    3. pop the hints pushed in step1 + *
    4. attach the hints of the propagation path + *
    + * if the node does not have hints, attach the hints of the propagation path directly. + *
+ * + * @param node {@link Hintable} to handle + * @return New copy of the {@code hintable} with propagated hints attached + */ + private RelNode visitHintable(RelNode node) { + final List topHints = ((Hintable) node).getHints(); + final boolean hasHints = topHints != null && topHints.size() > 0; + final boolean hasQueryHints = hasHints && !(node instanceof TableScan); + if (hasQueryHints) { + inheritPaths.push(Pair.of(topHints, new ArrayDeque<>())); + } + final RelNode node1 = visitChildren(node); + if (hasQueryHints) { + inheritPaths.pop(); + } + return attachHints(node1); + } + + private RelNode attachHints(RelNode original) { + assert original instanceof Hintable; + if (inheritPaths.size() > 0) { + final List hints = + inheritPaths.stream() + .sorted(Comparator.comparingInt(o -> o.right.size())) + .map(path -> copyWithInheritPath(path.left, path.right)) + .reduce( + new ArrayList<>(), + (acc, hints1) -> { + acc.addAll(hints1); + return acc; + }); + final List filteredHints = hintStrategies.apply(hints, original); + if (filteredHints.size() > 0) { + return ((Hintable) original).attachHints(filteredHints); + } + } + return original; + } + + private static List copyWithInheritPath( + List hints, Deque inheritPath) { + // Copy the Dequeue in reverse order. + final List path = new ArrayList<>(); + final Iterator iterator = inheritPath.descendingIterator(); + while (iterator.hasNext()) { + path.add(iterator.next()); + } + return hints.stream().map(hint -> hint.copy(path)).collect(Collectors.toList()); + } + } + + /** + * A {@code RelShuttle} which propagates the given hints to the sub-tree from the root node. It + * stops the search of current path if the node already has hints or the whole propagation if + * there is already a matched node. + * + *

Given a plan: + * + *

+ * + *
+     *            Filter
+     *                |
+     *               Join
+     *              /    \
+     *            Scan  Project (Hint2)
+     *                     |
+     *                    Scan2
+     * 
+ * + *
+ * + *

The [Filter, Join, Scan] are the candidates(in sequence) to propagate, the whole + * propagation ends if we append the given hints to a node successfully. + */ + private static class SubTreeHintPropagateShuttle extends RelHomogeneousShuttle { + /** Stack recording the appended inheritPath. */ + private final List appendPath = new ArrayList<>(); + + /** + * The hint strategies to decide if a hint should be attached to a relational expression. + */ + private final HintStrategyTable hintStrategies; + + /** Hints to propagate. */ + private final List hints; + + SubTreeHintPropagateShuttle(HintStrategyTable hintStrategies, List hints) { + this.hintStrategies = hintStrategies; + this.hints = hints; + } + + /** Visits a particular child of a parent. */ + @Override + protected RelNode visitChild(RelNode parent, int i, RelNode child) { + appendPath.add(i); + try { + RelNode child2 = child.accept(this); + if (child2 != child) { + final List newInputs = new ArrayList<>(parent.getInputs()); + newInputs.set(i, child2); + return parent.copy(parent.getTraitSet(), newInputs); + } + return parent; + } finally { + // Remove the last element. + appendPath.remove(appendPath.size() - 1); + } + } + + @Override + public RelNode visit(RelNode other) { + if (this.appendPath.size() > 3) { + // Returns early if the visiting depth is bigger than 3 + return other; + } + if (other instanceof Hintable) { + return visitHintable(other); + } else { + return visitChildren(other); + } + } + + /** + * Handle the {@link Hintable}s. + * + *

Try to propagate the given hints to the node, the propagation finishes if: + * + *

    + *
  • This hintable already has hints, that means, the rel is definitely not created by a + * planner rule(or copied by the planner rule) + *
  • This hintable appended the hints successfully + *
+ * + * @param node {@link Hintable} to handle + * @return New copy of the {@code hintable} with propagated hints attached + */ + private RelNode visitHintable(RelNode node) { + final List topHints = ((Hintable) node).getHints(); + final boolean hasHints = topHints != null && topHints.size() > 0; + if (hasHints) { + // This node is definitely not created by the planner, returns early. + return node; + } + final RelNode node1 = attachHints(node); + if (node1 != node) { + return node1; + } + return visitChildren(node); + } + + private RelNode attachHints(RelNode original) { + assert original instanceof Hintable; + final List hints = + this.hints.stream() + .map(hint -> copyWithAppendPath(hint, appendPath)) + .collect(Collectors.toList()); + final List filteredHints = hintStrategies.apply(hints, original); + if (filteredHints.size() > 0) { + return ((Hintable) original).attachHints(filteredHints); + } + return original; + } + + private static RelHint copyWithAppendPath(RelHint hint, List appendPaths) { + if (appendPaths.size() == 0) { + return hint; + } else { + List newPath = new ArrayList<>(hint.inheritPath); + newPath.addAll(appendPaths); + return hint.copy(newPath); + } + } + } + + /** + * A {@code RelShuttle} which resets all the hints of a relational expression to what they are + * originally like. + * + *

This would trigger a reverse transformation of what {@link RelHintPropagateShuttle} does. + * + *

Transformation rules: + * + *

    + *
  • Project: remove the hints that have non-empty inherit path (which means the hint was + * not originally declared from it); + *
  • Aggregate: remove the hints that have non-empty inherit path; + *
  • Join: remove all the hints; + *
  • TableScan: remove the hints that have non-empty inherit path. + *
+ */ + private static class ResetHintsShuttle extends RelHomogeneousShuttle { + @Override + public RelNode visit(RelNode node) { + node = visitChildren(node); + if (node instanceof Hintable) { + node = resetHints((Hintable) node); + } + return node; + } + + private static RelNode resetHints(Hintable hintable) { + if (hintable.getHints().size() > 0) { + final List resetHints = + hintable.getHints().stream() + .filter(hint -> hint.inheritPath.size() == 0) + .collect(Collectors.toList()); + return hintable.withHints(resetHints); + } else { + return (RelNode) hintable; + } + } + } + + /** Visitor that finds all variables used but not stopped in an expression. */ + private static class VariableSetVisitor extends RelVisitor { + final Set variables = new HashSet<>(); + + // implement RelVisitor + @Override + public void visit( + RelNode p, + int ordinal, + @org.checkerframework.checker.nullness.qual.Nullable RelNode parent) { + super.visit(p, ordinal, parent); + p.collectVariablesUsed(variables); + + // Important! Remove stopped variables AFTER we visit children + // (which what super.visit() does) + variables.removeAll(p.getVariablesSet()); + } + } + + /** Visitor that finds all variables used in an expression. */ + public static class VariableUsedVisitor extends RexShuttle { + public final Set variables = new LinkedHashSet<>(); + public final Multimap variableFields = LinkedHashMultimap.create(); + @NotOnlyInitialized private final @Nullable RelShuttle relShuttle; + + public VariableUsedVisitor(@UnknownInitialization @Nullable RelShuttle relShuttle) { + this.relShuttle = relShuttle; + } + + @Override + public RexNode visitCorrelVariable(RexCorrelVariable p) { + variables.add(p.id); + variableFields.put(p.id, -1); + return p; + } + + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) { + final RexCorrelVariable v = (RexCorrelVariable) fieldAccess.getReferenceExpr(); + variableFields.put(v.id, fieldAccess.getField().getIndex()); + } + return super.visitFieldAccess(fieldAccess); + } + + @Override + public RexNode visitSubQuery(RexSubQuery subQuery) { + if (relShuttle != null) { + subQuery.rel.accept(relShuttle); // look inside sub-queries + } + return super.visitSubQuery(subQuery); + } + } + + /** Shuttle that finds the set of inputs that are used. */ + public static class InputReferencedVisitor extends RexShuttle { + public final NavigableSet inputPosReferenced = new TreeSet<>(); + + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + inputPosReferenced.add(inputRef.getIndex()); + return inputRef; + } + } + + /** Converts types to descriptive strings. */ + public static class TypeDumper { + private String indent; + private final PrintWriter pw; + + TypeDumper(PrintWriter pw) { + this.pw = pw; + this.indent = ""; + } + + void accept(RelDataType type) { + if (type.isStruct()) { + final List fields = type.getFieldList(); + + // RECORD ( + // I INTEGER NOT NULL, + // J VARCHAR(240)) + pw.println("RECORD ("); + String prevIndent = indent; + String extraIndent = " "; + this.indent = indent + extraIndent; + acceptFields(fields); + this.indent = prevIndent; + pw.print(")"); + if (!type.isNullable()) { + pw.print(NON_NULLABLE_SUFFIX); + } + } else if (type instanceof MultisetSqlType) { + // E.g. "INTEGER NOT NULL MULTISET NOT NULL" + RelDataType componentType = + requireNonNull( + type.getComponentType(), + () -> "type.getComponentType() for " + type); + accept(componentType); + pw.print(" MULTISET"); + if (!type.isNullable()) { + pw.print(NON_NULLABLE_SUFFIX); + } + } else { + // E.g. "INTEGER" E.g. "VARCHAR(240) CHARACTER SET "ISO-8859-1" + // COLLATE "ISO-8859-1$en_US$primary" NOT NULL" + pw.print(type.getFullTypeString()); + } + } + + private void acceptFields(final List fields) { + for (int i = 0; i < fields.size(); i++) { + RelDataTypeField field = fields.get(i); + if (i > 0) { + pw.println(","); + } + pw.print(indent); + pw.print(field.getName()); + pw.print(" "); + accept(field.getType()); + } + } + } + + /** Visitor which builds a bitmap of the inputs used by an expression. */ + public static class InputFinder extends RexVisitorImpl { + private final ImmutableBitSet.Builder bitBuilder; + private final @Nullable Set extraFields; + + private InputFinder( + @Nullable Set extraFields, ImmutableBitSet.Builder bitBuilder) { + super(true); + this.bitBuilder = bitBuilder; + this.extraFields = extraFields; + } + + public InputFinder() { + this(null); + } + + public InputFinder(@Nullable Set extraFields) { + this(extraFields, ImmutableBitSet.builder()); + } + + public InputFinder( + @Nullable Set extraFields, ImmutableBitSet initialBits) { + this(extraFields, initialBits.rebuild()); + } + + /** Returns an input finder that has analyzed a given expression. */ + public static InputFinder analyze(RexNode node) { + final InputFinder inputFinder = new InputFinder(); + node.accept(inputFinder); + return inputFinder; + } + + /** Returns a bit set describing the inputs used by an expression. */ + public static ImmutableBitSet bits(RexNode node) { + return analyze(node).build(); + } + + /** + * Returns a bit set describing the inputs used by a collection of project expressions and + * an optional condition. + */ + public static ImmutableBitSet bits(List exprs, @Nullable RexNode expr) { + final InputFinder inputFinder = new InputFinder(); + RexUtil.apply(inputFinder, exprs, expr); + return inputFinder.build(); + } + + /** + * Returns the bit set. + * + *

After calling this method, you cannot do any more visits or call this method again. + */ + public ImmutableBitSet build() { + return bitBuilder.build(); + } + + @Override + public Void visitInputRef(RexInputRef inputRef) { + bitBuilder.set(inputRef.getIndex()); + return null; + } + + @Override + public Void visitCall(RexCall call) { + if (call.getOperator() == RexBuilder.GET_OPERATOR) { + RexLiteral literal = (RexLiteral) call.getOperands().get(1); + if (extraFields != null) { + requireNonNull(literal, () -> "first operand in " + call); + String value2 = (String) literal.getValue2(); + requireNonNull(value2, () -> "value of the first operand in " + call); + extraFields.add(new RelDataTypeFieldImpl(value2, -1, call.getType())); + } + } + return super.visitCall(call); + } + } + + /** + * Walks an expression tree, converting the index of RexInputRefs based on some adjustment + * factor. + */ + public static class RexInputConverter extends RexShuttle { + protected final RexBuilder rexBuilder; + private final @Nullable List srcFields; + protected final @Nullable List destFields; + private final @Nullable List leftDestFields; + private final @Nullable List rightDestFields; + private final int nLeftDestFields; + private final int[] adjustments; + + /** + * Creates a RexInputConverter. + * + * @param rexBuilder builder for creating new RexInputRefs + * @param srcFields fields where the RexInputRefs originated from; if null, a new + * RexInputRef is always created, referencing the input from destFields corresponding to + * its current index value + * @param destFields fields that the new RexInputRefs will be referencing; if null, use the + * type information from the source field when creating the new RexInputRef + * @param leftDestFields in the case where the destination is a join, these are the fields + * from the left join input + * @param rightDestFields in the case where the destination is a join, these are the fields + * from the right join input + * @param adjustments the amount to adjust each field by + */ + private RexInputConverter( + RexBuilder rexBuilder, + @Nullable List srcFields, + @Nullable List destFields, + @Nullable List leftDestFields, + @Nullable List rightDestFields, + int[] adjustments) { + this.rexBuilder = rexBuilder; + this.srcFields = srcFields; + this.destFields = destFields; + this.adjustments = adjustments; + this.leftDestFields = leftDestFields; + this.rightDestFields = rightDestFields; + if (leftDestFields == null) { + nLeftDestFields = 0; + } else { + assert destFields == null; + nLeftDestFields = leftDestFields.size(); + } + } + + public RexInputConverter( + RexBuilder rexBuilder, + @Nullable List srcFields, + @Nullable List leftDestFields, + @Nullable List rightDestFields, + int[] adjustments) { + this(rexBuilder, srcFields, null, leftDestFields, rightDestFields, adjustments); + } + + public RexInputConverter( + RexBuilder rexBuilder, + @Nullable List srcFields, + @Nullable List destFields, + int[] adjustments) { + this(rexBuilder, srcFields, destFields, null, null, adjustments); + } + + public RexInputConverter( + RexBuilder rexBuilder, + @Nullable List srcFields, + int[] adjustments) { + this(rexBuilder, srcFields, null, null, null, adjustments); + } + + @Override + public RexNode visitInputRef(RexInputRef var) { + int srcIndex = var.getIndex(); + int destIndex = srcIndex + adjustments[srcIndex]; + + RelDataType type; + if (destFields != null) { + type = destFields.get(destIndex).getType(); + } else if (leftDestFields != null) { + if (destIndex < nLeftDestFields) { + type = leftDestFields.get(destIndex).getType(); + } else { + type = + requireNonNull(rightDestFields, "rightDestFields") + .get(destIndex - nLeftDestFields) + .getType(); + } + } else { + type = requireNonNull(srcFields, "srcFields").get(srcIndex).getType(); + } + if ((adjustments[srcIndex] != 0) + || (srcFields == null) + || (type != srcFields.get(srcIndex).getType())) { + return rexBuilder.makeInputRef(type, destIndex); + } else { + return var; + } + } + } + + /** What kind of sub-query. */ + public enum SubQueryType { + EXISTS, + IN, + SCALAR + } + + /** Categorizes whether a bit set contains bits left and right of a line. */ + enum Side { + LEFT, + RIGHT, + BOTH, + EMPTY; + + static Side of(ImmutableBitSet bitSet, int middle) { + final int firstBit = bitSet.nextSetBit(0); + if (firstBit < 0) { + return EMPTY; + } + if (firstBit >= middle) { + return RIGHT; + } + if (bitSet.nextSetBit(middle) < 0) { + return LEFT; + } + return BOTH; + } + + public boolean opposite(Side side) { + return (this == LEFT && side == RIGHT) || (this == RIGHT && side == LEFT); + } + } + + /** + * Shuttle that finds correlation variables inside a given relational expression, including + * those that are inside {@link RexSubQuery sub-queries}. + */ + private static class CorrelationCollector extends RelHomogeneousShuttle { + @SuppressWarnings("assignment.type.incompatible") + private final VariableUsedVisitor vuv = new VariableUsedVisitor(this); + + @Override + public RelNode visit(RelNode other) { + other.collectVariablesUsed(vuv.variables); + other.accept(vuv); + RelNode result = super.visit(other); + // Important! Remove stopped variables AFTER we visit + // children. (which what super.visit() does) + vuv.variables.removeAll(other.getVariablesSet()); + return result; + } + } + + /** Result of calling {@link org.apache.calcite.plan.RelOptUtil#createExistsPlan}. */ + public static class Exists { + public final RelNode r; + public final boolean indicator; + public final boolean outerJoin; + + private Exists(RelNode r, boolean indicator, boolean outerJoin) { + this.r = r; + this.indicator = indicator; + this.outerJoin = outerJoin; + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java index 6dc339db70f91..464cb1eb6dd82 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java @@ -58,8 +58,8 @@ *

FLINK modifications are at lines * *

    - *
  1. Should be removed after fixing CALCITE-5199: Lines 243-245 - *
  2. Might be a subject to reconsider after bump to Calcite 1.38.0: Lines 568-570 + *
  3. Should be removed after fixing CALCITE-5199: Lines 242-244 + *
  4. Added in FLINK-39695 (backport of CALCITE-6764): Lines 407 ~ 438 *
*/ public abstract class RelDataTypeFactoryImpl implements RelDataTypeFactory { @@ -405,6 +405,39 @@ public RelDataType createTypeWithNullability(final RelDataType type, final boole return canonize(newType); } + // ----- FLINK MODIFICATION BEGIN ----- + // Backport from Calcite (CALCITE-6764): creates a type with specified nullability + // without deep-copying record field types. For record types, makes the struct + // itself nullable/not-nullable while keeping field types unchanged. + public RelDataType enforceTypeWithNullability(final RelDataType type, final boolean nullable) { + requireNonNull(type, "type"); + RelDataType newType; + if (type.isNullable() == nullable) { + newType = type; + } else if (type instanceof RelRecordType) { + return createStructType( + type.getStructKind(), + new AbstractList() { + @Override + public RelDataType get(int index) { + return type.getFieldList().get(index).getType(); + } + + @Override + public int size() { + return type.getFieldCount(); + } + }, + type.getFieldNames(), + nullable); + } else { + newType = copySimpleType(type, nullable); + } + return canonize(newType); + } + + // ----- FLINK MODIFICATION END ----- + /** * Registers a type, or returns the existing type if it is already registered. * diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexBuilder.java new file mode 100644 index 0000000000000..e6e34927cabca --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexBuilder.java @@ -0,0 +1,1915 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rex; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableRangeSet; +import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; +import com.google.common.collect.TreeRangeSet; +import org.apache.calcite.avatica.util.ByteString; +import org.apache.calcite.avatica.util.DateTimeUtils; +import org.apache.calcite.avatica.util.Spaces; +import org.apache.calcite.avatica.util.TimeUnit; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.runtime.FlatLists; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlCollation; +import org.apache.calcite.sql.SqlIntervalQualifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlSpecialOperator; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.fun.SqlCountAggFunction; +import org.apache.calcite.sql.fun.SqlLibraryOperators; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.ArraySqlType; +import org.apache.calcite.sql.type.MapSqlType; +import org.apache.calcite.sql.type.MultisetSqlType; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.DateString; +import org.apache.calcite.util.NlsString; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Sarg; +import org.apache.calcite.util.TimeString; +import org.apache.calcite.util.TimeWithTimeZoneString; +import org.apache.calcite.util.TimestampString; +import org.apache.calcite.util.TimestampWithTimeZoneString; +import org.apache.calcite.util.Util; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; +import org.locationtech.jts.geom.Geometry; + +import java.math.BigDecimal; +import java.math.MathContext; +import java.math.RoundingMode; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Calendar; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.function.IntPredicate; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verifyNotNull; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.linq4j.Nullness.castNonNull; + +/** + * Factory for row expressions. + * + *

Some common literal values (NULL, TRUE, FALSE, 0, 1, '') are cached. + * + *

FLINK modifications (backport of CALCITE-6764): Lines 234, 241, 245, 249 ~ 253 + */ +public class RexBuilder { + /** + * Special operator that accesses an unadvertised field of an input record. This operator cannot + * be used in SQL queries; it is introduced temporarily during sql-to-rel translation, then + * replaced during the process that trims unwanted fields. + */ + public static final SqlSpecialOperator GET_OPERATOR = + new SqlSpecialOperator("_get", SqlKind.OTHER_FUNCTION); + + /** The smallest valid {@code int} value, as a {@link BigDecimal}. */ + private static final BigDecimal INT_MIN = BigDecimal.valueOf(Integer.MIN_VALUE); + + /** The largest valid {@code int} value, as a {@link BigDecimal}. */ + private static final BigDecimal INT_MAX = BigDecimal.valueOf(Integer.MAX_VALUE); + + // ~ Instance fields -------------------------------------------------------- + + protected final RelDataTypeFactory typeFactory; + private final RexLiteral booleanTrue; + private final RexLiteral booleanFalse; + private final RexLiteral charEmpty; + private final RexLiteral constantNull; + private final SqlStdOperatorTable opTab = SqlStdOperatorTable.instance(); + + // ~ Constructors ----------------------------------------------------------- + + /** + * Creates a RexBuilder. + * + * @param typeFactory Type factory + */ + @SuppressWarnings("method.invocation.invalid") + public RexBuilder(RelDataTypeFactory typeFactory) { + this.typeFactory = typeFactory; + this.booleanTrue = + makeLiteral( + Boolean.TRUE, + typeFactory.createSqlType(SqlTypeName.BOOLEAN), + SqlTypeName.BOOLEAN); + this.booleanFalse = + makeLiteral( + Boolean.FALSE, + typeFactory.createSqlType(SqlTypeName.BOOLEAN), + SqlTypeName.BOOLEAN); + this.charEmpty = + makeLiteral( + new NlsString("", null, null), + typeFactory.createSqlType(SqlTypeName.CHAR, 0), + SqlTypeName.CHAR); + this.constantNull = + makeLiteral(null, typeFactory.createSqlType(SqlTypeName.NULL), SqlTypeName.NULL); + } + + /** + * Creates a list of {@link org.apache.calcite.rex.RexInputRef} expressions, projecting the + * fields of a given record type. + */ + public List identityProjects(final RelDataType rowType) { + return Util.transform( + rowType.getFieldList(), + input -> new RexInputRef(input.getIndex(), input.getType())); + } + + // ~ Methods ---------------------------------------------------------------- + + /** + * Returns this RexBuilder's type factory. + * + * @return type factory + */ + public RelDataTypeFactory getTypeFactory() { + return typeFactory; + } + + /** + * Returns this RexBuilder's operator table. + * + * @return operator table + */ + public SqlStdOperatorTable getOpTab() { + return opTab; + } + + /** + * Creates an expression accessing a given named field from a record. + * + *

NOTE: Be careful choosing the value of {@code caseSensitive}. If the field name was + * supplied by an end-user (e.g. as a column alias in SQL), use your session's case-sensitivity + * setting. Only hard-code {@code true} if you are sure that the field name is internally + * generated. Hard-coding {@code false} is almost certainly wrong. + * + * @param expr Expression yielding a record + * @param fieldName Name of field in record + * @param caseSensitive Whether match is case-sensitive + * @return Expression accessing a given named field + */ + public RexNode makeFieldAccess(RexNode expr, String fieldName, boolean caseSensitive) { + final RelDataType type = expr.getType(); + final RelDataTypeField field = type.getField(fieldName, caseSensitive, false); + if (field == null) { + throw new AssertionError("Type '" + type + "' has no field '" + fieldName + "'"); + } + return makeFieldAccessInternal(expr, field); + } + + /** + * Creates an expression accessing a field with a given ordinal from a record. + * + * @param expr Expression yielding a record + * @param i Ordinal of field + * @return Expression accessing given field + */ + public RexNode makeFieldAccess(RexNode expr, int i) { + final RelDataType type = expr.getType(); + final List fields = type.getFieldList(); + if ((i < 0) || (i >= fields.size())) { + throw new AssertionError( + "Field ordinal " + i + " is invalid for " + " type '" + type + "'"); + } + return makeFieldAccessInternal(expr, fields.get(i)); + } + + /** + * Creates an expression accessing a given field from a record. + * + * @param expr Expression yielding a record + * @param field Field + * @return Expression accessing given field + */ + private RexNode makeFieldAccessInternal(RexNode expr, final RelDataTypeField field) { + RelDataType fieldType = field.getType(); + if (expr instanceof RexRangeRef) { + RexRangeRef range = (RexRangeRef) expr; + if (field.getIndex() < 0) { + return makeCall( + fieldType, + GET_OPERATOR, + ImmutableList.of(expr, makeLiteral(field.getName()))); + } + return new RexInputRef(range.getOffset() + field.getIndex(), fieldType); + } + + if (expr.getType().isNullable()) { + fieldType = typeFactory.createTypeWithNullability(fieldType, true); + } + return new RexFieldAccess(expr, field, fieldType); + } + + /** Creates a call with a list of arguments and a predetermined type. */ + public RexNode makeCall(RelDataType returnType, SqlOperator op, List exprs) { + return new RexCall(returnType, op, exprs); + } + + /** + * Creates a call with an array of arguments. + * + *

If you already know the return type of the call, then {@link + * #makeCall(org.apache.calcite.rel.type.RelDataType, org.apache.calcite.sql.SqlOperator, + * java.util.List)} is preferred. + */ + public RexNode makeCall(SqlOperator op, List exprs) { + final RelDataType type = deriveReturnType(op, exprs); + return new RexCall(type, op, exprs); + } + + /** + * Creates a call with a list of arguments. + * + *

Equivalent to makeCall(op, exprList.toArray(new RexNode[exprList.size()])). + */ + public final RexNode makeCall(SqlOperator op, RexNode... exprs) { + return makeCall(op, ImmutableList.copyOf(exprs)); + } + + /** + * Derives the return type of a call to an operator. + * + * @param op the operator being called + * @param exprs actual operands + * @return derived type + */ + public RelDataType deriveReturnType(SqlOperator op, List exprs) { + return op.inferReturnType(new RexCallBinding(typeFactory, op, exprs, ImmutableList.of())); + } + + /** + * Creates a reference to an aggregate call, checking for repeated calls. + * + *

Argument types help to optimize for repeated aggregates. For instance count(42) is + * equivalent to count(*). + * + * @param aggCall aggregate call to be added + * @param groupCount number of groups in the aggregate relation + * @param aggCalls destination list of aggregate calls + * @param aggCallMapping the dictionary of already added calls + * @param isNullable Whether input field i is nullable + * @return Rex expression for the given aggregate call + */ + public RexNode addAggCall( + AggregateCall aggCall, + int groupCount, + List aggCalls, + Map aggCallMapping, + IntPredicate isNullable) { + if (aggCall.getAggregation() instanceof SqlCountAggFunction && !aggCall.isDistinct()) { + final List args = aggCall.getArgList(); + final List nullableArgs = nullableArgs(args, isNullable); + aggCall = aggCall.withArgList(nullableArgs); + } + RexNode rex = aggCallMapping.get(aggCall); + if (rex == null) { + int index = aggCalls.size() + groupCount; + aggCalls.add(aggCall); + rex = makeInputRef(aggCall.getType(), index); + aggCallMapping.put(aggCall, rex); + } + return rex; + } + + @Deprecated // to be removed before 2.0 + public RexNode addAggCall( + final AggregateCall aggCall, + int groupCount, + List aggCalls, + Map aggCallMapping, + final @Nullable List aggArgTypes) { + return addAggCall( + aggCall, + groupCount, + aggCalls, + aggCallMapping, + i -> + requireNonNull(aggArgTypes, "aggArgTypes") + .get(aggCall.getArgList().indexOf(i)) + .isNullable()); + } + + /** Creates a reference to an aggregate call, checking for repeated calls. */ + @Deprecated // to be removed before 2.0 + public RexNode addAggCall( + AggregateCall aggCall, + int groupCount, + boolean indicator, + List aggCalls, + Map aggCallMapping, + final @Nullable List aggArgTypes) { + checkArgument(!indicator, "indicator is deprecated, use GROUPING function instead"); + return addAggCall(aggCall, groupCount, aggCalls, aggCallMapping, aggArgTypes); + } + + private static List nullableArgs(List list0, IntPredicate isNullable) { + return list0.stream().filter(isNullable::test).collect(toImmutableList()); + } + + @Deprecated // to be removed before 2.0 + public RexNode makeOver( + RelDataType type, + SqlAggFunction operator, + List exprs, + List partitionKeys, + ImmutableList orderKeys, + RexWindowBound lowerBound, + RexWindowBound upperBound, + boolean rows, + boolean allowPartial, + boolean nullWhenCountZero, + boolean distinct) { + return makeOver( + type, + operator, + exprs, + partitionKeys, + orderKeys, + lowerBound, + upperBound, + rows, + allowPartial, + nullWhenCountZero, + distinct, + false); + } + + /** Creates a call to a windowed agg. */ + public RexNode makeOver( + RelDataType type, + SqlAggFunction operator, + List exprs, + List partitionKeys, + ImmutableList orderKeys, + RexWindowBound lowerBound, + RexWindowBound upperBound, + boolean rows, + boolean allowPartial, + boolean nullWhenCountZero, + boolean distinct, + boolean ignoreNulls) { + final RexWindow window = makeWindow(partitionKeys, orderKeys, lowerBound, upperBound, rows); + final RexOver over = new RexOver(type, operator, exprs, window, distinct, ignoreNulls); + RexNode result = over; + + // This should be correct but need time to go over test results. + // Also want to look at combing with section below. + if (nullWhenCountZero) { + final RelDataType bigintType = typeFactory.createSqlType(SqlTypeName.BIGINT); + result = + makeCall( + SqlStdOperatorTable.CASE, + makeCall( + SqlStdOperatorTable.GREATER_THAN, + new RexOver( + bigintType, + SqlStdOperatorTable.COUNT, + exprs, + window, + distinct, + ignoreNulls), + makeLiteral(BigDecimal.ZERO, bigintType, SqlTypeName.DECIMAL)), + ensureType( + type, // SUM0 is non-nullable, thus need a cast + new RexOver( + typeFactory.createTypeWithNullability(type, false), + operator, + exprs, + window, + distinct, + ignoreNulls), + false), + makeNullLiteral(type)); + } + if (!allowPartial) { + checkArgument(rows, "DISALLOW PARTIAL over RANGE"); + final RelDataType bigintType = typeFactory.createSqlType(SqlTypeName.BIGINT); + // todo: read bound + result = + makeCall( + SqlStdOperatorTable.CASE, + makeCall( + SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, + new RexOver( + bigintType, + SqlStdOperatorTable.COUNT, + ImmutableList.of(), + window, + distinct, + ignoreNulls), + makeLiteral( + BigDecimal.valueOf(2), + bigintType, + SqlTypeName.DECIMAL)), + result, + constantNull); + } + return result; + } + + /** + * Creates a window specification. + * + * @param partitionKeys Partition keys + * @param orderKeys Order keys + * @param lowerBound Lower bound + * @param upperBound Upper bound + * @param rows Whether physical. True if row-based, false if range-based + * @return window specification + */ + public RexWindow makeWindow( + List partitionKeys, + ImmutableList orderKeys, + RexWindowBound lowerBound, + RexWindowBound upperBound, + boolean rows) { + if (lowerBound.isUnbounded() + && lowerBound.isPreceding() + && upperBound.isUnbounded() + && upperBound.isFollowing()) { + // RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + // is equivalent to + // ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + // but we prefer "RANGE" + rows = false; + } + return new RexWindow(partitionKeys, orderKeys, lowerBound, upperBound, rows); + } + + /** + * Creates a constant for the SQL NULL value. + * + * @deprecated Use {@link #makeNullLiteral(RelDataType)}, which produces a NULL of the correct + * type + */ + @Deprecated // to be removed before 2.0 + public RexLiteral constantNull() { + return constantNull; + } + + /** + * Creates an expression referencing a correlation variable. + * + * @param id Name of variable + * @param type Type of variable + * @return Correlation variable + */ + public RexNode makeCorrel(RelDataType type, CorrelationId id) { + return new RexCorrelVariable(id, type); + } + + /** + * Creates an invocation of the NEW operator. + * + * @param type Type to be instantiated + * @param exprs Arguments to NEW operator + * @return Expression invoking NEW operator + */ + public RexNode makeNewInvocation(RelDataType type, List exprs) { + return new RexCall(type, SqlStdOperatorTable.NEW, exprs); + } + + /** + * Creates a call to the CAST operator. + * + * @param type Type to cast to + * @param exp Expression being cast + * @return Call to CAST operator + */ + public RexNode makeCast(RelDataType type, RexNode exp) { + return makeCast(type, exp, false, false, constantNull); + } + + @Deprecated // to be removed before 2.0 + public RexNode makeCast(RelDataType type, RexNode exp, boolean matchNullability) { + return makeCast(type, exp, matchNullability, false, constantNull); + } + + /** + * Creates a call to the CAST operator, expanding if possible, and optionally also preserving + * nullability, and optionally in safe mode. + * + *

Tries to expand the cast, and therefore the result may be something other than a {@link + * RexCall} to the CAST operator, such as a {@link RexLiteral}. + * + * @param type Type to cast to + * @param exp Expression being cast + * @param matchNullability Whether to ensure the result has the same nullability as {@code type} + * @param safe Whether to return NULL if cast fails + * @return Call to CAST operator + */ + public RexNode makeCast(RelDataType type, RexNode exp, boolean matchNullability, boolean safe) { + return makeCast(type, exp, matchNullability, safe, constantNull); + } + + /** + * Creates a call to the CAST operator, expanding if possible, and optionally also preserving + * nullability, and optionally in safe mode. + * + *

Tries to expand the cast, and therefore the result may be something other than a {@link + * RexCall} to the CAST operator, such as a {@link RexLiteral}. + * + * @param type Type to cast to + * @param exp Expression being cast + * @param matchNullability Whether to ensure the result has the same nullability as {@code type} + * @param safe Whether to return NULL if cast fails + * @param format Type Format to cast into + * @return Call to CAST operator + */ + public RexNode makeCast( + RelDataType type, + RexNode exp, + boolean matchNullability, + boolean safe, + RexLiteral format) { + final SqlTypeName sqlType = type.getSqlTypeName(); + if (exp instanceof RexLiteral) { + RexLiteral literal = (RexLiteral) exp; + Comparable value = literal.getValueAs(Comparable.class); + SqlTypeName typeName = literal.getTypeName(); + + // Allow casting boolean literals to integer types. + if (exp.getType().getSqlTypeName() == SqlTypeName.BOOLEAN + && SqlTypeUtil.isExactNumeric(type)) { + return makeCastBooleanToExact(type, exp); + } + if (canRemoveCastFromLiteral(type, value, typeName)) { + switch (typeName) { + case INTERVAL_YEAR: + case INTERVAL_YEAR_MONTH: + case INTERVAL_MONTH: + case INTERVAL_DAY: + case INTERVAL_DAY_HOUR: + case INTERVAL_DAY_MINUTE: + case INTERVAL_DAY_SECOND: + case INTERVAL_HOUR: + case INTERVAL_HOUR_MINUTE: + case INTERVAL_HOUR_SECOND: + case INTERVAL_MINUTE: + case INTERVAL_MINUTE_SECOND: + case INTERVAL_SECOND: + assert value instanceof BigDecimal; + typeName = type.getSqlTypeName(); + switch (typeName) { + case BIGINT: + case INTEGER: + case SMALLINT: + case TINYINT: + case DOUBLE: + case FLOAT: + case REAL: + case DECIMAL: + BigDecimal value2 = (BigDecimal) value; + final BigDecimal multiplier = + baseUnit(literal.getTypeName()).multiplier; + final BigDecimal divider = + literal.getTypeName().getEndUnit().multiplier; + value = + value2.multiply(multiplier) + .divide(divider, 0, RoundingMode.HALF_DOWN); + break; + default: + break; + } + + // Not all types are allowed for literals + switch (typeName) { + case INTEGER: + typeName = SqlTypeName.BIGINT; + break; + default: + break; + } + break; + default: + break; + } + final RexLiteral literal2 = makeLiteral(value, type, typeName); + if (type.isNullable() && !literal2.getType().isNullable() && matchNullability) { + return makeAbstractCast(type, literal2, safe, format); + } + return literal2; + } + } else if (SqlTypeUtil.isExactNumeric(type) && SqlTypeUtil.isInterval(exp.getType())) { + return makeCastIntervalToExact(type, exp); + } else if (sqlType == SqlTypeName.BOOLEAN && SqlTypeUtil.isExactNumeric(exp.getType())) { + return makeCastExactToBoolean(type, exp); + } else if (exp.getType().getSqlTypeName() == SqlTypeName.BOOLEAN + && SqlTypeUtil.isExactNumeric(type)) { + return makeCastBooleanToExact(type, exp); + } + return makeAbstractCast(type, exp, safe, format); + } + + /** + * Returns the lowest granularity unit for the given unit. YEAR and MONTH intervals are stored + * as months; HOUR, MINUTE, SECOND intervals are stored as milliseconds. + */ + protected static TimeUnit baseUnit(SqlTypeName unit) { + if (unit.isYearMonth()) { + return TimeUnit.MONTH; + } else { + return TimeUnit.MILLISECOND; + } + } + + boolean canRemoveCastFromLiteral( + RelDataType toType, @Nullable Comparable value, SqlTypeName fromTypeName) { + if (value == null) { + return true; + } + final SqlTypeName sqlType = toType.getSqlTypeName(); + if (!RexLiteral.valueMatchesType(value, sqlType, false)) { + return false; + } + if (toType.getSqlTypeName() != fromTypeName + && SqlTypeFamily.DATETIME.getTypeNames().contains(fromTypeName)) { + return false; + } + if (value instanceof NlsString) { + final int length = ((NlsString) value).getValue().length(); + switch (toType.getSqlTypeName()) { + case CHAR: + return SqlTypeUtil.comparePrecision(toType.getPrecision(), length) == 0; + case VARCHAR: + return SqlTypeUtil.comparePrecision(toType.getPrecision(), length) >= 0; + default: + throw new AssertionError(toType); + } + } + if (value instanceof ByteString) { + final int length = ((ByteString) value).length(); + switch (toType.getSqlTypeName()) { + case BINARY: + return SqlTypeUtil.comparePrecision(toType.getPrecision(), length) == 0; + case VARBINARY: + return SqlTypeUtil.comparePrecision(toType.getPrecision(), length) >= 0; + default: + throw new AssertionError(toType); + } + } + + if (toType.getSqlTypeName() == SqlTypeName.DECIMAL) { + final BigDecimal decimalValue = (BigDecimal) value; + return SqlTypeUtil.isValidDecimalValue(decimalValue, toType); + } + + if (SqlTypeName.INT_TYPES.contains(sqlType)) { + final BigDecimal decimalValue = (BigDecimal) value; + final int s = decimalValue.scale(); + if (s != 0) { + return false; + } + long l = decimalValue.longValue(); + switch (sqlType) { + case TINYINT: + return l >= Byte.MIN_VALUE && l <= Byte.MAX_VALUE; + case SMALLINT: + return l >= Short.MIN_VALUE && l <= Short.MAX_VALUE; + case INTEGER: + return l >= Integer.MIN_VALUE && l <= Integer.MAX_VALUE; + case BIGINT: + default: + return true; + } + } + + return true; + } + + private RexNode makeCastExactToBoolean(RelDataType toType, RexNode exp) { + return makeCall( + toType, + SqlStdOperatorTable.NOT_EQUALS, + ImmutableList.of(exp, makeZeroLiteral(exp.getType()))); + } + + private RexNode makeCastBooleanToExact(RelDataType toType, RexNode exp) { + final RexNode casted = + makeCall( + SqlStdOperatorTable.CASE, + exp, + makeExactLiteral(BigDecimal.ONE, toType), + makeZeroLiteral(toType)); + if (!exp.getType().isNullable()) { + return casted; + } + return makeCall( + toType, + SqlStdOperatorTable.CASE, + ImmutableList.of( + makeCall(SqlStdOperatorTable.IS_NOT_NULL, exp), + casted, + makeNullLiteral(toType))); + } + + private RexNode makeCastIntervalToExact(RelDataType toType, RexNode exp) { + final TimeUnit endUnit = exp.getType().getSqlTypeName().getEndUnit(); + final TimeUnit baseUnit = baseUnit(exp.getType().getSqlTypeName()); + final BigDecimal multiplier = baseUnit.multiplier; + final BigDecimal divider = endUnit.multiplier; + RexNode value = multiplyDivide(decodeIntervalOrDecimal(exp), multiplier, divider); + return ensureType(toType, value, false); + } + + public RexNode multiplyDivide(RexNode e, BigDecimal multiplier, BigDecimal divider) { + assert multiplier.signum() > 0; + assert divider.signum() > 0; + switch (multiplier.compareTo(divider)) { + case 0: + return e; + case 1: + // E.g. multiplyDivide(e, 1000, 10) ==> e * 100 + return makeCall( + SqlStdOperatorTable.MULTIPLY, + e, + makeExactLiteral(multiplier.divide(divider, RoundingMode.UNNECESSARY))); + case -1: + // E.g. multiplyDivide(e, 10, 1000) ==> e / 100 + return makeCall( + SqlStdOperatorTable.DIVIDE_INTEGER, + e, + makeExactLiteral(divider.divide(multiplier, RoundingMode.UNNECESSARY))); + default: + throw new AssertionError(multiplier + "/" + divider); + } + } + + /** + * Casts a decimal's integer representation to a decimal node. If the expression is not the + * expected integer type, then it is casted first. + * + *

An overflow check may be requested to ensure the internal value does not exceed the + * maximum value of the decimal type. + * + * @param value integer representation of decimal + * @param type type integer will be reinterpreted as + * @param checkOverflow indicates whether an overflow check is required when reinterpreting this + * particular value as the decimal type. A check usually not required for arithmetic, but is + * often required for rounding and explicit casts. + * @return the integer reinterpreted as an opaque decimal type + */ + public RexNode encodeIntervalOrDecimal(RexNode value, RelDataType type, boolean checkOverflow) { + RelDataType bigintType = typeFactory.createSqlType(SqlTypeName.BIGINT); + RexNode cast = ensureType(bigintType, value, true); + return makeReinterpretCast(type, cast, makeLiteral(checkOverflow)); + } + + /** + * Retrieves an INTERVAL or DECIMAL node's integer representation. + * + * @param node the interval or decimal value as an opaque type + * @return an integer representation of the decimal value + */ + public RexNode decodeIntervalOrDecimal(RexNode node) { + assert SqlTypeUtil.isDecimal(node.getType()) || SqlTypeUtil.isInterval(node.getType()); + RelDataType bigintType = typeFactory.createSqlType(SqlTypeName.BIGINT); + return makeReinterpretCast(matchNullability(bigintType, node), node, makeLiteral(false)); + } + + @Deprecated // to be removed before 2.0 + public RexNode makeAbstractCast(RelDataType type, RexNode exp) { + return makeAbstractCast(type, exp, false); + } + + /** + * Creates a call to CAST or SAFE_CAST operator. + * + * @param type Type to cast to + * @param exp Expression being cast + * @param safe Whether to return NULL if cast fails + * @return Call to CAST operator + */ + public RexNode makeAbstractCast(RelDataType type, RexNode exp, boolean safe) { + final SqlOperator operator = + safe ? SqlLibraryOperators.SAFE_CAST : SqlStdOperatorTable.CAST; + return new RexCall(type, operator, ImmutableList.of(exp)); + } + + /** + * Creates a call to CAST or SAFE_CAST operator with a FORMAT clause. + * + * @param type Type to cast to + * @param exp Expression being cast + * @param safe Whether to return NULL if cast fails + * @param format Conversion format for target type + * @return Call to CAST operator + */ + public RexNode makeAbstractCast( + RelDataType type, RexNode exp, boolean safe, RexLiteral format) { + final SqlOperator operator = + safe ? SqlLibraryOperators.SAFE_CAST : SqlStdOperatorTable.CAST; + if (format.isNull()) { + return new RexCall(type, operator, ImmutableList.of(exp)); + } + return new RexCall(type, operator, ImmutableList.of(exp, format)); + } + + /** + * Makes a reinterpret cast. + * + * @param type type returned by the cast + * @param exp expression to be casted + * @param checkOverflow whether an overflow check is required + * @return a RexCall with two operands and a special return type + */ + public RexNode makeReinterpretCast(RelDataType type, RexNode exp, RexNode checkOverflow) { + List args; + if ((checkOverflow != null) && checkOverflow.isAlwaysTrue()) { + args = ImmutableList.of(exp, checkOverflow); + } else { + args = ImmutableList.of(exp); + } + return new RexCall(type, SqlStdOperatorTable.REINTERPRET, args); + } + + /** Makes a cast of a value to NOT NULL; no-op if the type already has NOT NULL. */ + public RexNode makeNotNull(RexNode exp) { + final RelDataType type = exp.getType(); + if (!type.isNullable()) { + return exp; + } + final RelDataType notNullType = typeFactory.createTypeWithNullability(type, false); + return makeAbstractCast(notNullType, exp, false); + } + + /** + * Creates a reference to all the fields in the row. That is, the whole row as a single record + * object. + * + * @param input Input relational expression + */ + public RexNode makeRangeReference(RelNode input) { + return new RexRangeRef(input.getRowType(), 0); + } + + /** + * Creates a reference to all the fields in the row. + * + *

For example, if the input row has type T{f0,f1,f2,f3,f4} then + * makeRangeReference(T{f0,f1,f2,f3,f4}, S{f3,f4}, 3) is an expression which yields the + * last 2 fields. + * + * @param type Type of the resulting range record. + * @param offset Index of first field. + * @param nullable Whether the record is nullable. + */ + public RexRangeRef makeRangeReference(RelDataType type, int offset, boolean nullable) { + if (nullable && !type.isNullable()) { + type = typeFactory.createTypeWithNullability(type, nullable); + } + return new RexRangeRef(type, offset); + } + + /** + * Creates a reference to a given field of the input record. + * + * @param type Type of field + * @param i Ordinal of field + * @return Reference to field + */ + public RexInputRef makeInputRef(RelDataType type, int i) { + type = SqlTypeUtil.addCharsetAndCollation(type, typeFactory); + return new RexInputRef(i, type); + } + + /** + * Creates a reference to a given field of the input relational expression. + * + * @param input Input relational expression + * @param i Ordinal of field + * @return Reference to field + * @see #identityProjects(RelDataType) + */ + public RexInputRef makeInputRef(RelNode input, int i) { + return makeInputRef(input.getRowType().getFieldList().get(i).getType(), i); + } + + /** + * Creates a reference to a given field of the pattern. + * + * @param alpha the pattern name + * @param type Type of field + * @param i Ordinal of field + * @return Reference to field of pattern + */ + public RexPatternFieldRef makePatternFieldRef(String alpha, RelDataType type, int i) { + type = SqlTypeUtil.addCharsetAndCollation(type, typeFactory); + return new RexPatternFieldRef(alpha, i, type); + } + + /** + * Create a reference to local variable. + * + * @param type Type of variable + * @param i Ordinal of variable + * @return Reference to local variable + */ + public RexLocalRef makeLocalRef(RelDataType type, int i) { + type = SqlTypeUtil.addCharsetAndCollation(type, typeFactory); + return new RexLocalRef(i, type); + } + + /** + * Creates a literal representing a flag. + * + * @param flag Flag value + */ + public RexLiteral makeFlag(Enum flag) { + assert flag != null; + return makeLiteral(flag, typeFactory.createSqlType(SqlTypeName.SYMBOL), SqlTypeName.SYMBOL); + } + + /** + * Internal method to create a call to a literal. Code outside this package should call one of + * the type-specific methods such as {@link #makeDateLiteral(DateString)}, {@link + * #makeLiteral(boolean)}, {@link #makeLiteral(String)}. + * + * @param o Value of literal, must be appropriate for the type + * @param type Type of literal + * @param typeName SQL type of literal + * @return Literal + */ + protected RexLiteral makeLiteral( + @Nullable Comparable o, RelDataType type, SqlTypeName typeName) { + // All literals except NULL have NOT NULL types. + type = typeFactory.createTypeWithNullability(type, o == null); + int p; + switch (typeName) { + case CHAR: + // Character literals must have a charset and collation. Populate + // from the type if necessary. + assert o instanceof NlsString; + NlsString nlsString = (NlsString) o; + if (nlsString.getCollation() == null + || nlsString.getCharset() == null + || !Objects.equals(nlsString.getCharset(), type.getCharset()) + || !Objects.equals(nlsString.getCollation(), type.getCollation())) { + assert type.getSqlTypeName() == SqlTypeName.CHAR + || type.getSqlTypeName() == SqlTypeName.VARCHAR; + Charset charset = type.getCharset(); + assert charset != null : "type.getCharset() must not be null"; + assert type.getCollation() != null : "type.getCollation() must not be null"; + o = new NlsString(nlsString.getValue(), charset.name(), type.getCollation()); + } + break; + case TIME: + case TIME_WITH_LOCAL_TIME_ZONE: + assert o instanceof TimeString; + p = type.getPrecision(); + if (p == RelDataType.PRECISION_NOT_SPECIFIED) { + p = 0; + } + o = ((TimeString) o).round(p); + break; + case TIME_TZ: + assert o instanceof TimeWithTimeZoneString; + p = type.getPrecision(); + if (p == RelDataType.PRECISION_NOT_SPECIFIED) { + p = 0; + } + o = ((TimeWithTimeZoneString) o).round(p); + break; + case TIMESTAMP: + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + assert o instanceof TimestampString; + p = type.getPrecision(); + if (p == RelDataType.PRECISION_NOT_SPECIFIED) { + p = 0; + } + o = ((TimestampString) o).round(p); + break; + case TIMESTAMP_TZ: + assert o instanceof TimestampWithTimeZoneString; + p = type.getPrecision(); + if (p == RelDataType.PRECISION_NOT_SPECIFIED) { + p = 0; + } + o = ((TimestampWithTimeZoneString) o).round(p); + break; + default: + break; + } + if (typeName == SqlTypeName.DECIMAL + && !SqlTypeUtil.isValidDecimalValue((BigDecimal) o, type)) { + throw new IllegalArgumentException( + "Cannot convert " + o + " to " + type + " due to overflow"); + } + return new RexLiteral(o, type, typeName); + } + + /** Creates a boolean literal. */ + public RexLiteral makeLiteral(boolean b) { + return b ? booleanTrue : booleanFalse; + } + + /** Creates a numeric literal. */ + public RexLiteral makeExactLiteral(BigDecimal bd) { + RelDataType relType; + int scale = bd.scale(); + assert scale >= 0; + assert scale <= typeFactory.getTypeSystem().getMaxNumericScale() : scale; + if (scale == 0) { + if (bd.compareTo(INT_MIN) >= 0 && bd.compareTo(INT_MAX) <= 0) { + relType = typeFactory.createSqlType(SqlTypeName.INTEGER); + } else { + relType = typeFactory.createSqlType(SqlTypeName.BIGINT); + } + } else { + int precision = bd.unscaledValue().abs().toString().length(); + if (precision > scale) { + // bd is greater than or equal to 1 + relType = typeFactory.createSqlType(SqlTypeName.DECIMAL, precision, scale); + } else { + // bd is less than 1 + relType = typeFactory.createSqlType(SqlTypeName.DECIMAL, scale + 1, scale); + } + } + return makeExactLiteral(bd, relType); + } + + /** Creates a BIGINT literal. */ + public RexLiteral makeBigintLiteral(@Nullable BigDecimal bd) { + RelDataType bigintType = typeFactory.createSqlType(SqlTypeName.BIGINT); + return makeLiteral(bd, bigintType, SqlTypeName.DECIMAL); + } + + /** Creates a numeric literal. */ + public RexLiteral makeExactLiteral(@Nullable BigDecimal bd, RelDataType type) { + return makeLiteral(bd, type, SqlTypeName.DECIMAL); + } + + /** Creates a byte array literal. */ + public RexLiteral makeBinaryLiteral(ByteString byteString) { + return makeLiteral( + byteString, + typeFactory.createSqlType(SqlTypeName.BINARY, byteString.length()), + SqlTypeName.BINARY); + } + + /** Creates a double-precision literal. */ + public RexLiteral makeApproxLiteral(BigDecimal bd) { + // Validator should catch if underflow is allowed + // If underflow is allowed, let underflow become zero + if (bd.doubleValue() == 0) { + bd = BigDecimal.ZERO; + } + return makeApproxLiteral(bd, typeFactory.createSqlType(SqlTypeName.DOUBLE)); + } + + /** + * Creates an approximate numeric literal (double or float). + * + * @param bd literal value + * @param type approximate numeric type + * @return new literal + */ + public RexLiteral makeApproxLiteral(@Nullable BigDecimal bd, RelDataType type) { + assert SqlTypeFamily.APPROXIMATE_NUMERIC.getTypeNames().contains(type.getSqlTypeName()); + return makeLiteral(bd, type, SqlTypeName.DOUBLE); + } + + /** Creates a search argument literal. */ + public RexLiteral makeSearchArgumentLiteral(Sarg s, RelDataType type) { + return makeLiteral(requireNonNull(s, "s"), type, SqlTypeName.SARG); + } + + /** Creates a character string literal. */ + public RexLiteral makeLiteral(String s) { + assert s != null; + return makePreciseStringLiteral(s); + } + + /** + * Creates a character string literal with type CHAR and default charset and collation. + * + * @param s String value + * @return Character string literal + */ + protected RexLiteral makePreciseStringLiteral(String s) { + assert s != null; + if (s.equals("")) { + return charEmpty; + } + return makeCharLiteral(new NlsString(s, null, null)); + } + + /** + * Creates a character string literal with type CHAR. + * + * @param value String value in bytes + * @param charsetName SQL-level charset name + * @param collation Sql collation + * @return String literal + */ + protected RexLiteral makePreciseStringLiteral( + ByteString value, String charsetName, SqlCollation collation) { + return makeCharLiteral(new NlsString(value, charsetName, collation)); + } + + /** + * Ensures expression is interpreted as a specified type. The returned expression may be wrapped + * with a cast. + * + * @param type desired type + * @param node expression + * @param matchNullability whether to correct nullability of specified type to match the + * expression; this usually should be true, except for explicit casts which can override + * default nullability + * @return a casted expression or the original expression + */ + public RexNode ensureType(RelDataType type, RexNode node, boolean matchNullability) { + RelDataType targetType = type; + if (matchNullability) { + targetType = matchNullability(type, node); + } + + if (targetType.getSqlTypeName() == SqlTypeName.ANY + && (!matchNullability || targetType.isNullable() == node.getType().isNullable())) { + return node; + } + + if (!node.getType().equals(targetType)) { + return makeCast(targetType, node); + } + return node; + } + + /** Ensures that a type's nullability matches a value's nullability. */ + public RelDataType matchNullability(RelDataType type, RexNode value) { + boolean typeNullability = type.isNullable(); + boolean valueNullability = value.getType().isNullable(); + if (typeNullability != valueNullability) { + return typeFactory.createTypeWithNullability(type, valueNullability); + } + return type; + } + + /** + * Creates a character string literal from an {@link NlsString}. + * + *

If the string's charset and collation are not set, uses the system defaults. + */ + public RexLiteral makeCharLiteral(NlsString str) { + assert str != null; + RelDataType type = SqlUtil.createNlsStringType(typeFactory, str); + return makeLiteral(str, type, SqlTypeName.CHAR); + } + + // CHECKSTYLE: IGNORE 1 + /** + * @deprecated Use {@link #makeDateLiteral(DateString)}. + */ + @Deprecated // to be removed before 2.0 + public RexLiteral makeDateLiteral(Calendar calendar) { + return makeDateLiteral(DateString.fromCalendarFields(calendar)); + } + + /** Creates a Date literal. */ + public RexLiteral makeDateLiteral(DateString date) { + return makeLiteral( + requireNonNull(date, "date"), + typeFactory.createSqlType(SqlTypeName.DATE), + SqlTypeName.DATE); + } + + // CHECKSTYLE: IGNORE 1 + /** + * @deprecated Use {@link #makeTimeLiteral(TimeString, int)}. + */ + @Deprecated // to be removed before 2.0 + public RexLiteral makeTimeLiteral(Calendar calendar, int precision) { + return makeTimeLiteral(TimeString.fromCalendarFields(calendar), precision); + } + + /** Creates a Time literal. */ + public RexLiteral makeTimeLiteral(TimeString time, int precision) { + return makeLiteral( + requireNonNull(time, "time"), + typeFactory.createSqlType(SqlTypeName.TIME, precision), + SqlTypeName.TIME); + } + + /** Creates a Time with local time-zone literal. */ + public RexLiteral makeTimeWithLocalTimeZoneLiteral(TimeString time, int precision) { + return makeLiteral( + requireNonNull(time, "time"), + typeFactory.createSqlType(SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE, precision), + SqlTypeName.TIME_WITH_LOCAL_TIME_ZONE); + } + + /** Creates a Time with time-zone literal. */ + public RexLiteral makeTimeTzLiteral(TimeWithTimeZoneString time, int precision) { + return makeLiteral( + requireNonNull(time, "time"), + typeFactory.createSqlType(SqlTypeName.TIME_TZ, precision), + SqlTypeName.TIME_TZ); + } + + // CHECKSTYLE: IGNORE 1 + /** + * @deprecated Use {@link #makeTimestampLiteral(TimestampString, int)}. + */ + @Deprecated // to be removed before 2.0 + public RexLiteral makeTimestampLiteral(Calendar calendar, int precision) { + return makeTimestampLiteral(TimestampString.fromCalendarFields(calendar), precision); + } + + /** Creates a Timestamp literal. */ + public RexLiteral makeTimestampLiteral(TimestampString timestamp, int precision) { + return makeLiteral( + requireNonNull(timestamp, "timestamp"), + typeFactory.createSqlType(SqlTypeName.TIMESTAMP, precision), + SqlTypeName.TIMESTAMP); + } + + /** Creates a Timestamp with local time-zone literal. */ + public RexLiteral makeTimestampWithLocalTimeZoneLiteral( + TimestampString timestamp, int precision) { + return makeLiteral( + requireNonNull(timestamp, "timestamp"), + typeFactory.createSqlType(SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE, precision), + SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE); + } + + public RexLiteral makeTimestampTzLiteral(TimestampWithTimeZoneString timestamp, int precision) { + return makeLiteral( + requireNonNull(timestamp, "timestamp"), + typeFactory.createSqlType(SqlTypeName.TIMESTAMP_TZ, precision), + SqlTypeName.TIMESTAMP_TZ); + } + + /** + * Creates a literal representing an interval type, for example {@code YEAR TO MONTH} or {@code + * DOW}. + */ + public RexLiteral makeIntervalLiteral(SqlIntervalQualifier intervalQualifier) { + verifyNotNull(intervalQualifier); + if (intervalQualifier.timeFrameName != null) { + return makePreciseStringLiteral(intervalQualifier.timeFrameName); + } + return makeFlag(intervalQualifier.timeUnitRange); + } + + /** + * Creates a literal representing an interval value, for example {@code INTERVAL '3-7' YEAR TO + * MONTH}. + */ + public RexLiteral makeIntervalLiteral( + @Nullable BigDecimal v, SqlIntervalQualifier intervalQualifier) { + return makeLiteral( + v, + typeFactory.createSqlIntervalType(intervalQualifier), + intervalQualifier.typeName()); + } + + /** + * Creates a reference to a dynamic parameter. + * + * @param type Type of dynamic parameter + * @param index Index of dynamic parameter + * @return Expression referencing dynamic parameter + */ + public RexDynamicParam makeDynamicParam(RelDataType type, int index) { + return new RexDynamicParam(type, index); + } + + /** + * Creates a literal whose value is NULL, with a particular type. + * + *

The typing is necessary because RexNodes are strictly typed. For example, in the Rex world + * the NULL parameter to + * SUBSTRING(NULL FROM 2 FOR 4) must have a valid VARCHAR type so that the result type + * can be determined. + * + * @param type Type to cast NULL to + * @return NULL literal of given type + */ + public RexLiteral makeNullLiteral(RelDataType type) { + if (!type.isNullable()) { + type = typeFactory.createTypeWithNullability(type, true); + } + return (RexLiteral) makeCast(type, constantNull); + } + + // CHECKSTYLE: IGNORE 1 + /** + * @deprecated Use {@link #makeNullLiteral(RelDataType)} + */ + @Deprecated // to be removed before 2.0 + public RexNode makeNullLiteral(SqlTypeName typeName, int precision) { + return makeNullLiteral(typeFactory.createSqlType(typeName, precision)); + } + + // CHECKSTYLE: IGNORE 1 + /** + * @deprecated Use {@link #makeNullLiteral(RelDataType)} + */ + @Deprecated // to be removed before 2.0 + public RexNode makeNullLiteral(SqlTypeName typeName) { + return makeNullLiteral(typeFactory.createSqlType(typeName)); + } + + /** + * Creates a {@link RexNode} representation a SQL "arg IN (point, ...)" expression. + * + *

If all of the expressions are literals, creates a call {@link Sarg} literal, "SEARCH(arg, + * SARG([point0..point0], [point1..point1], ...)"; otherwise creates a disjunction, "arg = + * point0 OR arg = point1 OR ...". + */ + public RexNode makeIn(RexNode arg, List ranges) { + if (areAssignable(arg, ranges)) { + final Sarg sarg = toSarg(Comparable.class, ranges, RexUnknownAs.UNKNOWN); + if (sarg != null) { + final List types = + ranges.stream().map(RexNode::getType).collect(Collectors.toList()); + RelDataType sargType = + requireNonNull( + typeFactory.leastRestrictive(types), + () -> "Can't find leastRestrictive type for SARG among " + types); + return makeCall( + SqlStdOperatorTable.SEARCH, arg, makeSearchArgumentLiteral(sarg, sargType)); + } + } + return RexUtil.composeDisjunction( + this, + ranges.stream() + .map(r -> makeCall(SqlStdOperatorTable.EQUALS, arg, r)) + .collect(toImmutableList())); + } + + /** + * Returns whether and argument and bounds are have types that are sufficiently compatible to be + * converted to a {@link Sarg}. + */ + private static boolean areAssignable(RexNode arg, List bounds) { + for (RexNode bound : bounds) { + if (!SqlTypeUtil.inSameFamily(arg.getType(), bound.getType()) + && !(arg.getType().isStruct() && bound.getType().isStruct())) { + return false; + } + } + return true; + } + + /** + * Creates a {@link RexNode} representation a SQL "arg BETWEEN lower AND upper" expression. + * + *

If the expressions are all literals of compatible type, creates a call to {@link Sarg} + * literal, {@code SEARCH(arg, SARG([lower..upper])}; otherwise creates a disjunction, {@code + * arg >= lower AND arg <= upper}. + */ + @SuppressWarnings("BetaApi") + public RexNode makeBetween(RexNode arg, RexNode lower, RexNode upper) { + final Comparable lowerValue = toComparable(Comparable.class, lower); + final Comparable upperValue = toComparable(Comparable.class, upper); + if (lowerValue != null + && upperValue != null + && areAssignable(arg, Arrays.asList(lower, upper))) { + final Sarg sarg = + Sarg.of( + RexUnknownAs.UNKNOWN, + ImmutableRangeSet.of(Range.closed(lowerValue, upperValue))); + List types = ImmutableList.of(lower.getType(), upper.getType()); + RelDataType sargType = + requireNonNull( + typeFactory.leastRestrictive(types), + () -> "Can't find leastRestrictive type for SARG among " + types); + return makeCall( + SqlStdOperatorTable.SEARCH, arg, makeSearchArgumentLiteral(sarg, sargType)); + } + return makeCall( + SqlStdOperatorTable.AND, + makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, arg, lower), + makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, arg, upper)); + } + + /** Converts a list of expressions to a search argument, or returns null if not possible. */ + @SuppressWarnings({"BetaApi", "UnstableApiUsage"}) + private static > @Nullable Sarg toSarg( + Class clazz, List ranges, RexUnknownAs unknownAs) { + if (ranges.isEmpty()) { + // Cannot convert an empty list to a Sarg (by this interface, at least) + // because we use the type of the first element. + return null; + } + final RangeSet rangeSet = TreeRangeSet.create(); + for (RexNode range : ranges) { + final C value = toComparable(clazz, range); + if (value == null) { + return null; + } + rangeSet.add(Range.singleton(value)); + } + return Sarg.of(unknownAs, rangeSet); + } + + private static > @Nullable C toComparable( + Class clazz, RexNode point) { + switch (point.getKind()) { + case LITERAL: + final RexLiteral literal = (RexLiteral) point; + return literal.getValueAs(clazz); + + case ROW: + final RexCall call = (RexCall) point; + final ImmutableList.Builder b = ImmutableList.builder(); + for (RexNode operand : call.operands) { + //noinspection unchecked + final Comparable value = toComparable(Comparable.class, operand); + if (value == null) { + return null; // not a constant value + } + b.add(value); + } + return clazz.cast(FlatLists.ofComparable(b.build())); + + default: + return null; // not a constant value + } + } + + /** + * Creates a copy of an expression, which may have been created using a different RexBuilder + * and/or {@link RelDataTypeFactory}, using this RexBuilder. + * + * @param expr Expression + * @return Copy of expression + * @see RelDataTypeFactory#copyType(RelDataType) + */ + public RexNode copy(RexNode expr) { + return expr.accept(new RexCopier(this)); + } + + /** + * Creates a literal of the default value for the given type. + * + *

This value is: + * + *

    + *
  • 0 for numeric types; + *
  • FALSE for BOOLEAN; + *
  • The epoch for TIMESTAMP and DATE; + *
  • Midnight for TIME; + *
  • The empty string for string types (CHAR, BINARY, VARCHAR, VARBINARY). + *
+ * + * @param type Type + * @return Simple literal + */ + public RexLiteral makeZeroLiteral(RelDataType type) { + return makeLiteral(zeroValue(type), type); + } + + private static Comparable zeroValue(RelDataType type) { + switch (type.getSqlTypeName()) { + case CHAR: + return new NlsString(Spaces.of(type.getPrecision()), null, null); + case VARCHAR: + return new NlsString("", null, null); + case BINARY: + return new ByteString(new byte[type.getPrecision()]); + case VARBINARY: + return ByteString.EMPTY; + case TINYINT: + case SMALLINT: + case INTEGER: + case BIGINT: + case DECIMAL: + case FLOAT: + case REAL: + case DOUBLE: + return BigDecimal.ZERO; + case BOOLEAN: + return false; + case TIME: + case DATE: + case TIMESTAMP: + return DateTimeUtils.ZERO_CALENDAR; + case TIME_WITH_LOCAL_TIME_ZONE: + return new TimeString(0, 0, 0); + case TIME_TZ: + return new TimeWithTimeZoneString(0, 0, 0, "GMT+00"); + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + return new TimestampString(0, 1, 1, 0, 0, 0); + case TIMESTAMP_TZ: + return new TimestampWithTimeZoneString(0, 1, 1, 0, 0, 0, "GMT+00"); + default: + throw Util.unexpected(type.getSqlTypeName()); + } + } + + /** + * Creates a literal of a given type, padding values of constant-width types to match their + * type, not allowing casts. + * + * @param value Value + * @param type Type + * @return Simple literal + */ + public RexLiteral makeLiteral(@Nullable Object value, RelDataType type) { + return (RexLiteral) makeLiteral(value, type, false, false); + } + + /** + * Creates a literal of a given type, padding values of constant-width types to match their + * type. + * + * @param value Value + * @param type Type + * @param allowCast Whether to allow a cast. If false, value is always a {@link RexLiteral} but + * may not be the exact type + * @return Simple literal, or cast simple literal + */ + public RexNode makeLiteral(@Nullable Object value, RelDataType type, boolean allowCast) { + return makeLiteral(value, type, allowCast, false); + } + + /** + * Creates a literal of a given type. The value is assumed to be compatible with the type. + * + *

The {@code trim} parameter controls whether to trim values of constant-width types such as + * {@code CHAR}. Consider a call to {@code makeLiteral("foo ", CHAR(5)}, and note that the value + * is too short for its type. If {@code trim} is true, the value is converted to "foo" and the + * type to {@code CHAR(3)}; if {@code trim} is false, the value is right-padded with spaces to + * {@code "foo "}, to match the type {@code CHAR(5)}. + * + * @param value Value + * @param type Type + * @param allowCast Whether to allow a cast. If false, value is always a {@link RexLiteral} but + * may not be the exact type + * @param trim Whether to trim values and type to the shortest equivalent value; for example + * whether to convert CHAR(4) 'foo ' to CHAR(3) 'foo' + * @return Simple literal, or cast simple literal + */ + public RexNode makeLiteral( + @Nullable Object value, RelDataType type, boolean allowCast, boolean trim) { + if (value == null) { + return makeCast(type, constantNull); + } + if (type.isNullable()) { + final RelDataType typeNotNull = typeFactory.createTypeWithNullability(type, false); + if (allowCast) { + RexNode literalNotNull = makeLiteral(value, typeNotNull, allowCast); + return makeAbstractCast(type, literalNotNull, false); + } + type = typeNotNull; + } + value = clean(value, type); + RexLiteral literal; + final List operands; + final SqlTypeName sqlTypeName = type.getSqlTypeName(); + switch (sqlTypeName) { + case CHAR: + final NlsString nlsString = (NlsString) value; + if (trim) { + return makeCharLiteral(nlsString.rtrim()); + } else { + return makeCharLiteral(padRight(nlsString, type.getPrecision())); + } + case VARCHAR: + literal = makeCharLiteral((NlsString) value); + if (allowCast) { + return makeCast(type, literal); + } else { + return literal; + } + case BINARY: + return makeBinaryLiteral(padRight((ByteString) value, type.getPrecision())); + case VARBINARY: + literal = makeBinaryLiteral((ByteString) value); + if (allowCast) { + return makeCast(type, literal); + } else { + return literal; + } + case TINYINT: + case SMALLINT: + case INTEGER: + case BIGINT: + case DECIMAL: + if (value instanceof RexLiteral + && ((RexLiteral) value).getTypeName() == SqlTypeName.SARG) { + return (RexNode) value; + } + return makeExactLiteral((BigDecimal) value, type); + case FLOAT: + case REAL: + case DOUBLE: + return makeApproxLiteral((BigDecimal) value, type); + case BOOLEAN: + return (Boolean) value ? booleanTrue : booleanFalse; + case TIME: + return makeTimeLiteral((TimeString) value, type.getPrecision()); + case TIME_WITH_LOCAL_TIME_ZONE: + return makeTimeWithLocalTimeZoneLiteral((TimeString) value, type.getPrecision()); + case TIME_TZ: + return makeTimeTzLiteral((TimeWithTimeZoneString) value, type.getPrecision()); + case DATE: + return makeDateLiteral((DateString) value); + case TIMESTAMP: + return makeTimestampLiteral((TimestampString) value, type.getPrecision()); + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + return makeTimestampWithLocalTimeZoneLiteral( + (TimestampString) value, type.getPrecision()); + case TIMESTAMP_TZ: + return makeTimestampTzLiteral( + (TimestampWithTimeZoneString) value, type.getPrecision()); + case INTERVAL_YEAR: + case INTERVAL_YEAR_MONTH: + case INTERVAL_MONTH: + case INTERVAL_DAY: + case INTERVAL_DAY_HOUR: + case INTERVAL_DAY_MINUTE: + case INTERVAL_DAY_SECOND: + case INTERVAL_HOUR: + case INTERVAL_HOUR_MINUTE: + case INTERVAL_HOUR_SECOND: + case INTERVAL_MINUTE: + case INTERVAL_MINUTE_SECOND: + case INTERVAL_SECOND: + return makeIntervalLiteral( + (BigDecimal) value, castNonNull(type.getIntervalQualifier())); + case SYMBOL: + return makeFlag((Enum) value); + case MAP: + final MapSqlType mapType = (MapSqlType) type; + @SuppressWarnings("unchecked") + final Map map = (Map) value; + operands = new ArrayList<>(); + for (Map.Entry entry : map.entrySet()) { + operands.add(makeLiteral(entry.getKey(), mapType.getKeyType(), allowCast)); + operands.add(makeLiteral(entry.getValue(), mapType.getValueType(), allowCast)); + } + return makeCall(SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR, operands); + case ARRAY: + final ArraySqlType arrayType = (ArraySqlType) type; + @SuppressWarnings("unchecked") + final List listValue = (List) value; + operands = new ArrayList<>(); + for (Object entry : listValue) { + operands.add(makeLiteral(entry, arrayType.getComponentType(), allowCast)); + } + return makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, operands); + case MULTISET: + final MultisetSqlType multisetType = (MultisetSqlType) type; + operands = new ArrayList<>(); + for (Object entry : (List) value) { + final RexNode e = + entry instanceof RexLiteral + ? (RexNode) entry + : makeLiteral( + entry, multisetType.getComponentType(), allowCast); + operands.add(e); + } + if (allowCast) { + return makeCall(SqlStdOperatorTable.MULTISET_VALUE, operands); + } else { + return new RexLiteral((Comparable) FlatLists.of(operands), type, sqlTypeName); + } + case ROW: + operands = new ArrayList<>(); + //noinspection unchecked + for (Pair pair : + Pair.zip(type.getFieldList(), (List) value)) { + final RexNode e = + pair.right instanceof RexLiteral + ? (RexNode) pair.right + : makeLiteral(pair.right, pair.left.getType(), allowCast); + operands.add(e); + } + return new RexLiteral((Comparable) FlatLists.of(operands), type, sqlTypeName); + case GEOMETRY: + return new RexLiteral((Comparable) value, guessType(value), SqlTypeName.GEOMETRY); + case ANY: + return makeLiteral(value, guessType(value), allowCast); + default: + throw new IllegalArgumentException( + "Cannot create literal for type '" + sqlTypeName + "'"); + } + } + + /** + * Creates a lambda expression. + * + * @param expr expression of the lambda + * @param parameters parameters of the lambda + * @return RexNode representing the lambda + */ + public RexNode makeLambdaCall(RexNode expr, List parameters) { + return new RexLambda(parameters, expr); + } + + /** + * Converts the type of a value to comply with {@link + * org.apache.calcite.rex.RexLiteral#valueMatchesType}. + * + *

Returns null if and only if {@code o} is null. + */ + private @PolyNull Object clean(@PolyNull Object o, RelDataType type) { + if (o == null) { + return o; + } + if (o instanceof Sarg) { + return makeSearchArgumentLiteral((Sarg) o, type); + } + switch (type.getSqlTypeName()) { + case TINYINT: + case SMALLINT: + case INTEGER: + case BIGINT: + case DECIMAL: + case INTERVAL_YEAR: + case INTERVAL_YEAR_MONTH: + case INTERVAL_MONTH: + case INTERVAL_DAY: + case INTERVAL_DAY_HOUR: + case INTERVAL_DAY_MINUTE: + case INTERVAL_DAY_SECOND: + case INTERVAL_HOUR: + case INTERVAL_HOUR_MINUTE: + case INTERVAL_HOUR_SECOND: + case INTERVAL_MINUTE: + case INTERVAL_MINUTE_SECOND: + case INTERVAL_SECOND: + if (o instanceof BigDecimal) { + return o; + } + assert !(o instanceof Float || o instanceof Double) + : String.format( + Locale.ROOT, + "%s is not compatible with %s, try to use makeExactLiteral", + o.getClass().getCanonicalName(), + type.getSqlTypeName()); + return new BigDecimal(((Number) o).longValue()); + case REAL: + if (o instanceof BigDecimal) { + return o; + } + return new BigDecimal(((Number) o).doubleValue(), MathContext.DECIMAL32) + .stripTrailingZeros(); + case FLOAT: + case DOUBLE: + if (o instanceof BigDecimal) { + return o; + } + return new BigDecimal(((Number) o).doubleValue(), MathContext.DECIMAL64) + .stripTrailingZeros(); + case CHAR: + case VARCHAR: + if (o instanceof NlsString) { + return o; + } + assert type.getCharset() != null : type + ".getCharset() must not be null"; + return new NlsString((String) o, type.getCharset().name(), type.getCollation()); + case TIME: + if (o instanceof TimeString) { + return o; + } else if (o instanceof Calendar) { + if (!((Calendar) o).getTimeZone().equals(DateTimeUtils.UTC_ZONE)) { + throw new AssertionError(); + } + return TimeString.fromCalendarFields((Calendar) o); + } else { + return TimeString.fromMillisOfDay((Integer) o); + } + case TIME_TZ: + if (o instanceof TimeWithTimeZoneString) { + return o; + } else if (o instanceof Calendar) { + return TimeWithTimeZoneString.fromCalendarFields((Calendar) o); + } else { + throw new AssertionError("Value does not contain time zone"); + } + case TIME_WITH_LOCAL_TIME_ZONE: + if (o instanceof TimeString) { + return o; + } else { + return TimeString.fromMillisOfDay((Integer) o); + } + case DATE: + if (o instanceof DateString) { + return o; + } else if (o instanceof Calendar) { + if (!((Calendar) o).getTimeZone().equals(DateTimeUtils.UTC_ZONE)) { + throw new AssertionError(); + } + return DateString.fromCalendarFields((Calendar) o); + } else { + return DateString.fromDaysSinceEpoch((Integer) o); + } + case TIMESTAMP: + if (o instanceof TimestampString) { + return o; + } else if (o instanceof Calendar) { + if (!((Calendar) o).getTimeZone().equals(DateTimeUtils.UTC_ZONE)) { + throw new AssertionError(); + } + return TimestampString.fromCalendarFields((Calendar) o); + } else { + return TimestampString.fromMillisSinceEpoch((Long) o); + } + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + if (o instanceof TimestampString) { + return o; + } else { + return TimestampString.fromMillisSinceEpoch((Long) o); + } + case TIMESTAMP_TZ: + if (o instanceof TimestampWithTimeZoneString) { + return o; + } else if (o instanceof Calendar) { + return TimestampWithTimeZoneString.fromCalendarFields((Calendar) o); + } else { + throw new AssertionError("Value does not contain time zone"); + } + default: + return o; + } + } + + private RelDataType guessType(@Nullable Object value) { + if (value == null) { + return typeFactory.createSqlType(SqlTypeName.NULL); + } + if (value instanceof Float || value instanceof Double) { + return typeFactory.createSqlType(SqlTypeName.DOUBLE); + } + if (value instanceof Number) { + return typeFactory.createSqlType(SqlTypeName.BIGINT); + } + if (value instanceof Boolean) { + return typeFactory.createSqlType(SqlTypeName.BOOLEAN); + } + if (value instanceof String) { + return typeFactory.createSqlType(SqlTypeName.CHAR, ((String) value).length()); + } + if (value instanceof ByteString) { + return typeFactory.createSqlType(SqlTypeName.BINARY, ((ByteString) value).length()); + } + if (value instanceof Geometry) { + return typeFactory.createSqlType(SqlTypeName.GEOMETRY); + } + throw new AssertionError("unknown type " + value.getClass()); + } + + /** Returns an {@link NlsString} with spaces to make it at least a given length. */ + private static NlsString padRight(NlsString s, int length) { + if (s.getValue().length() >= length) { + return s; + } + return s.copy(padRight(s.getValue(), length)); + } + + /** Returns a string padded with spaces to make it at least a given length. */ + private static String padRight(String s, int length) { + if (s.length() >= length) { + return s; + } + return new StringBuilder().append(s).append(Spaces.MAX, s.length(), length).toString(); + } + + /** Returns a byte-string padded with zero bytes to make it at least a given length. */ + private static ByteString padRight(ByteString s, int length) { + if (s.length() >= length) { + return s; + } + return new ByteString(Arrays.copyOf(s.getBytes(), length)); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexChecker.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexChecker.java new file mode 100644 index 0000000000000..9d7506c20342c --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexChecker.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rex; + +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.util.Litmus; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +/** + * Visitor which checks the validity of a {@link RexNode} expression. + * + *

FLINK modifications (backport of CALCITE-6764): Lines 121 ~ 133, 164 ~ 177 + * + *

There are two modes of operation: + * + *

    + *
  • Usefail=true to throw an {@link AssertionError} as soon as an invalid node is + * detected: + *
    + * RexNode node;
    + * RelDataType rowType;
    + * assert new RexChecker(rowType, true).isValid(node);
    + *
    + *

    This mode requires that assertions are enabled. + *

  • Use fail=false to test for validity without throwing an error. + *
    + * RexNode node;
    + * RelDataType rowType;
    + * RexChecker checker = new RexChecker(rowType, false);
    + * node.accept(checker);
    + * if (!checker.valid) {
    + *    ...
    + * }
    + *
    + *
+ * + * @see RexNode + */ +public class RexChecker extends RexVisitorImpl { + // ~ Instance fields -------------------------------------------------------- + + protected final RelNode.Context context; + protected final Litmus litmus; + protected final List inputTypeList; + protected int failCount; + + // ~ Constructors ----------------------------------------------------------- + + /** + * Creates a RexChecker with a given input row type. + * + *

If fail is true, the checker will throw an {@link AssertionError} if an + * invalid node is found and assertions are enabled. + * + *

Otherwise, each method returns whether its part of the tree is valid. + * + * @param inputRowType Input row type + * @param context Context of the enclosing {@link RelNode}, or null + * @param litmus What to do if an invalid node is detected + */ + public RexChecker(final RelDataType inputRowType, RelNode.Context context, Litmus litmus) { + this(RelOptUtil.getFieldTypeList(inputRowType), context, litmus); + } + + /** + * Creates a RexChecker with a given set of input fields. + * + *

If fail is true, the checker will throw an {@link AssertionError} if an + * invalid node is found and assertions are enabled. + * + *

Otherwise, each method returns whether its part of the tree is valid. + * + * @param inputTypeList Input row type + * @param context Context of the enclosing {@link RelNode}, or null + * @param litmus What to do if an error is detected + */ + public RexChecker(List inputTypeList, RelNode.Context context, Litmus litmus) { + super(true); + this.inputTypeList = inputTypeList; + this.context = context; + this.litmus = litmus; + } + + // ~ Methods ---------------------------------------------------------------- + + /** + * Returns the number of failures encountered. + * + * @return Number of failures + */ + public int getFailureCount() { + return failCount; + } + + @Override + public Boolean visitInputRef(RexInputRef ref) { + final int index = ref.getIndex(); + if ((index < 0) || (index >= inputTypeList.size())) { + ++failCount; + return litmus.fail( + "RexInputRef index {} out of range 0..{}", index, inputTypeList.size() - 1); + } + // Type of field and type of result can differ in nullability. See [CALCITE-6764] + if (!ref.getType().isStruct() + && !RelOptUtil.eqUpToNullability( + ref.getType().isNullable(), + "ref", + ref.getType(), + "input", + inputTypeList.get(index), + litmus)) { + ++failCount; + return litmus.fail(null); + } + return litmus.succeed(); + } + + @Override + public Boolean visitLocalRef(RexLocalRef ref) { + ++failCount; + return litmus.fail("RexLocalRef illegal outside program"); + } + + @Override + public Boolean visitCall(RexCall call) { + for (RexNode operand : call.getOperands()) { + Boolean valid = operand.accept(this); + if (valid != null && !valid) { + return litmus.fail(null); + } + } + return litmus.succeed(); + } + + @Override + public Boolean visitFieldAccess(RexFieldAccess fieldAccess) { + super.visitFieldAccess(fieldAccess); + final RelDataType refType = fieldAccess.getReferenceExpr().getType(); + assert refType.isStruct(); + final RelDataTypeField field = fieldAccess.getField(); + final int index = field.getIndex(); + if ((index < 0) || (index >= refType.getFieldList().size())) { + ++failCount; + return litmus.fail(null); + } + // Type of field may not match type of field access - they may differ in nullability + final RelDataTypeField typeField = refType.getFieldList().get(index); + if (!RelOptUtil.eqUpToNullability( + refType.isNullable(), + "type1", + typeField.getType(), + "type2", + fieldAccess.getType(), + litmus)) { + ++failCount; + return litmus.fail(null); + } + return litmus.succeed(); + } + + @Override + public Boolean visitCorrelVariable(RexCorrelVariable v) { + if (context != null && !context.correlationIds().contains(v.id)) { + ++failCount; + return litmus.fail( + "correlation id {} not found in correlation list {}", + v, + context.correlationIds()); + } + return litmus.succeed(); + } + + /** Returns whether an expression is valid. */ + public final boolean isValid(RexNode expr) { + return requireNonNull(expr.accept(this), () -> "expr.accept(RexChecker) for expr=" + expr); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexFieldAccess.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexFieldAccess.java new file mode 100644 index 0000000000000..955cc334db62b --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexFieldAccess.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rex; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.sql.SqlKind; +import org.checkerframework.checker.nullness.qual.Nullable; + +import static com.google.common.base.Preconditions.checkArgument; + +/** + * Access to a field of a row-expression. + * + *

FLINK modifications (backport of CALCITE-6764): Lines 59, 63 ~ 65, 67 ~ 73, 95 + * + *

You might expect to use a RexFieldAccess to access columns of relational tables, + * for example, the expression emp.empno in the query + * + *

+ * + *
SELECT emp.empno FROM emp
+ * + *
+ * + *

but there is a specialized expression {@link RexInputRef} for this purpose. So in practice, + * RexFieldAccess is usually used to access fields of correlating variables, for + * example the expression emp.deptno in + * + *

+ * + *
SELECT ename
+ * FROM dept
+ * WHERE EXISTS (
+ *     SELECT NULL
+ *     FROM emp
+ *     WHERE emp.deptno = dept.deptno
+ *     AND gender = 'F')
+ * + *
+ */ +public class RexFieldAccess extends RexNode { + // ~ Instance fields -------------------------------------------------------- + + private final RexNode expr; + private final RelDataTypeField field; + // Not always the same as the field.getType(). + private final RelDataType type; + + // ~ Constructors ----------------------------------------------------------- + + RexFieldAccess(RexNode expr, RelDataTypeField field) { + this(expr, field, field.getType()); + } + + RexFieldAccess(RexNode expr, RelDataTypeField field, RelDataType type) { + checkValid(expr, field); + this.expr = expr; + this.field = field; + this.digest = expr + "." + field.getName(); + this.type = type; + } + + // ~ Methods ---------------------------------------------------------------- + + private static void checkValid(RexNode expr, RelDataTypeField field) { + RelDataType exprType = expr.getType(); + int fieldIdx = field.getIndex(); + checkArgument( + fieldIdx >= 0 + && fieldIdx < exprType.getFieldList().size() + && exprType.getFieldList().get(fieldIdx).equals(field), + "Field %s does not exist for expression %s", + field, + expr); + } + + public RelDataTypeField getField() { + return field; + } + + @Override + public RelDataType getType() { + return type; + } + + @Override + public SqlKind getKind() { + return SqlKind.FIELD_ACCESS; + } + + @Override + public R accept(RexVisitor visitor) { + return visitor.visitFieldAccess(this); + } + + @Override + public R accept(RexBiVisitor visitor, P arg) { + return visitor.visitFieldAccess(this, arg); + } + + /** Returns the expression whose field is being accessed. */ + public RexNode getReferenceExpr() { + return expr; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + RexFieldAccess that = (RexFieldAccess) o; + + return field.equals(that.field) && expr.equals(that.expr); + } + + @Override + public int hashCode() { + int result = expr.hashCode(); + result = 31 * result + field.hashCode(); + return result; + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexProgram.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexProgram.java new file mode 100644 index 0000000000000..04dea5bbca1b3 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexProgram.java @@ -0,0 +1,984 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rex; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Ordering; +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.RelOptPredicateList; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelInput; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.externalize.RelJsonWriter; +import org.apache.calcite.rel.externalize.RelWriterImpl; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.Litmus; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Permutation; +import org.apache.calcite.util.mapping.MappingType; +import org.apache.calcite.util.mapping.Mappings; +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.AbstractList; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.linq4j.Nullness.castNonNull; + +/** + * A collection of expressions which read inputs, compute output expressions, and optionally use a + * condition to filter rows. + * + *

Programs are immutable. It may help to use a {@link RexProgramBuilder}, which has the same + * relationship to {@link RexProgram} as {@link StringBuilder} has to {@link String}. + * + *

A program can contain aggregate functions. If it does, the arguments to each aggregate + * function must be an {@link RexInputRef}. + * + *

FLINK modifications (backport of CALCITE-6764): Lines 999 ~ 1002 + * + * @see RexProgramBuilder + */ +public class RexProgram { + // ~ Instance fields -------------------------------------------------------- + + /** + * First stage of expression evaluation. The expressions in this array can refer to inputs + * (using input ordinal #0) or previous expressions in the array (using input ordinal #1). + */ + private final List exprs; + + /** With {@link #condition}, the second stage of expression evaluation. */ + private final List projects; + + /** The optional condition. If null, the calculator does not filter rows. */ + private final @Nullable RexLocalRef condition; + + private final RelDataType inputRowType; + + private final RelDataType outputRowType; + + /** Reference counts for each expression, computed on demand. */ + private int[] refCounts; + + // ~ Constructors ----------------------------------------------------------- + + /** + * Creates a program. + * + *

The expressions must be valid: they must not contain common expressions, forward + * references, or non-trivial aggregates. + * + * @param inputRowType Input row type + * @param exprs Common expressions + * @param projects Projection expressions + * @param condition Condition expression. If null, calculator does not filter rows + * @param outputRowType Description of the row produced by the program + */ + public RexProgram( + RelDataType inputRowType, + List exprs, + List projects, + @Nullable RexLocalRef condition, + RelDataType outputRowType) { + this.inputRowType = inputRowType; + this.exprs = ImmutableList.copyOf(exprs); + this.projects = ImmutableList.copyOf(projects); + this.condition = condition; + this.outputRowType = outputRowType; + assert isValid(Litmus.THROW, null); + } + + // ~ Methods ---------------------------------------------------------------- + + // REVIEW jvs 16-Oct-2006: The description below is confusing. I + // think it means "none of the entries are null, there may be none, + // and there is no further reduction into smaller common sub-expressions + // possible"? + + /** + * Returns the common sub-expressions of this program. + * + *

The list is never null but may be empty; each the expression in the list is not null; and + * no further reduction into smaller common sub-expressions is possible. + */ + public List getExprList() { + return exprs; + } + + /** + * Returns an array of references to the expressions which this program is to project. Never + * null, may be empty. + */ + public List getProjectList() { + return projects; + } + + /** Returns a list of project expressions and their field names. */ + public List> getNamedProjects() { + return new AbstractList>() { + @Override + public int size() { + return projects.size(); + } + + @Override + public Pair get(int index) { + return Pair.of( + projects.get(index), outputRowType.getFieldList().get(index).getName()); + } + }; + } + + /** + * Returns the field reference of this program's filter condition, or null if there is no + * condition. + */ + public @Nullable RexLocalRef getCondition() { + return condition; + } + + /** + * Creates a program which calculates projections and filters rows based upon a condition. Does + * not attempt to eliminate common sub-expressions. + * + * @param projectExprs Project expressions + * @param conditionExpr Condition on which to filter rows, or null if rows are not to be + * filtered + * @param outputRowType Output row type + * @param rexBuilder Builder of rex expressions + * @return A program + */ + public static RexProgram create( + RelDataType inputRowType, + List projectExprs, + @Nullable RexNode conditionExpr, + RelDataType outputRowType, + RexBuilder rexBuilder) { + return create( + inputRowType, + projectExprs, + conditionExpr, + outputRowType.getFieldNames(), + rexBuilder); + } + + /** + * Creates a program which calculates projections and filters rows based upon a condition. Does + * not attempt to eliminate common sub-expressions. + * + * @param projectExprs Project expressions + * @param conditionExpr Condition on which to filter rows, or null if rows are not to be + * filtered + * @param fieldNames Names of projected fields + * @param rexBuilder Builder of rex expressions + * @return A program + */ + public static RexProgram create( + RelDataType inputRowType, + List projectExprs, + @Nullable RexNode conditionExpr, + @Nullable List fieldNames, + RexBuilder rexBuilder) { + if (fieldNames == null) { + fieldNames = Collections.nCopies(projectExprs.size(), null); + } else { + assert fieldNames.size() == projectExprs.size() + : "fieldNames=" + fieldNames + ", exprs=" + projectExprs; + } + final RexProgramBuilder programBuilder = new RexProgramBuilder(inputRowType, rexBuilder); + for (int i = 0; i < projectExprs.size(); i++) { + programBuilder.addProject(projectExprs.get(i), fieldNames.get(i)); + } + if (conditionExpr != null) { + programBuilder.addCondition(conditionExpr); + } + return programBuilder.getProgram(); + } + + /** + * Create a program from serialized output. In this case, the input is mainly from the output + * json string of {@link RelJsonWriter} + */ + public static RexProgram create(RelInput input) { + final List exprs = requireNonNull(input.getExpressionList("exprs"), "exprs"); + final List projectRexNodes = + requireNonNull(input.getExpressionList("projects"), "projects"); + final List projects = new ArrayList<>(projectRexNodes.size()); + for (RexNode rexNode : projectRexNodes) { + projects.add((RexLocalRef) rexNode); + } + final RelDataType inputType = input.getRowType("inputRowType"); + final RelDataType outputType = input.getRowType("outputRowType"); + final RexLocalRef condition = (RexLocalRef) input.getExpression("condition"); + return new RexProgram(inputType, exprs, projects, condition, outputType); + } + + // description of this calc, chiefly intended for debugging + @Override + public String toString() { + // Intended to produce similar output to explainCalc, + // but without requiring a RelNode or RelOptPlanWriter. + final RelWriterImpl pw = new RelWriterImpl(new PrintWriter(new StringWriter())); + collectExplainTerms("", pw); + return pw.simple(); + } + + /** + * Writes an explanation of the expressions in this program to a plan writer. + * + * @param pw Plan writer + */ + public RelWriter explainCalc(RelWriter pw) { + if (pw instanceof RelJsonWriter) { + return pw.item("exprs", exprs) + .item("projects", projects) + .item("condition", condition) + .item("inputRowType", inputRowType) + .item("outputRowType", outputRowType); + } else { + return collectExplainTerms("", pw, pw.getDetailLevel()); + } + } + + public RelWriter collectExplainTerms(String prefix, RelWriter pw) { + return collectExplainTerms(prefix, pw, SqlExplainLevel.EXPPLAN_ATTRIBUTES); + } + + /** + * Collects the expressions in this program into a list of terms and values. + * + * @param prefix Prefix for term names, usually the empty string, but useful if a relational + * expression contains more than one program + * @param pw Plan writer + */ + public RelWriter collectExplainTerms(String prefix, RelWriter pw, SqlExplainLevel level) { + final List inFields = inputRowType.getFieldList(); + final List outFields = outputRowType.getFieldList(); + assert outFields.size() == projects.size() + : "outFields.length=" + outFields.size() + ", projects.length=" + projects.size(); + pw.item( + prefix + "expr#0" + ((inFields.size() > 1) ? (".." + (inFields.size() - 1)) : ""), + "{inputs}"); + for (int i = inFields.size(); i < exprs.size(); i++) { + pw.item(prefix + "expr#" + i, exprs.get(i)); + } + + // If a lot of the fields are simply projections of the underlying + // expression, try to be a bit less verbose. + int trivialCount = countTrivial(projects); + + switch (trivialCount) { + case 0: + break; + case 1: + trivialCount = 0; + break; + default: + pw.item(prefix + "proj#0.." + (trivialCount - 1), "{exprs}"); + break; + } + + final boolean withFieldNames = level != SqlExplainLevel.DIGEST_ATTRIBUTES; + // Print the non-trivial fields with their names as they appear in the + // output row type. + for (int i = trivialCount; i < projects.size(); i++) { + final String fieldName = + withFieldNames ? prefix + outFields.get(i).getName() : prefix + i; + pw.item(fieldName, projects.get(i)); + } + if (condition != null) { + pw.item(prefix + "$condition", condition); + } + return pw; + } + + /** + * Returns the number of expressions at the front of an array which are simply projections of + * the same field. + * + * @param refs References + */ + private static int countTrivial(List refs) { + for (int i = 0; i < refs.size(); i++) { + RexLocalRef ref = refs.get(i); + if (ref.getIndex() != i) { + return i; + } + } + return refs.size(); + } + + /** Returns the number of expressions in this program. */ + public int getExprCount() { + return exprs.size() + projects.size() + ((condition == null) ? 0 : 1); + } + + /** Creates the identity program. */ + public static RexProgram createIdentity(RelDataType rowType) { + return createIdentity(rowType, rowType); + } + + /** + * Creates a program that projects its input fields but with possibly different names for the + * output fields. + */ + public static RexProgram createIdentity(RelDataType rowType, RelDataType outputRowType) { + if (rowType != outputRowType + && !Pair.right(rowType.getFieldList()) + .equals(Pair.right(outputRowType.getFieldList()))) { + throw new IllegalArgumentException( + "field type mismatch: " + rowType + " vs. " + outputRowType); + } + final List fields = rowType.getFieldList(); + final List projectRefs = new ArrayList<>(); + final List refs = new ArrayList<>(); + for (int i = 0; i < fields.size(); i++) { + final RexInputRef ref = RexInputRef.of(i, fields); + refs.add(ref); + projectRefs.add(new RexLocalRef(i, ref.getType())); + } + return new RexProgram(rowType, refs, projectRefs, null, outputRowType); + } + + /** + * Returns the type of the input row to the program. + * + * @return input row type + */ + public RelDataType getInputRowType() { + return inputRowType; + } + + /** + * Returns whether this program contains windowed aggregate functions. + * + * @return whether this program contains windowed aggregate functions + */ + public boolean containsAggs() { + return RexOver.containsOver(this); + } + + /** + * Returns the type of the output row from this program. + * + * @return output row type + */ + public RelDataType getOutputRowType() { + return outputRowType; + } + + /** + * Checks that this program is valid. + * + *

If fail is true, executes assert false, so will throw an {@link + * AssertionError} if assertions are enabled. If + * fail is false, merely returns whether the program is valid. + * + * @param litmus What to do if an error is detected + * @param context Context of enclosing {@link RelNode}, for validity checking, or null if not + * known + * @return Whether the program is valid + */ + public boolean isValid(Litmus litmus, RelNode.Context context) { + if (inputRowType == null) { + return litmus.fail(null); + } + if (exprs == null) { + return litmus.fail(null); + } + if (projects == null) { + return litmus.fail(null); + } + if (outputRowType == null) { + return litmus.fail(null); + } + + // If the input row type is a struct (contains fields) then the leading + // expressions must be references to those fields. But we don't require + // this if the input row type is, say, a java class. + if (inputRowType.isStruct()) { + if (!RexUtil.containIdentity(exprs, inputRowType, litmus)) { + return litmus.fail(null); + } + + // None of the other fields should be inputRefs. + for (int i = inputRowType.getFieldCount(); i < exprs.size(); i++) { + RexNode expr = exprs.get(i); + if (expr instanceof RexInputRef) { + return litmus.fail(null); + } + } + } + // todo: enable + // CHECKSTYLE: IGNORE 1 + if (false && RexUtil.containNoCommonExprs(exprs, litmus)) { + return litmus.fail(null); + } + if (!RexUtil.containNoForwardRefs(exprs, inputRowType, litmus)) { + return litmus.fail(null); + } + if (!RexUtil.containNoNonTrivialAggs(exprs, litmus)) { + return litmus.fail(null); + } + final Checker checker = new Checker(inputRowType, RexUtil.types(exprs), null, litmus); + if (condition != null) { + if (!SqlTypeUtil.inBooleanFamily(condition.getType())) { + return litmus.fail("condition must be boolean"); + } + condition.accept(checker); + if (checker.failCount > 0) { + return litmus.fail(null); + } + } + for (RexLocalRef project : projects) { + project.accept(checker); + if (checker.failCount > 0) { + return litmus.fail(null); + } + } + for (RexNode expr : exprs) { + expr.accept(checker); + if (checker.failCount > 0) { + return litmus.fail(null); + } + } + return litmus.succeed(); + } + + /** + * Returns whether an expression always evaluates to null. + * + *

Like {@link RexUtil#isNull(RexNode)}, null literals are null, and casts of null literals + * are null. But this method also regards references to null expressions as null. + * + * @param expr Expression + * @return Whether expression always evaluates to null + */ + public boolean isNull(RexNode expr) { + switch (expr.getKind()) { + case LITERAL: + return ((RexLiteral) expr).getValue2() == null; + case LOCAL_REF: + RexLocalRef inputRef = (RexLocalRef) expr; + return isNull(exprs.get(inputRef.index)); + case CAST: + return isNull(((RexCall) expr).operands.get(0)); + default: + return false; + } + } + + /** + * Fully expands a RexLocalRef back into a pure RexNode tree containing no RexLocalRefs + * (reversing the effect of common subexpression elimination). For example, + * program.expandLocalRef(program.getCondition()) will return the expansion of a + * program's condition. + * + * @param ref a RexLocalRef from this program + * @return expanded form + */ + public RexNode expandLocalRef(RexLocalRef ref) { + return ref.accept(new ExpansionShuttle(exprs)); + } + + /** Expands a list of expressions that may contain {@link RexLocalRef}s. */ + public List expandList(List nodes) { + return new ExpansionShuttle(exprs).visitList(nodes); + } + + /** + * Splits this program into a list of project expressions and a list of filter expressions. + * + *

Neither list is null. The filters are evaluated first. + */ + public Pair, ImmutableList> split() { + final List filters = new ArrayList<>(); + if (condition != null) { + RelOptUtil.decomposeConjunction(expandLocalRef(condition), filters); + } + final ImmutableList.Builder projects = ImmutableList.builder(); + for (RexLocalRef project : this.projects) { + projects.add(expandLocalRef(project)); + } + return Pair.of(projects.build(), ImmutableList.copyOf(filters)); + } + + /** + * Given a list of collations which hold for the input to this program, returns a list of + * collations which hold for its output. The result is mutable and sorted. + */ + public List getCollations(List inputCollations) { + final List outputCollations = new ArrayList<>(); + deduceCollations(outputCollations, inputRowType.getFieldCount(), projects, inputCollations); + return outputCollations; + } + + /** + * Given a list of expressions and a description of which are ordered, populates a list of + * collations, sorted in natural order. + */ + public static void deduceCollations( + List outputCollations, + final int sourceCount, + List refs, + List inputCollations) { + int[] targets = new int[sourceCount]; + Arrays.fill(targets, -1); + for (int i = 0; i < refs.size(); i++) { + final RexLocalRef ref = refs.get(i); + final int source = ref.getIndex(); + if ((source < sourceCount) && (targets[source] == -1)) { + targets[source] = i; + } + } + loop: + for (RelCollation collation : inputCollations) { + final List fieldCollations = new ArrayList<>(0); + for (RelFieldCollation fieldCollation : collation.getFieldCollations()) { + final int source = fieldCollation.getFieldIndex(); + final int target = targets[source]; + if (target < 0) { + continue loop; + } + fieldCollations.add(fieldCollation.withFieldIndex(target)); + } + + // Success -- all of the source fields of this key are mapped + // to the output. + outputCollations.add(RelCollations.of(fieldCollations)); + } + outputCollations.sort(Ordering.natural()); + } + + /** + * Returns whether the fields on the leading edge of the project list are the input fields. + * + * @param fail Whether to throw an assert failure if does not project identity + */ + public boolean projectsIdentity(final boolean fail) { + final int fieldCount = inputRowType.getFieldCount(); + if (projects.size() < fieldCount) { + assert !fail + : "program '" + + toString() + + "' does not project identity for input row type '" + + inputRowType + + "'"; + return false; + } + for (int i = 0; i < fieldCount; i++) { + RexLocalRef project = projects.get(i); + if (project.index != i) { + assert !fail + : "program " + + toString() + + "' does not project identity for input row type '" + + inputRowType + + "', field #" + + i; + return false; + } + } + return true; + } + + /** + * Returns whether this program projects precisely its input fields. It may or may not apply a + * condition. + */ + public boolean projectsOnlyIdentity() { + if (projects.size() != inputRowType.getFieldCount()) { + return false; + } + for (int i = 0; i < projects.size(); i++) { + RexLocalRef project = projects.get(i); + if (project.index != i) { + return false; + } + } + return true; + } + + /** + * Returns whether this program returns its input exactly. + * + *

This is a stronger condition than {@link #projectsIdentity(boolean)}. + */ + public boolean isTrivial() { + return getCondition() == null && projectsOnlyIdentity(); + } + + /** + * Gets reference counts for each expression in the program, where the references are detected + * from later expressions in the same program, as well as the project list and condition. + * Expressions with references counts greater than 1 are true common sub-expressions. + * + * @return array of reference counts; the ith element in the returned array is the number of + * references to getExprList()[i] + */ + public int[] getReferenceCounts() { + if (refCounts != null) { + return refCounts; + } + refCounts = new int[exprs.size()]; + ReferenceCounter refCounter = new ReferenceCounter(refCounts); + RexUtil.apply(refCounter, exprs, null); + if (condition != null) { + refCounter.visitLocalRef(condition); + } + for (RexLocalRef project : projects) { + refCounter.visitLocalRef(project); + } + return refCounts; + } + + /** Returns whether an expression is constant. */ + public boolean isConstant(RexNode ref) { + return ref.accept(new ConstantFinder()); + } + + public @Nullable RexNode gatherExpr(RexNode expr) { + return expr.accept(new Marshaller()); + } + + /** + * Returns the input field that an output field is populated from, or -1 if it is populated from + * an expression. + */ + public int getSourceField(int outputOrdinal) { + assert (outputOrdinal >= 0) && (outputOrdinal < this.projects.size()); + RexLocalRef project = projects.get(outputOrdinal); + int index = project.index; + while (true) { + RexNode expr = exprs.get(index); + if (expr instanceof RexCall + && ((RexCall) expr).getOperator() == SqlStdOperatorTable.IN_FENNEL) { + // drill through identity function + expr = ((RexCall) expr).getOperands().get(0); + } + if (expr instanceof RexLocalRef) { + index = ((RexLocalRef) expr).index; + } else if (expr instanceof RexInputRef) { + return ((RexInputRef) expr).index; + } else { + return -1; + } + } + } + + /** Returns whether this program is a permutation of its inputs. */ + public boolean isPermutation() { + if (projects.size() != inputRowType.getFieldList().size()) { + return false; + } + for (int i = 0; i < projects.size(); ++i) { + if (getSourceField(i) < 0) { + return false; + } + } + return true; + } + + /** Returns a permutation, if this program is a permutation, otherwise null. */ + public @Nullable Permutation getPermutation() { + Permutation permutation = new Permutation(projects.size()); + if (projects.size() != inputRowType.getFieldList().size()) { + return null; + } + for (int i = 0; i < projects.size(); ++i) { + int sourceField = getSourceField(i); + if (sourceField < 0) { + return null; + } + permutation.set(i, sourceField); + } + return permutation; + } + + /** + * Returns the set of correlation variables used (read) by this program. + * + * @return set of correlation variable names + */ + public Set getCorrelVariableNames() { + final Set paramIdSet = new HashSet<>(); + RexUtil.apply( + new RexVisitorImpl(true) { + @Override + public Void visitCorrelVariable(RexCorrelVariable correlVariable) { + paramIdSet.add(correlVariable.getName()); + return null; + } + }, + exprs, + null); + return paramIdSet; + } + + /** + * Returns whether this program is in canonical form. + * + * @param litmus What to do if an error is detected (program is not in canonical form) + * @param rexBuilder Rex builder + * @return whether in canonical form + */ + public boolean isNormalized(Litmus litmus, RexBuilder rexBuilder) { + final RexProgram normalizedProgram = normalize(rexBuilder, null); + String normalized = normalizedProgram.toString(); + String string = toString(); + if (!normalized.equals(string)) { + final String message = + "Program is not normalized:\n" + "program: {}\n" + "normalized: {}\n"; + return litmus.fail(message, string, normalized); + } + return litmus.succeed(); + } + + /** + * Creates a simplified/normalized copy of this program. + * + * @param rexBuilder Rex builder + * @param simplify Simplifier to simplify (in addition to normalizing), or null to not simplify + * @return Normalized program + */ + public RexProgram normalize(RexBuilder rexBuilder, @Nullable RexSimplify simplify) { + // Normalize program by creating program builder from the program, then + // converting to a program. getProgram does not need to normalize + // because the builder was normalized on creation. + assert isValid(Litmus.THROW, null); + final RexProgramBuilder builder = + RexProgramBuilder.create( + rexBuilder, + inputRowType, + exprs, + projects, + condition, + outputRowType, + true, + simplify); + return builder.getProgram(false); + } + + @Deprecated // to be removed before 2.0 + public RexProgram normalize(RexBuilder rexBuilder, boolean simplify) { + final RelOptPredicateList predicates = RelOptPredicateList.EMPTY; + return normalize( + rexBuilder, + simplify ? new RexSimplify(rexBuilder, predicates, RexUtil.EXECUTOR) : null); + } + + /** + * Returns a partial mapping of a set of project expressions. + * + *

The mapping is an inverse function. Every target has a source field, but a source might + * have 0, 1 or more targets. Project expressions that do not consist of a mapping are ignored. + * + * @param inputFieldCount Number of input fields + * @return Mapping of a set of project expressions, never null + */ + public Mappings.TargetMapping getPartialMapping(int inputFieldCount) { + Mappings.TargetMapping mapping = + Mappings.create(MappingType.INVERSE_FUNCTION, inputFieldCount, projects.size()); + for (Ord exp : Ord.zip(projects)) { + RexNode rexNode = expandLocalRef(exp.e); + if (rexNode instanceof RexInputRef) { + mapping.set(((RexInputRef) rexNode).getIndex(), exp.i); + } + } + return mapping; + } + + // ~ Inner Classes ---------------------------------------------------------- + + /** Visitor which walks over a program and checks validity. */ + static class Checker extends RexChecker { + private final List internalExprTypeList; + + /** + * Creates a Checker. + * + * @param inputRowType Types of the input fields + * @param internalExprTypeList Types of the internal expressions + * @param context Context of the enclosing {@link RelNode}, or null + * @param litmus Whether to fail + */ + Checker( + RelDataType inputRowType, + List internalExprTypeList, + RelNode.Context context, + Litmus litmus) { + super(inputRowType, context, litmus); + this.internalExprTypeList = internalExprTypeList; + } + + /** + * Overrides {@link RexChecker} method, because {@link RexLocalRef} is is illegal in most + * rex expressions, but legal in a program. + */ + @Override + public Boolean visitLocalRef(RexLocalRef localRef) { + final int index = localRef.getIndex(); + if ((index < 0) || (index >= internalExprTypeList.size())) { + ++failCount; + return litmus.fail(null); + } + if (!RelOptUtil.eq( + "type1", + localRef.getType(), + "type2", + internalExprTypeList.get(index), + litmus)) { + ++failCount; + return litmus.fail(null); + } + return litmus.succeed(); + } + } + + /** A RexShuttle used in the implementation of {@link RexProgram#expandLocalRef}. */ + static class ExpansionShuttle extends RexShuttle { + private final List exprs; + + ExpansionShuttle(List exprs) { + this.exprs = exprs; + } + + @Override + public RexNode visitLocalRef(RexLocalRef localRef) { + RexNode tree = exprs.get(localRef.getIndex()); + return tree.accept(this); + } + } + + /** Walks over an expression and determines whether it is constant. */ + private class ConstantFinder extends RexUtil.ConstantFinder { + @Override + public Boolean visitLocalRef(RexLocalRef localRef) { + final RexNode expr = exprs.get(localRef.index); + return expr.accept(this); + } + + @Override + public Boolean visitOver(RexOver over) { + return false; + } + + @Override + public Boolean visitCorrelVariable(RexCorrelVariable correlVariable) { + // Correlating variables are constant WITHIN A RESTART, so that's + // good enough. + return true; + } + } + + /** + * Given an expression in a program, creates a clone of the expression with sub-expressions + * (represented by {@link RexLocalRef}s) fully expanded. + */ + private class Marshaller extends RexVisitorImpl<@Nullable RexNode> { + Marshaller() { + super(false); + } + + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + return inputRef; + } + + @Override + public @Nullable RexNode visitLocalRef(RexLocalRef localRef) { + final RexNode expr = exprs.get(localRef.index); + return expr.accept(this); + } + + @Override + public RexNode visitLiteral(RexLiteral literal) { + return literal; + } + + @Override + public RexNode visitCall(RexCall call) { + final List newOperands = new ArrayList<>(); + for (RexNode operand : call.getOperands()) { + newOperands.add(castNonNull(operand.accept(this))); + } + return call.clone(call.getType(), newOperands); + } + + @Override + public RexNode visitOver(RexOver over) { + return visitCall(over); + } + + @Override + public RexNode visitCorrelVariable(RexCorrelVariable correlVariable) { + return correlVariable; + } + + @Override + public RexNode visitDynamicParam(RexDynamicParam dynamicParam) { + return dynamicParam; + } + + @Override + public RexNode visitRangeRef(RexRangeRef rangeRef) { + return rangeRef; + } + + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + final RexNode referenceExpr = fieldAccess.getReferenceExpr().accept(this); + return new RexFieldAccess( + requireNonNull(referenceExpr, "referenceExpr must not be null"), + fieldAccess.getField(), + fieldAccess.getType()); + } + } + + /** Visitor which marks which expressions are used. */ + private static class ReferenceCounter extends RexVisitorImpl { + private final int[] refCounts; + + ReferenceCounter(int[] refCounts) { + super(true); + this.refCounts = refCounts; + } + + @Override + public Void visitLocalRef(RexLocalRef localRef) { + final int index = localRef.getIndex(); + refCounts[index]++; + return null; + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexShuttle.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexShuttle.java new file mode 100644 index 0000000000000..5a20c53c23927 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexShuttle.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rex; + +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; + +import static java.util.Objects.requireNonNull; + +/** + * Passes over a row-expression, calling a handler method for each node, appropriate to the type of + * the node. + * + *

Like {@link RexVisitor}, this is an instance of the {@link + * org.apache.calcite.util.Glossary#VISITOR_PATTERN Visitor Pattern}. Use RexShuttle + * if you would like your methods to return a value. + * + *

FLINK modifications (backport of CALCITE-6764): Lines 208 ~ 211 + */ +public class RexShuttle implements RexVisitor { + // ~ Methods ---------------------------------------------------------------- + + @Override + public RexNode visitOver(RexOver over) { + boolean[] update = {false}; + List clonedOperands = visitList(over.operands, update); + RexWindow window = visitWindow(over.getWindow()); + if (update[0] || (window != over.getWindow())) { + // REVIEW jvs 8-Mar-2005: This doesn't take into account + // the fact that a rewrite may have changed the result type. + // To do that, we would need to take a RexBuilder and + // watch out for special operators like CAST and NEW where + // the type is embedded in the original call. + return new RexOver( + over.getType(), + over.getAggOperator(), + clonedOperands, + window, + over.isDistinct(), + over.ignoreNulls()); + } else { + return over; + } + } + + public RexWindow visitWindow(RexWindow window) { + boolean[] update = {false}; + List clonedOrderKeys = visitFieldCollations(window.orderKeys, update); + List clonedPartitionKeys = visitList(window.partitionKeys, update); + final RexWindowBound lowerBound = window.getLowerBound().accept(this); + final RexWindowBound upperBound = window.getUpperBound().accept(this); + if (lowerBound == null + || upperBound == null + || !update[0] + && lowerBound == window.getLowerBound() + && upperBound == window.getUpperBound()) { + return window; + } + boolean rows = window.isRows(); + if (lowerBound.isUnbounded() + && lowerBound.isPreceding() + && upperBound.isUnbounded() + && upperBound.isFollowing()) { + // RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + // is equivalent to + // ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + // but we prefer "RANGE" + rows = false; + } + return new RexWindow(clonedPartitionKeys, clonedOrderKeys, lowerBound, upperBound, rows); + } + + @Override + public RexNode visitSubQuery(RexSubQuery subQuery) { + boolean[] update = {false}; + List clonedOperands = visitList(subQuery.operands, update); + if (update[0]) { + return subQuery.clone(subQuery.getType(), clonedOperands); + } else { + return subQuery; + } + } + + @Override + public RexNode visitTableInputRef(RexTableInputRef ref) { + return ref; + } + + @Override + public RexNode visitPatternFieldRef(RexPatternFieldRef fieldRef) { + return fieldRef; + } + + @Override + public RexNode visitCall(final RexCall call) { + boolean[] update = {false}; + List clonedOperands = visitList(call.operands, update); + if (update[0]) { + // REVIEW jvs 8-Mar-2005: This doesn't take into account + // the fact that a rewrite may have changed the result type. + // To do that, we would need to take a RexBuilder and + // watch out for special operators like CAST and NEW where + // the type is embedded in the original call. + return call.clone(call.getType(), clonedOperands); + } else { + return call; + } + } + + /** + * Visits each of an array of expressions and returns an array of the results. + * + * @param exprs Array of expressions + * @param update If not null, sets this to true if any of the expressions was modified + * @return Array of visited expressions + */ + protected RexNode[] visitArray(RexNode[] exprs, boolean[] update) { + RexNode[] clonedOperands = new RexNode[exprs.length]; + for (int i = 0; i < exprs.length; i++) { + RexNode operand = exprs[i]; + RexNode clonedOperand = operand.accept(this); + if ((clonedOperand != operand) && (update != null)) { + update[0] = true; + } + clonedOperands[i] = clonedOperand; + } + return clonedOperands; + } + + /** + * Visits each of a list of expressions and returns a list of the results. + * + * @param exprs List of expressions + * @param update If not null, sets this to true if any of the expressions was modified + * @return Array of visited expressions + */ + protected List visitList(List exprs, boolean[] update) { + ImmutableList.Builder clonedOperands = ImmutableList.builder(); + for (RexNode operand : exprs) { + RexNode clonedOperand = operand.accept(this); + if ((clonedOperand != operand) && (update != null)) { + update[0] = true; + } + clonedOperands.add(clonedOperand); + } + return clonedOperands.build(); + } + + /** + * Visits each of a list of field collations and returns a list of the results. + * + * @param collations List of field collations + * @param update If not null, sets this to true if any of the expressions was modified + * @return Array of visited field collations + */ + protected List visitFieldCollations( + List collations, boolean[] update) { + ImmutableList.Builder clonedOperands = ImmutableList.builder(); + for (RexFieldCollation collation : collations) { + RexNode clonedOperand = collation.left.accept(this); + if ((clonedOperand != collation.left) && (update != null)) { + update[0] = true; + collation = new RexFieldCollation(clonedOperand, requireNonNull(collation.right)); + } + clonedOperands.add(collation); + } + return clonedOperands.build(); + } + + @Override + public RexNode visitCorrelVariable(RexCorrelVariable variable) { + return variable; + } + + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + RexNode before = fieldAccess.getReferenceExpr(); + RexNode after = before.accept(this); + + if (before == after) { + return fieldAccess; + } else { + return new RexFieldAccess(after, fieldAccess.getField(), fieldAccess.getType()); + } + } + + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + return inputRef; + } + + @Override + public RexNode visitLocalRef(RexLocalRef localRef) { + return localRef; + } + + @Override + public RexNode visitLiteral(RexLiteral literal) { + return literal; + } + + @Override + public RexNode visitDynamicParam(RexDynamicParam dynamicParam) { + return dynamicParam; + } + + @Override + public RexNode visitRangeRef(RexRangeRef rangeRef) { + return rangeRef; + } + + @Override + public RexNode visitLambda(RexLambda lambda) { + lambda.getExpression().accept(this); + return lambda; + } + + @Override + public RexNode visitLambdaRef(RexLambdaRef lambdaRef) { + return lambdaRef; + } + + /** + * Applies this shuttle to each expression in a list. + * + * @return whether any of the expressions changed + */ + public final boolean mutate(List exprList) { + int changeCount = 0; + for (int i = 0; i < exprList.size(); i++) { + T expr = exprList.get(i); + T expr2 = (T) apply(expr); // Avoid NPE if expr is null + if (expr != expr2) { + ++changeCount; + exprList.set(i, expr2); + } + } + return changeCount > 0; + } + + /** + * Applies this shuttle to each expression in a list and returns the resulting list. Does not + * modify the initial list. + * + *

Returns null if and only if {@code exprList} is null. + */ + public final List apply(List exprList) { + if (exprList == null) { + return exprList; + } + final List list2 = new ArrayList<>(exprList); + if (mutate(list2)) { + return list2; + } else { + return exprList; + } + } + + /** Applies this shuttle to an expression, or returns null if the expression is null. */ + public final RexNode apply(RexNode expr) { + return (expr == null) ? expr : expr.accept(this); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java index 77cb8af4323cf..bddc6a15239ad 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java @@ -81,6 +81,8 @@ * *

Lines 402 ~ 404, Use Calcite 1.32.0 behavior for {@link RexUtil#gatherConstraints(Class, * RexNode, Map, Set, RexBuilder)}. + * + *

FLINK modifications (backport of CALCITE-6764): Line 2481~2485 */ public class RexUtil { @@ -2489,7 +2491,11 @@ public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { expr.accept(this); final RexNode normalizedExpr = lookup(expr); if (normalizedExpr != expr) { - fieldAccess = new RexFieldAccess(normalizedExpr, fieldAccess.getField()); + // ----- FLINK MODIFICATION BEGIN ----- + fieldAccess = + new RexFieldAccess( + normalizedExpr, fieldAccess.getField(), fieldAccess.getType()); + // ----- FLINK MODIFICATION END ----- } return register(fieldAccess); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java index 3fa27ea4016a0..50a979faa517f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java @@ -41,6 +41,7 @@ *

  • Should be removed after fixing CALCITE-6342: Lines 100-102 *
  • Should be removed after fixing CALCITE-6342: Lines 484-496 *
  • Should be removed after fix of FLINK-31350: Lines 563-575. + *
  • Added in FLINK-39695 (backport of CALCITE-6764): Lines 225 ~ 248 * */ public class SqlTypeFactoryImpl extends RelDataTypeFactoryImpl { @@ -227,6 +228,31 @@ public RelDataType createTypeWithNullability(final RelDataType type, final boole return canonize(newType); } + // ----- FLINK MODIFICATION BEGIN ----- + // Backport from Calcite (CALCITE-6764) + @Override + public RelDataType enforceTypeWithNullability(final RelDataType type, final boolean nullable) { + final RelDataType newType; + if (type instanceof BasicSqlType) { + newType = ((BasicSqlType) type).createWithNullability(nullable); + } else if (type instanceof MapSqlType) { + newType = copyMapType(type, nullable); + } else if (type instanceof ArraySqlType) { + newType = copyArrayType(type, nullable); + } else if (type instanceof MultisetSqlType) { + newType = copyMultisetType(type, nullable); + } else if (type instanceof IntervalSqlType) { + newType = copyIntervalType(type, nullable); + } else if (type instanceof ObjectSqlType) { + newType = copyObjectType(type, nullable); + } else { + return super.enforceTypeWithNullability(type, nullable); + } + return canonize(newType); + } + + // ----- FLINK MODIFICATION END ----- + private static void assertBasic(SqlTypeName typeName) { assert typeName != null; assert typeName != SqlTypeName.MULTISET : "use createMultisetType() instead"; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index b002c45b3a5aa..b7d9b6eb48824 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -32,6 +32,7 @@ import org.apache.calcite.rel.type.RelCrossType; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeFactoryImpl; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rel.type.RelRecordType; @@ -178,37 +179,42 @@ * Default implementation of {@link SqlValidator}, the class was copied over because of * CALCITE-4554. * - *

    Lines 230 ~ 233, Flink improves error message for functions without appropriate arguments in + *

    Lines 234 ~ 237, Flink improves error message for functions without appropriate arguments in * handleUnresolvedFunction. * - *

    Lines 1319 ~ 1321, CALCITE-7217, should be removed after upgrading Calcite to 1.41.0. + *

    Lines 1325 ~ 1327, CALCITE-7217, should be removed after upgrading Calcite to 1.41.0. * - *

    Lines 2080 ~ 2094, Flink improves error message for functions without appropriate arguments in + *

    Lines 2086 ~ 2100, Flink improves error message for functions without appropriate arguments in * handleUnresolvedFunction at {@link SqlValidatorImpl#handleUnresolvedFunction}. * - *

    Lines 2507 ~ 2509, CALCITE-7471 should be removed after upgrading Calcite to 1.42.0. + *

    Lines 2348 ~ 2359 * - *

    Lines 2622 ~ 2641, CALCITE-7217, CALCITE-7312 should be removed after upgrading Calcite to + *

    Lines 2513 ~ 2515, CALCITE-7471 should be removed after upgrading Calcite to 1.42.0. + * + *

    Lines 2628 ~ 2647, CALCITE-7217, CALCITE-7312 should be removed after upgrading Calcite to * 1.42.0. * - *

    Line 2672 ~2690, set the correct scope for VECTOR_SEARCH. + *

    Line 2678 ~2696, set the correct scope for VECTOR_SEARCH. * - *

    Lines 4072 ~ 4076, 6766 ~ 6772 Flink improves Optimize the retrieval of sub-operands in + *

    Lines 4078 ~ 4082, 6766 ~ 6772 Flink improves Optimize the retrieval of sub-operands in * SqlCall when using NamedParameters at {@link SqlValidatorImpl#checkRollUp}. * - *

    Lines 5492 ~ 5498, FLINK-24352 Add null check for temporal table check on SqlSnapshot. + *

    Lines 5498 ~ 5504, FLINK-24352 Add null check for temporal table check on SqlSnapshot. + * + *

    Lines 5919-5934, CALCITE-7538 should be removed after upgrading Calcite to 1.42.0. * - *

    Lines 5913-5928, CALCITE-7538 should be removed after upgrading Calcite to 1.42.0. + *

    Lines 5944-5946, CALCITE-7466 should be removed after upgrading Calcite to 1.42.0. * - *

    Lines 5938-5940, CALCITE-7466 should be removed after upgrading Calcite to 1.42.0. + *

    Lines 6000-6002, CALCITE-7470 should be removed after upgrading Calcite to 1.42.0. * - *

    Lines 5994-5996, CALCITE-7470 should be removed after upgrading Calcite to 1.42.0. + *

    Lines 6863-6873, Added in FLINK-39695 (backport of CALCITE-6764): propagate parent record + * nullability to nested fields. * - *

    Lines 7422-7445, CALCITE-7486 should be removed after upgrading Calcite to 1.42.0. + *

    Lines 7438-7461, CALCITE-7486 should be removed after upgrading Calcite to 1.42.0. * - *

    Lines 7492-7509, CALCITE-7486 should be removed after upgrading Calcite to 1.42.0. + *

    Lines 7508-7525, CALCITE-7486 should be removed after upgrading Calcite to 1.42.0. * - *

    Lines 7554-7562, CALCITE-7486 should be removed after upgrading Calcite to 1.42.0. + *

    Lines 7570-7578, CALCITE-7486 should be removed after upgrading Calcite to 1.42.0. */ public class SqlValidatorImpl implements SqlValidatorWithHints { // ~ Static fields/initializers --------------------------------------------- @@ -6854,7 +6860,17 @@ public RelDataType visit(SqlIdentifier id) { if (field == null) { throw newValidationError(id.getComponent(i), RESOURCE.unknownField(name)); } + // ----- FLINK MODIFICATION BEGIN ----- + // Backport from Calcite (CALCITE-6764): if the parent record is + // nullable, the field must also be nullable. + boolean recordIsNullable = type.isNullable(); type = field.getType(); + if (recordIsNullable) { + type = + ((RelDataTypeFactoryImpl) getTypeFactory()) + .enforceTypeWithNullability(type, true); + } + // ----- FLINK MODIFICATION END ----- } type = SqlTypeUtil.addCharsetAndCollation(type, getTypeFactory()); return type; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java index da6359b697763..2b9589203dbc9 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java @@ -4609,6 +4609,14 @@ private RexNode convertIdentifier(Blackboard bb, SqlIdentifier identifier) { } } + // ----- FLINK MODIFICATION BEGIN ----- + // For nested field access (e.g. b.r.order_id in a LEFT JOIN), the result + // is a RexFieldAccess (or CAST of one) whose reference RexInputRef still + // has the pre-join type. Adjust the reference so that field access on a + // nullable ROW from the non-preserved side produces a nullable type. + e = adjustFieldAccessInputRef(bb, e); + // ----- FLINK MODIFICATION END ----- + if (e0.left instanceof RexCorrelVariable) { // ----- FLINK MODIFICATION BEGIN ----- // adjust the type to account for nulls introduced by FlinkRexBuilder#makeFieldAccess @@ -4621,6 +4629,65 @@ private RexNode convertIdentifier(Blackboard bb, SqlIdentifier identifier) { } // ----- FLINK MODIFICATION BEGIN ----- + /** + * Adjusts the nullability of a nested field access based on the nullability of the enclosing + * ROW after an outer join. For instance if there are tables + * + *

    {@code
    +     * CREATE TABLE orders (order_id BIGINT NOT NULL, PRIMARY KEY (order_id) NOT ENFORCED);
    +     * CREATE TABLE details (
    +     *   r ROW NOT NULL,
    +     *   PRIMARY KEY (r) NOT ENFORCED
    +     * );
    +     * }
    + * + *

    and then there is a SQL query + * + *

    {@code
    +     * SELECT b.r.order_id
    +     * FROM orders a LEFT JOIN details b ON a.order_id = b.r.order_id
    +     * }
    + * + *

    The field {@code r.order_id} is declared {@code NOT NULL} inside a {@code NOT NULL} ROW. + * However, in a LEFT JOIN when there is no match on the right side, the entire {@code b} row is + * null-padded — so {@code b.r} is null, and {@code b.r.order_id} must produce {@code null}. + * + *

    Without this adjustment, the {@link RexFieldAccess} built by {@code convertIdentifier} + * still carries the pre-join {@link RexInputRef} type ({@code NOT NULL}), so the field access + * result is also typed as {@code NOT NULL}. At runtime this causes the codegen to read a + * default value (e.g. {@code -1} for {@code BIGINT}) from the null-padded row instead of {@code + * null}. + */ + private RexNode adjustFieldAccessInputRef(Blackboard bb, RexNode e) { + final RexFieldAccess fieldAccess; + if (e instanceof RexFieldAccess) { + fieldAccess = (RexFieldAccess) e; + } else if (e instanceof RexCall + && e.getKind() == SqlKind.CAST + && ((RexCall) e).getOperands().get(0) instanceof RexFieldAccess) { + fieldAccess = (RexFieldAccess) ((RexCall) e).getOperands().get(0); + } else { + return e; + } + + final RexNode ref = fieldAccess.getReferenceExpr(); + if (!(ref instanceof RexInputRef)) { + return e; + } + + final RexNode adjusted = adjustInputRef(bb, (RexInputRef) ref); + if (adjusted.getType().isNullable() == ref.getType().isNullable()) { + return e; + } + // Rebuild the field access with the adjusted ref and wrap in CAST to nullable type. + // The CAST is required for Flink codegen to emit null-checking code at runtime. + final RelDataType nullableFieldType = + typeFactory.createTypeWithNullability(fieldAccess.getField().getType(), true); + return rexBuilder.makeCast( + nullableFieldType, + rexBuilder.makeFieldAccess(adjusted, fieldAccess.getField().getIndex())); + } + private RexFieldAccess adjustRexFieldAccess(RexNode rexNode) { // Either RexFieldAccess or CAST of RexFieldAccess to nullable assert rexNode instanceof RexFieldAccess diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/FlinkRexBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/FlinkRexBuilder.java index 814dfa1d2af9c..0c7d62ca47f0b 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/FlinkRexBuilder.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/FlinkRexBuilder.java @@ -21,13 +21,7 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; -import org.apache.calcite.rex.RexCall; -import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexLiteral; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.rex.RexShuttle; -import org.apache.calcite.rex.RexUtil; -import org.apache.calcite.sql.SqlKind; import org.apache.calcite.util.TimestampString; /** A slim extension over a {@link RexBuilder}. See the overridden methods for more explanation. */ @@ -36,32 +30,6 @@ public FlinkRexBuilder(RelDataTypeFactory typeFactory) { super(typeFactory); } - /** - * Compared to the original method we adjust the nullability of the nested column based on the - * nullability of the enclosing type. - * - *

    If the fields type is NOT NULL, but the enclosing ROW is nullable we still can produce - * nulls. - */ - @Override - public RexNode makeFieldAccess(RexNode expr, String fieldName, boolean caseSensitive) { - final RexNode field = super.makeFieldAccess(expr, fieldName, caseSensitive); - return makeFieldAccess(expr, field); - } - - /** - * Compared to the original method we adjust the nullability of the nested column based on the - * nullability of the enclosing type. - * - *

    If the fields type is NOT NULL, but the enclosing ROW is nullable we still can produce - * nulls. - */ - @Override - public RexNode makeFieldAccess(RexNode expr, int i) { - final RexNode field = super.makeFieldAccess(expr, i); - return makeFieldAccess(expr, field); - } - /** * Creates a literal of the default value for the given type. * @@ -91,89 +59,4 @@ public RexLiteral makeZeroLiteral(RelDataType type) { return super.makeZeroLiteral(type); } } - - /** - * Adjust the nullability of the nested column based on the nullability of the enclosing type. - * However, if there is former nullability {@code CAST} present then it will be dropped and - * replaced with a new one (if needed). For instance if there is a table - * - *

    {@code
    -     * CREATE TABLE MyTable (
    -     * `field1` ROW<`data` ROW<`nested` ROW<`trId` STRING>>NOT NULL>
    -     * WITH ('connector' = 'datagen')
    -     * }
    - * - *

    and then there is a SQL query - * - *

    {@code
    -     * SELECT `field1`.`data`.`nested`.`trId` AS transactionId FROM MyTable
    -     * }
    - * - *

    The {@code SELECT} picks a nested field only. In this case it should go step by step - * checking each level. - * - *

      - *
    1. Looking at {@code `field1`} type it is nullable, then no changes. - *
    2. {@code `field1`.`data`} is {@code NOT NULL}, however keeping in mind that enclosing - * type @{code `field1`} is nullable then need to change nullability with {@code CAST} - *
    3. {@code `field1`.`data`.`nested`} is nullable that means that in this case no need for - * extra {@code CAST} inserted in previous step, so it will be dropped. - *
    4. {@code `field1`.`data`.`nested`.`trId`} is also nullable, so no changes. - *
    - */ - private RexNode makeFieldAccess(RexNode expr, RexNode field) { - final RexNode fieldWithRemovedCast = removeCastNullableFromFieldAccess(field); - final boolean nullabilityShouldChange = - field.getType().isNullable() != fieldWithRemovedCast.getType().isNullable() - || expr.getType().isNullable() && !field.getType().isNullable(); - - if (nullabilityShouldChange) { - return makeCast( - typeFactory.createTypeWithNullability(field.getType(), true), - fieldWithRemovedCast, - true, - false); - } - - return expr.getType().isNullable() && fieldWithRemovedCast.getType().isNullable() - ? fieldWithRemovedCast - : field; - } - - /** - * {@link FlinkRexBuilder#makeFieldAccess} will adjust nullability based on nullability of the - * enclosing type. However, it might be a deeply nested column and for every step {@link - * FlinkRexBuilder#makeFieldAccess} will try to insert a cast. This method will remove previous - * cast in order to keep only one. - */ - private RexNode removeCastNullableFromFieldAccess(RexNode rexFieldAccess) { - if (!(rexFieldAccess instanceof RexFieldAccess)) { - return rexFieldAccess; - } - RexNode rexNode = rexFieldAccess; - while (rexNode instanceof RexFieldAccess) { - rexNode = ((RexFieldAccess) rexNode).getReferenceExpr(); - } - if (rexNode.getKind() != SqlKind.CAST) { - return rexFieldAccess; - } - RexShuttle visitor = - new RexShuttle() { - @Override - public RexNode visitCall(final RexCall call) { - if (call.getKind() == SqlKind.CAST - && !call.operands.get(0).getType().isNullable() - && call.getType().isNullable() - && call.getOperands() - .get(0) - .getType() - .getFieldList() - .equals(call.getType().getFieldList())) { - return RexUtil.removeCast(call); - } - return call; - } - }; - return RexUtil.apply(visitor, new RexNode[] {rexFieldAccess})[0]; - } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/common/CalcTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/common/CalcTestPrograms.java index 6300c24369b0b..0a92bc0f41a52 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/common/CalcTestPrograms.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/common/CalcTestPrograms.java @@ -268,4 +268,36 @@ public class CalcTestPrograms { .build()) .runSql("INSERT INTO sink_t SELECT name, ts, CURRENT_WATERMARK(ts) AS w FROM t") .build(); + + public static final TableTestProgram COALESCE_NESTED_ROW_LEFT_JOIN = + TableTestProgram.of( + "calc-coalesce-nested-row-left-join", + "validates coalesce on nested ROW field from LEFT JOIN") + .setupTableSource( + SourceTestStep.newBuilder("orders") + .addSchema( + "`order_id` BIGINT NOT NULL", + "`amount` DOUBLE", + "PRIMARY KEY (`order_id`) NOT ENFORCED") + .producedValues(Row.of(1L, 10.0), Row.of(2L, 20.0)) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("order_details_row") + .addSchema( + "`r` ROW<`order_id` BIGINT NOT NULL, `name` STRING NOT NULL> NOT NULL", + "`detail` STRING", + "PRIMARY KEY (`r`) NOT ENFORCED") + .producedValues(Row.of(Row.of(1L, "first"), "d1")) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("coalesce_sink") + .addSchema("order_id_str STRING") + .consumedValues("+I[1]", "+I[2]") + .build()) + .runSql( + "INSERT INTO coalesce_sink " + + "SELECT CAST(COALESCE(b.r.order_id, a.order_id) AS STRING) AS order_id_str " + + "FROM orders a LEFT JOIN order_details_row b " + + "ON a.order_id = b.r.order_id") + .build(); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/common/JoinSemanticTestPrograms.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/common/JoinSemanticTestPrograms.java index b687f5d4a18b3..d67b997b597f2 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/common/JoinSemanticTestPrograms.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/common/JoinSemanticTestPrograms.java @@ -104,10 +104,205 @@ public class JoinSemanticTestPrograms { .setupTableSink( SinkTestStep.newBuilder("sink_t") .addSchema("output STRING") + .testMaterializedData() .consumedValues("+I[test_diff]") .build()) .runSql( "INSERT INTO sink_t SELECT t2.ext.nested.nested1.nested2 FROM source_t2 t2 WHERE" + " NOT EXISTS (SELECT 1 FROM source_t1 t1 WHERE t1.ext.nested = t2.ext.nested.nested1.nested2)") .build(); + + // --- NOT NULL ROW field, join on nested field --- + + public static final TableTestProgram LEFT_JOIN_NOT_NULL_NESTED_ROW = + TableTestProgram.of( + "left-join-not-null-nested-row", + "NOT NULL ROW field from non-preserved side of LEFT JOIN must be nullable") + .setupTableSource( + SourceTestStep.newBuilder("lj_nn_orders") + .addSchema( + "`order_id` BIGINT NOT NULL", + "PRIMARY KEY (`order_id`) NOT ENFORCED") + .producedValues(Row.of(1L), Row.of(2L)) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("lj_nn_details") + .addSchema( + "`r` ROW<`order_id` BIGINT NOT NULL, `name` STRING NOT NULL> NOT NULL", + "PRIMARY KEY (`r`) NOT ENFORCED") + .producedValues(Row.of(Row.of(1L, "first"))) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("lj_nn_sink") + .addSchema("order_id BIGINT", "detail_id BIGINT") + .testMaterializedData() + .consumedValues("+I[1, 1]", "+I[2, null]") + .build()) + .runSql( + "INSERT INTO lj_nn_sink " + + "SELECT a.order_id, b.r.order_id " + + "FROM lj_nn_orders a LEFT JOIN lj_nn_details b " + + "ON a.order_id = b.r.order_id") + .build(); + + public static final TableTestProgram RIGHT_JOIN_NOT_NULL_NESTED_ROW = + TableTestProgram.of( + "right-join-not-null-nested-row", + "NOT NULL ROW field from non-preserved side of RIGHT JOIN must be nullable") + .setupTableSource( + SourceTestStep.newBuilder("rj_nn_orders") + .addSchema( + "`order_id` BIGINT NOT NULL", + "PRIMARY KEY (`order_id`) NOT ENFORCED") + .producedValues(Row.of(1L), Row.of(2L)) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("rj_nn_details") + .addSchema( + "`r` ROW<`order_id` BIGINT NOT NULL, `name` STRING NOT NULL> NOT NULL", + "PRIMARY KEY (`r`) NOT ENFORCED") + .producedValues(Row.of(Row.of(1L, "first"))) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("rj_nn_sink") + .addSchema("detail_id BIGINT", "order_id BIGINT") + .testMaterializedData() + .consumedValues("+I[1, 1]", "+I[null, 2]") + .build()) + .runSql( + "INSERT INTO rj_nn_sink " + + "SELECT b.r.order_id, a.order_id " + + "FROM rj_nn_details b RIGHT JOIN rj_nn_orders a " + + "ON a.order_id = b.r.order_id") + .build(); + + public static final TableTestProgram FULL_JOIN_NOT_NULL_NESTED_ROW = + TableTestProgram.of( + "full-join-not-null-nested-row", + "NOT NULL ROW fields from both sides of FULL JOIN must be nullable") + .setupTableSource( + SourceTestStep.newBuilder("fj_nn_left") + .addSchema( + "`r` ROW<`id` BIGINT NOT NULL> NOT NULL", + "PRIMARY KEY (`r`) NOT ENFORCED") + .producedValues(Row.of(Row.of(1L)), Row.of(Row.of(2L))) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("fj_nn_right") + .addSchema( + "`r` ROW<`id` BIGINT NOT NULL> NOT NULL", + "PRIMARY KEY (`r`) NOT ENFORCED") + .producedValues(Row.of(Row.of(2L)), Row.of(Row.of(3L))) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("fj_nn_sink") + .addSchema("left_id BIGINT", "right_id BIGINT") + .testMaterializedData() + .consumedValues("+I[1, null]", "+I[2, 2]", "+I[null, 3]") + .build()) + .runSql( + "INSERT INTO fj_nn_sink " + + "SELECT a.r.id, b.r.id " + + "FROM fj_nn_left a FULL JOIN fj_nn_right b " + + "ON a.r.id = b.r.id") + .build(); + + // --- nullable ROW field, join on nested field --- + + public static final TableTestProgram LEFT_JOIN_NULLABLE_NESTED_ROW = + TableTestProgram.of( + "left-join-nullable-nested-row", + "nullable ROW field from non-preserved side of LEFT JOIN must be nullable") + .setupTableSource( + SourceTestStep.newBuilder("lj_n_orders") + .addSchema( + "`order_id` BIGINT NOT NULL", + "PRIMARY KEY (`order_id`) NOT ENFORCED") + .producedValues(Row.of(1L), Row.of(2L)) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("lj_n_details") + .addSchema( + "`id` BIGINT NOT NULL", + "`r` ROW<`order_id` BIGINT NOT NULL, `name` STRING>", + "PRIMARY KEY (`id`) NOT ENFORCED") + .producedValues(Row.of(1L, Row.of(1L, "first"))) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("lj_n_sink") + .addSchema("order_id BIGINT", "detail_id BIGINT") + .testMaterializedData() + .consumedValues("+I[1, 1]", "+I[2, null]") + .build()) + .runSql( + "INSERT INTO lj_n_sink " + + "SELECT a.order_id, b.r.order_id " + + "FROM lj_n_orders a LEFT JOIN lj_n_details b " + + "ON a.order_id = b.r.order_id") + .build(); + + public static final TableTestProgram RIGHT_JOIN_NULLABLE_NESTED_ROW = + TableTestProgram.of( + "right-join-nullable-nested-row", + "nullable ROW field from non-preserved side of RIGHT JOIN must be nullable") + .setupTableSource( + SourceTestStep.newBuilder("rj_n_orders") + .addSchema( + "`order_id` BIGINT NOT NULL", + "PRIMARY KEY (`order_id`) NOT ENFORCED") + .producedValues(Row.of(1L), Row.of(2L)) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("rj_n_details") + .addSchema( + "`id` BIGINT NOT NULL", + "`r` ROW<`order_id` BIGINT NOT NULL, `name` STRING>", + "PRIMARY KEY (`id`) NOT ENFORCED") + .producedValues(Row.of(1L, Row.of(1L, "first"))) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("rj_n_sink") + .addSchema("detail_id BIGINT", "order_id BIGINT") + .testMaterializedData() + .consumedValues("+I[1, 1]", "+I[null, 2]") + .build()) + .runSql( + "INSERT INTO rj_n_sink " + + "SELECT b.r.order_id, a.order_id " + + "FROM rj_n_details b RIGHT JOIN rj_n_orders a " + + "ON a.order_id = b.r.order_id") + .build(); + + public static final TableTestProgram FULL_JOIN_NULLABLE_NESTED_ROW = + TableTestProgram.of( + "full-join-nullable-nested-row", + "nullable ROW fields from both sides of FULL JOIN must be nullable") + .setupTableSource( + SourceTestStep.newBuilder("fj_n_left") + .addSchema( + "`id` BIGINT NOT NULL", + "`r` ROW<`id` BIGINT NOT NULL>", + "PRIMARY KEY (`id`) NOT ENFORCED") + .producedValues(Row.of(1L, Row.of(1L)), Row.of(2L, Row.of(2L))) + .build()) + .setupTableSource( + SourceTestStep.newBuilder("fj_n_right") + .addSchema( + "`id` BIGINT NOT NULL", + "`r` ROW<`id` BIGINT NOT NULL>", + "PRIMARY KEY (`id`) NOT ENFORCED") + .producedValues(Row.of(2L, Row.of(2L)), Row.of(3L, Row.of(3L))) + .build()) + .setupTableSink( + SinkTestStep.newBuilder("fj_n_sink") + .addSchema("left_id BIGINT", "right_id BIGINT") + .testMaterializedData() + .consumedValues("+I[1, null]", "+I[2, 2]", "+I[null, 3]") + .build()) + .runSql( + "INSERT INTO fj_n_sink " + + "SELECT a.r.id, b.r.id " + + "FROM fj_n_left a FULL JOIN fj_n_right b " + + "ON a.r.id = b.r.id") + .build(); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/JoinSemanticTests.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/JoinSemanticTests.java index 2501a7e626140..4988c34a93cf4 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/JoinSemanticTests.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/JoinSemanticTests.java @@ -30,6 +30,12 @@ public class JoinSemanticTests extends SemanticTestBase { public List programs() { return List.of( JoinSemanticTestPrograms.OUTER_JOIN_CHANGELOG_TEST, - JoinSemanticTestPrograms.ANTI_JOIN_ON_NESTED); + JoinSemanticTestPrograms.ANTI_JOIN_ON_NESTED, + JoinSemanticTestPrograms.LEFT_JOIN_NOT_NULL_NESTED_ROW, + JoinSemanticTestPrograms.RIGHT_JOIN_NOT_NULL_NESTED_ROW, + JoinSemanticTestPrograms.FULL_JOIN_NOT_NULL_NESTED_ROW, + JoinSemanticTestPrograms.LEFT_JOIN_NULLABLE_NESTED_ROW, + JoinSemanticTestPrograms.RIGHT_JOIN_NULLABLE_NESTED_ROW, + JoinSemanticTestPrograms.FULL_JOIN_NULLABLE_NESTED_ROW); } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MiscSemanticTests.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MiscSemanticTests.java index 7467e7cf45ec9..5b35692c77bd1 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MiscSemanticTests.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MiscSemanticTests.java @@ -31,6 +31,7 @@ public class MiscSemanticTests extends SemanticTestBase { public List programs() { return List.of( WindowRankTestPrograms.WINDOW_RANK_HOP_TVF_NAMED_MIN_TOP_1, - CalcTestPrograms.CURRENT_WATERMARK); + CalcTestPrograms.CURRENT_WATERMARK, + CalcTestPrograms.COALESCE_NESTED_ROW_LEFT_JOIN); } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.java index 0a3d8c7b81bbd..398c3dd583800 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.java @@ -18,13 +18,20 @@ package org.apache.flink.table.planner.plan.rules.logical; +import org.apache.flink.table.api.Table; import org.apache.flink.table.api.TableConfig; import org.apache.flink.table.planner.utils.StreamTableTestUtil; import org.apache.flink.table.planner.utils.TableTestBase; +import org.apache.flink.table.planner.utils.TableTestUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rex.RexNode; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThat; + /** Test rule {@link SimplifyCoalesceWithEquiJoinConditionRule}. */ class SimplifyCoalesceWithEquiJoinConditionRuleTest extends TableTestBase { @@ -176,4 +183,23 @@ void testCoalesceOnNestedRowScalarField() { "SELECT CAST(COALESCE(b.r.order_id, a.order_id) AS STRING) AS order_id_str " + "FROM orders a LEFT JOIN order_details_row b ON a.order_id = b.r.order_id"); } + + @Test + void testNestedRowFieldAccessNullableAfterLeftJoin() { + // b.r.order_id must be nullable after a LEFT JOIN even though + // the ROW field is declared NOT NULL, because the ROW itself + // is null-padded on the non-preserved side. + Table table = + util.tableEnv() + .sqlQuery( + "SELECT b.r.order_id " + + "FROM orders a LEFT JOIN order_details_row b " + + "ON a.order_id = b.r.order_id"); + RelNode relNode = TableTestUtil.toRelNode(table); + LogicalProject project = (LogicalProject) relNode; + RexNode expr = project.getProjects().get(0); + assertThat(expr.getType().isNullable()) + .as("Field access on nullable ROW from LEFT JOIN should be nullable, got: " + expr) + .isTrue(); + } } diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/table/CalcTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/table/CalcTest.xml index c68409972cea8..d71673a4f3e0b 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/table/CalcTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/table/CalcTest.xml @@ -64,13 +64,13 @@ Calc(select=[a._1 AS a$_1, a._2 AS a$_2, b]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PushProjectIntoTableSourceScanRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PushProjectIntoTableSourceScanRuleTest.xml index 92ec23fe38c73..ae3eb0f4948ff 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PushProjectIntoTableSourceScanRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/PushProjectIntoTableSourceScanRuleTest.xml @@ -268,13 +268,13 @@ LogicalProject(EXPR$0=[ITEM($0, 2).value], data_arr=[$0]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.xml index cf6cbf60a2df1..e2b46196f6ca8 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyCoalesceWithEquiJoinConditionRuleTest.xml @@ -96,7 +96,7 @@ LogicalProject(order_id_str=[CAST(COALESCE($3.order_id, $0)):VARCHAR(2147483647) @@ -134,7 +134,7 @@ LogicalProject(field2=[$1], transactionId=[COALESCE(ITEM($0.data, 0).nested.trId @@ -295,8 +295,8 @@ GROUP BY t.data[1].nested[0].trId]]> NOT NULL" + }, { + "type" : "ReadingMetadata", + "metadataKeys" : [ ], + "producedType" : "ROW<`order_id` BIGINT NOT NULL> NOT NULL" + } ] + }, + "outputType" : "ROW<`order_id` BIGINT NOT NULL>", + "description" : "TableSourceScan(table=[[default_catalog, default_database, orders, project=[order_id], metadata=[]]], fields=[order_id])" + }, { + "id" : 28, + "type" : "stream-exec-exchange_1", + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "HASH", + "keys" : [ 0 ] + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`order_id` BIGINT NOT NULL>", + "description" : "Exchange(distribution=[hash[order_id]])" + }, { + "id" : 29, + "type" : "stream-exec-table-source-scan_2", + "scanTableSource" : { + "table" : { + "identifier" : "`default_catalog`.`default_database`.`order_details_row`", + "resolvedTable" : { + "schema" : { + "columns" : [ { + "name" : "r", + "dataType" : "ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL> NOT NULL" + }, { + "name" : "detail", + "dataType" : "VARCHAR(2147483647)" + } ], + "primaryKey" : { + "name" : "PK_r", + "type" : "PRIMARY_KEY", + "columns" : [ "r" ] + } + } + } + }, + "abilities" : [ { + "type" : "ProjectPushDown", + "projectedFields" : [ [ 0 ] ], + "producedType" : "ROW<`r` ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL> NOT NULL> NOT NULL" + }, { + "type" : "ReadingMetadata", + "metadataKeys" : [ ], + "producedType" : "ROW<`r` ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL> NOT NULL> NOT NULL" + } ] + }, + "outputType" : "ROW<`r` ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL> NOT NULL>", + "description" : "TableSourceScan(table=[[default_catalog, default_database, order_details_row, project=[r], metadata=[]]], fields=[r])" + }, { + "id" : 30, + "type" : "stream-exec-calc_1", + "projection" : [ { + "kind" : "INPUT_REF", + "inputIndex" : 0, + "type" : "ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL> NOT NULL" + }, { + "kind" : "FIELD_ACCESS", + "name" : "order_id", + "expr" : { + "kind" : "INPUT_REF", + "inputIndex" : 0, + "type" : "ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL> NOT NULL" + } + } ], + "condition" : null, + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`r` ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL> NOT NULL, `$f2` BIGINT NOT NULL>", + "description" : "Calc(select=[r, r.order_id AS $f2])" + }, { + "id" : 31, + "type" : "stream-exec-exchange_1", + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "HASH", + "keys" : [ 1 ] + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`r` ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL> NOT NULL, `$f2` BIGINT NOT NULL>", + "description" : "Exchange(distribution=[hash[$f2]])" + }, { + "id" : 32, + "type" : "stream-exec-join_1", + "joinSpec" : { + "joinType" : "LEFT", + "leftKeys" : [ 0 ], + "rightKeys" : [ 1 ], + "filterNulls" : [ true ], + "nonEquiCondition" : null + }, + "leftUpsertKeys" : [ [ 0 ] ], + "rightUpsertKeys" : [ [ 0 ] ], + "state" : [ { + "index" : 0, + "ttl" : "0 ms", + "name" : "leftState" + }, { + "index" : 1, + "ttl" : "0 ms", + "name" : "rightState" + } ], + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + }, { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`order_id` BIGINT NOT NULL, `r` ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL>, `$f2` BIGINT>", + "description" : "Join(joinType=[LeftOuterJoin], where=[(order_id = $f2)], select=[order_id, r, $f2], leftInputSpec=[JoinKeyContainsUniqueKey], rightInputSpec=[HasUniqueKey])" + }, { + "id" : 33, + "type" : "stream-exec-calc_1", + "projection" : [ { + "kind" : "CALL", + "syntax" : "SPECIAL", + "internalName" : "$CAST$1", + "operands" : [ { + "kind" : "CALL", + "internalName" : "$COALESCE$1", + "operands" : [ { + "kind" : "CALL", + "syntax" : "SPECIAL", + "internalName" : "$CAST$1", + "operands" : [ { + "kind" : "INPUT_REF", + "inputIndex" : 1, + "type" : "ROW<`order_id` BIGINT NOT NULL, `name` VARCHAR(2147483647) NOT NULL>" + } ], + "type" : "BIGINT" + }, { + "kind" : "INPUT_REF", + "inputIndex" : 0, + "type" : "BIGINT NOT NULL" + } ], + "type" : "BIGINT NOT NULL" + } ], + "type" : "VARCHAR(2147483647) NOT NULL" + } ], + "condition" : null, + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`order_id_str` VARCHAR(2147483647) NOT NULL>", + "description" : "Calc(select=[CAST(COALESCE(CAST(r AS BIGINT), order_id) AS VARCHAR(2147483647)) AS order_id_str])" + }, { + "id" : 34, + "type" : "stream-exec-sink_2", + "configuration" : { + "table.exec.sink.keyed-shuffle" : "AUTO", + "table.exec.sink.not-null-enforcer" : "ERROR", + "table.exec.sink.rowtime-inserter" : "ENABLED", + "table.exec.sink.type-length-enforcer" : "IGNORE", + "table.exec.sink.upsert-materialize" : "AUTO" + }, + "dynamicTableSink" : { + "table" : { + "identifier" : "`default_catalog`.`default_database`.`coalesce_sink`", + "resolvedTable" : { + "schema" : { + "columns" : [ { + "name" : "order_id_str", + "dataType" : "VARCHAR(2147483647)" + } ] + } + } + } + }, + "inputChangelogMode" : [ "INSERT", "UPDATE_BEFORE", "UPDATE_AFTER", "DELETE" ], + "inputProperties" : [ { + "requiredDistribution" : { + "type" : "UNKNOWN" + }, + "damBehavior" : "PIPELINED", + "priority" : 0 + } ], + "outputType" : "ROW<`order_id_str` VARCHAR(2147483647) NOT NULL>", + "description" : "Sink(table=[default_catalog.default_database.coalesce_sink], fields=[order_id_str])" + } ], + "edges" : [ { + "source" : 27, + "target" : 28, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 29, + "target" : 30, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 30, + "target" : 31, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 28, + "target" : 32, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 31, + "target" : 32, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 32, + "target" : 33, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + }, { + "source" : 33, + "target" : 34, + "shuffle" : { + "type" : "FORWARD" + }, + "shuffleMode" : "PIPELINED" + } ] +} \ No newline at end of file