Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -299,8 +300,7 @@ private static RelNode convertSinkToRel(
isOverwrite,
sink,
contextResolvedTable.getResolvedTable(),
sinkAbilitySpecs,
targetColumns);
sinkAbilitySpecs);

// rewrite rel node for delete
if (isDelete) {
Expand All @@ -314,7 +314,7 @@ private static RelNode convertSinkToRel(
typeFactory,
sinkAbilitySpecs);
} else if (isUpdate) {
input =
Tuple2<RelNode, int[]> updateResult =
convertUpdate(
(LogicalTableModify) input,
sink,
Expand All @@ -323,7 +323,12 @@ private static RelNode convertSinkToRel(
dataTypeFactory,
typeFactory,
sinkAbilitySpecs);
input = updateResult.f0;
// align target columns with the projected row delivered to the sink
targetColumns = toNestedIndexPaths(updateResult.f1);
}
// apply target columns after UPDATE rewrite so the sink sees the final column set
validateAndApplyTargetColumns(sink, targetColumns, sinkAbilitySpecs);

sinkAbilitySpecs.forEach(spec -> spec.apply(sink));

Expand Down Expand Up @@ -490,7 +495,7 @@ private static RelNode convertDelete(
return deleteRelNodeAndRequireIndices.f0;
}

private static RelNode convertUpdate(
private static Tuple2<RelNode, int[]> convertUpdate(
LogicalTableModify tableModify,
DynamicTableSink sink,
ContextResolvedTable contextResolvedTable,
Expand Down Expand Up @@ -522,6 +527,7 @@ private static RelNode convertUpdate(
tableModify,
contextResolvedTable,
updateInfo,
updatedColumns,
tableDebugName,
dataTypeFactory,
typeFactory);
Expand All @@ -531,7 +537,21 @@ private static RelNode convertUpdate(
updateInfo.getRowLevelUpdateMode(),
context,
updateRelNodeAndRequireIndices.f1));
return updateRelNodeAndRequireIndices.f0;
return updateRelNodeAndRequireIndices;
}

/** Append updated columns that are missing from the sink-declared required columns. */
private static List<Column> mergeRequiredAndUpdatedColumns(
List<Column> requiredColumns, List<Column> updatedColumns) {
Set<String> existingNames =
requiredColumns.stream().map(Column::getName).collect(Collectors.toSet());
List<Column> merged = new ArrayList<>(requiredColumns);
for (Column updated : updatedColumns) {
if (!existingNames.contains(updated.getName())) {
merged.add(updated);
}
}
return merged;
}

private static List<Column> getUpdatedColumns(
Expand Down Expand Up @@ -590,6 +610,14 @@ private static Tuple2<RelNode, int[]> convertToRowLevelDelete(
getPhysicalColumnIndices(colIndexes, resolvedSchema));
}

private static int[][] toNestedIndexPaths(int[] columnIndices) {
int[][] result = new int[columnIndices.length][];
for (int i = 0; i < columnIndices.length; i++) {
result[i] = new int[] {columnIndices[i]};
}
return result;
}

/** Return the indices from {@param colIndexes} that belong to physical column. */
private static int[] getPhysicalColumnIndices(List<Integer> colIndexes, ResolvedSchema schema) {
return colIndexes.stream()
Expand Down Expand Up @@ -709,13 +737,17 @@ private static Tuple2<RelNode, int[]> convertToRowLevelUpdate(
LogicalTableModify tableModify,
ContextResolvedTable contextResolvedTable,
SupportsRowLevelUpdate.RowLevelUpdateInfo rowLevelUpdateInfo,
List<Column> updatedColumns,
String tableDebugName,
DataTypeFactory dataTypeFactory,
FlinkTypeFactory typeFactory) {
// get the required columns
ResolvedSchema resolvedSchema = contextResolvedTable.getResolvedSchema();
Optional<List<Column>> optionalColumns = rowLevelUpdateInfo.requiredColumns();
List<Column> requiredColumns = optionalColumns.orElse(resolvedSchema.getColumns());
List<Column> requiredColumns =
optionalColumns
.map(cols -> mergeRequiredAndUpdatedColumns(cols, updatedColumns))
.orElse(resolvedSchema.getColumns());
// get the root table scan which we may need rewrite it
LogicalTableScan tableScan = getSourceTableScan(tableModify);
Tuple2<List<Integer>, List<MetadataColumn>> colsIndexAndExtraMetaCols =
Expand Down Expand Up @@ -1026,8 +1058,7 @@ private static void prepareDynamicSink(
boolean isOverwrite,
DynamicTableSink sink,
ResolvedCatalogTable table,
List<SinkAbilitySpec> sinkAbilitySpecs,
int[][] targetColumns) {
List<SinkAbilitySpec> sinkAbilitySpecs) {
table.getDistribution()
.ifPresent(
distribution ->
Expand All @@ -1039,8 +1070,6 @@ private static void prepareDynamicSink(
validateAndApplyOverwrite(tableDebugName, isOverwrite, sink, sinkAbilitySpecs);

validateAndApplyMetadata(tableDebugName, sink, table.getResolvedSchema(), sinkAbilitySpecs);

validateAndApplyTargetColumns(sink, targetColumns, sinkAbilitySpecs);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.flink.table.connector.sink.abilities.SupportsDeletePushDown;
import org.apache.flink.table.connector.sink.abilities.SupportsRowLevelDelete;
import org.apache.flink.table.connector.sink.abilities.SupportsRowLevelUpdate;
import org.apache.flink.table.connector.sink.abilities.SupportsTargetColumnWriting;
import org.apache.flink.table.connector.sink.abilities.SupportsTruncate;
import org.apache.flink.table.connector.source.DynamicTableSource;
import org.apache.flink.table.connector.source.ScanTableSource;
Expand Down Expand Up @@ -76,6 +77,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.apache.flink.table.data.RowData.createFieldGetter;
Expand Down Expand Up @@ -156,12 +158,31 @@ public class TestUpdateDeleteTableFactory
private static final AtomicInteger idCounter = new AtomicInteger(0);
private static final Map<String, Collection<RowData>> registeredRowData = new HashMap<>();

private static final Map<ObjectIdentifier, Optional<int[][]>> capturedUpdateTargetColumns =
new HashMap<>();

private static final Map<ObjectIdentifier, int[][]> capturedAppliedTargetColumns =
new HashMap<>();

public static String registerRowData(Collection<RowData> data) {
String id = String.valueOf(idCounter.incrementAndGet());
registeredRowData.put(id, data);
return id;
}

public static Optional<int[][]> getCapturedUpdateTargetColumns(ObjectIdentifier id) {
return capturedUpdateTargetColumns.get(id);
}

public static int[][] getCapturedAppliedTargetColumns(ObjectIdentifier id) {
return capturedAppliedTargetColumns.get(id);
}

public static void clearCapturedTargetColumns(ObjectIdentifier id) {
capturedUpdateTargetColumns.remove(id);
capturedAppliedTargetColumns.remove(id);
}

@Override
public DynamicTableSink createDynamicTableSink(Context context) {
FactoryUtil.TableFactoryHelper helper = FactoryUtil.createTableFactoryHelper(this, context);
Expand Down Expand Up @@ -323,7 +344,7 @@ private static class TestScanContext implements RowLevelModificationScanContext

/** A sink that supports row-level update. */
private static class SupportsRowLevelUpdateSink
implements DynamicTableSink, SupportsRowLevelUpdate {
implements DynamicTableSink, SupportsRowLevelUpdate, SupportsTargetColumnWriting {

protected final ObjectIdentifier tableIdentifier;
protected final ResolvedCatalogTable resolvedCatalogTable;
Expand Down Expand Up @@ -376,6 +397,7 @@ public ChangelogMode getChangelogMode(ChangelogMode requestedMode) {

@Override
public SinkRuntimeProvider getSinkRuntimeProvider(Context context) {
capturedUpdateTargetColumns.put(tableIdentifier, context.getTargetColumns());
return new DataStreamSinkProvider() {

@Override
Expand Down Expand Up @@ -417,6 +439,12 @@ public String asSummaryString() {
return "SupportsRowLevelUpdateSink";
}

@Override
public boolean applyTargetColumns(int[][] targetColumns) {
capturedAppliedTargetColumns.put(tableIdentifier, targetColumns);
return false;
}

@Override
public RowLevelUpdateInfo applyRowLevelUpdate(
List<Column> updatedColumns, @Nullable RowLevelModificationScanContext context) {
Expand All @@ -437,7 +465,9 @@ public Optional<List<Column>> requiredColumns() {
resolvedCatalogTable.getResolvedSchema());
}
requiredColumnIndices =
getRequiredColumnIndexes(resolvedCatalogTable, requiredCols);
getRequiredColumnIndexes(
resolvedCatalogTable,
mergeRequiredWithUpdatedColumns(requiredCols, updatedColumns));
return Optional.ofNullable(requiredCols);
}

Expand Down Expand Up @@ -750,6 +780,22 @@ public void executeTruncation() {
}
}

private static List<Column> mergeRequiredWithUpdatedColumns(
@Nullable List<Column> requiredColumns, List<Column> updatedColumns) {
if (requiredColumns == null) {
return null;
}
Set<String> existingNames =
requiredColumns.stream().map(Column::getName).collect(Collectors.toSet());
List<Column> merged = new ArrayList<>(requiredColumns);
for (Column updated : updatedColumns) {
if (!existingNames.contains(updated.getName())) {
merged.add(updated);
}
}
return merged;
}

private static int[] getRequiredColumnIndexes(
ResolvedCatalogTable resolvedCatalogTable, @Nullable List<Column> columns) {
if (columns == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ void testUpdateWithCustomColumns() {
util.verifyExplainInsert("UPDATE t SET b = 'v2' WHERE b = '123'", explainDetails);
}

@TestTemplate
void testUpdateColumnDisjointFromRequired() {
util.tableEnv()
.executeSql(
String.format(
"CREATE TABLE t (a int PRIMARY KEY NOT ENFORCED, b string, c double) WITH"
+ " ("
+ "'connector' = 'test-update-delete', "
+ "'required-columns-for-update' = 'a', "
+ "'update-mode' = '%s'"
+ ") ",
updateMode));
util.verifyExplainInsert("UPDATE t SET b = 'v2' WHERE a = 1", explainDetails);
}

@TestTemplate
void testUpdateWithMetaColumns() {
util.tableEnv()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.table.api.StatementSet;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.TableResult;
import org.apache.flink.table.catalog.ObjectIdentifier;
import org.apache.flink.table.connector.sink.DynamicTableSink;
import org.apache.flink.table.connector.sink.abilities.SupportsRowLevelUpdate;
import org.apache.flink.table.data.GenericRowData;
Expand All @@ -41,6 +42,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -62,6 +64,8 @@ private static Collection<SupportsRowLevelUpdate.RowLevelUpdateMode> data() {
@TestTemplate
void testUpdate() throws Exception {
String dataId = registerData();
ObjectIdentifier tableId = ObjectIdentifier.of("default_catalog", "default_database", "t");
TestUpdateDeleteTableFactory.clearCapturedTargetColumns(tableId);
tEnv().executeSql(
String.format(
"CREATE TABLE t ("
Expand All @@ -77,6 +81,14 @@ void testUpdate() throws Exception {
assertThat(rows.toString())
.isEqualTo("[+I[0, b_0, 0.0], +I[1, uaa, 4.0], +I[2, uaa, 16.0]]");

// Sink receives the full row, so both targetColumns paths should be [a,b,c].
Optional<int[][]> captured =
TestUpdateDeleteTableFactory.getCapturedUpdateTargetColumns(tableId);
assertThat(captured).isPresent();
assertThat(captured.get()).isEqualTo(new int[][] {{0}, {1}, {2}});
int[][] applied = TestUpdateDeleteTableFactory.getCapturedAppliedTargetColumns(tableId);
assertThat(applied).isEqualTo(new int[][] {{0}, {1}, {2}});

tEnv().executeSql("UPDATE t SET b = 'uab' WHERE a > (SELECT count(1) FROM t WHERE a > 1)")
.await();
rows = toSortedResults(tEnv().executeSql("SELECT * FROM t"));
Expand All @@ -87,6 +99,8 @@ void testUpdate() throws Exception {
@TestTemplate
void testPartialUpdate() throws Exception {
String dataId = registerData();
ObjectIdentifier tableId = ObjectIdentifier.of("default_catalog", "default_database", "t");
TestUpdateDeleteTableFactory.clearCapturedTargetColumns(tableId);
tEnv().executeSql(
String.format(
"CREATE TABLE t ("
Expand All @@ -103,6 +117,14 @@ void testPartialUpdate() throws Exception {
assertThat(rows.toString())
.isEqualTo("[+I[0, b_0, 0.0], +I[1, uaa, 2.0], +I[2, uaa, 4.0]]");

// Sink-required columns [a,b] should drive both targetColumns paths.
Optional<int[][]> captured =
TestUpdateDeleteTableFactory.getCapturedUpdateTargetColumns(tableId);
assertThat(captured).isPresent();
assertThat(captured.get()).isEqualTo(new int[][] {{0}, {1}});
int[][] applied = TestUpdateDeleteTableFactory.getCapturedAppliedTargetColumns(tableId);
assertThat(applied).isEqualTo(new int[][] {{0}, {1}});

// test partial update with requiring partial primary keys
dataId = registerData();
tEnv().executeSql(
Expand All @@ -124,6 +146,36 @@ void testPartialUpdate() throws Exception {
.isEqualTo("[+I[0, b_0, 0.0], +I[1, uaa, 2.0], +I[2, uaa, 4.0]]");
}

@TestTemplate
void testUpdateColumnDisjointFromRequired() throws Exception {
// The planner must merge updated columns into required columns for targetColumns.
String dataId = registerData();
ObjectIdentifier tableId = ObjectIdentifier.of("default_catalog", "default_database", "t");
TestUpdateDeleteTableFactory.clearCapturedTargetColumns(tableId);
tEnv().executeSql(
String.format(
"CREATE TABLE t ("
+ " a int PRIMARY KEY NOT ENFORCED,"
+ " b string not null,"
+ " c double not null) WITH"
+ " ('connector' = 'test-update-delete', "
+ "'data-id' = '%s',"
+ " 'required-columns-for-update' = 'a', "
+ " 'update-mode' = '%s')",
dataId, updateMode));
tEnv().executeSql("UPDATE t SET b = 'uaa' WHERE a >= 1").await();
List<String> rows = toSortedResults(tEnv().executeSql("SELECT * FROM t"));
assertThat(rows.toString())
.isEqualTo("[+I[0, b_0, 0.0], +I[1, uaa, 2.0], +I[2, uaa, 4.0]]");

Optional<int[][]> captured =
TestUpdateDeleteTableFactory.getCapturedUpdateTargetColumns(tableId);
assertThat(captured).isPresent();
assertThat(captured.get()).isEqualTo(new int[][] {{0}, {1}});
int[][] applied = TestUpdateDeleteTableFactory.getCapturedAppliedTargetColumns(tableId);
assertThat(applied).isEqualTo(new int[][] {{0}, {1}});
}

@TestTemplate
void testStatementSetContainUpdateAndInsert() {
tEnv().executeSql(
Expand Down
Loading