Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, PaimonSparkSession, SparkSession}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface}
import org.apache.spark.sql.catalyst.parser.extensions.PaimonSqlExtensionsParser.{NonReservedContext, QuotedIdentifierContext}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.parser.extensions.PaimonSqlExtensionsParser.{CreateTableLikeContext, MultipartIdentifierContext, NonReservedContext, QuotedIdentifierContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.VariableSubstitution
Expand All @@ -52,30 +52,76 @@ import scala.collection.JavaConverters._
* @param delegate
* The extension parser.
*/
// Keep this class in the Spark 4.0 module so it is compiled against Spark 4.0's ParserInterface.
abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterface)
extends org.apache.spark.sql.catalyst.parser.ParserInterface
with Logging {

private lazy val substitutor = new VariableSubstitution()
private lazy val astBuilder = new PaimonSqlExtensionsAstBuilder(delegate)
private val nonReservedIdentifierTokenTypes = Set(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these effects?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Zouxxyy This set is only used by the lightweight maybeCreateTableLike pre-check.

It mirrors the nonReserved identifier rule in PaimonSqlExtensions.g4. The lexer returns words such as TAG, LIKE, MAP, etc. as keyword token
types, but the grammar still allows them to be used as unquoted identifier parts. Without this set, a valid statement like CREATE TABLE paimon.default.tag LIKE ... would not be recognized by the pre-check and could fall back to Spark's parser, bypassing the Paimon parser path for
catalog-qualified CREATE TABLE LIKE.

PaimonSqlExtensionsParser.ALTER,
PaimonSqlExtensionsParser.AS,
PaimonSqlExtensionsParser.CALL,
PaimonSqlExtensionsParser.CREATE,
PaimonSqlExtensionsParser.DAYS,
PaimonSqlExtensionsParser.DELETE,
PaimonSqlExtensionsParser.EXISTS,
PaimonSqlExtensionsParser.HOURS,
PaimonSqlExtensionsParser.IF,
PaimonSqlExtensionsParser.LIKE,
PaimonSqlExtensionsParser.NOT,
PaimonSqlExtensionsParser.OF,
PaimonSqlExtensionsParser.OR,
PaimonSqlExtensionsParser.TABLE,
PaimonSqlExtensionsParser.REPLACE,
PaimonSqlExtensionsParser.RETAIN,
PaimonSqlExtensionsParser.VERSION,
PaimonSqlExtensionsParser.TAG,
PaimonSqlExtensionsParser.TRUE,
PaimonSqlExtensionsParser.FALSE,
PaimonSqlExtensionsParser.MAP,
PaimonSqlExtensionsParser.COPY,
PaimonSqlExtensionsParser.INTO,
PaimonSqlExtensionsParser.FROM,
PaimonSqlExtensionsParser.FILE_FORMAT,
PaimonSqlExtensionsParser.PATTERN,
PaimonSqlExtensionsParser.FORCE,
PaimonSqlExtensionsParser.ON_ERROR,
PaimonSqlExtensionsParser.ABORT_STATEMENT,
PaimonSqlExtensionsParser.OVERWRITE,
PaimonSqlExtensionsParser.CSV
)

/** Parses a string to a LogicalPlan. */
override def parsePlan(sqlText: String): LogicalPlan = {
val sqlTextAfterSubstitution = substitutor.substitute(sqlText)
if (isPaimonCommand(sqlTextAfterSubstitution)) {
parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement()))
.asInstanceOf[LogicalPlan]
} else if (isCatalogCreateTableLike(sqlTextAfterSubstitution)) {
applyParserRules(
parse(sqlTextAfterSubstitution)(parser => astBuilder.visit(parser.singleStatement()))
.asInstanceOf[LogicalPlan])
} else {
var plan = delegate.parsePlan(sqlText)
val sparkSession = PaimonSparkSession.active
parserRules(sparkSession).foreach(
rule => {
plan = rule.apply(plan)
})
plan
parsePlanWithDelegate(sqlText)
}
}

private def parsePlanWithDelegate(sqlText: String): LogicalPlan = {
applyParserRules(delegate.parsePlan(sqlText))
}

private def applyParserRules(plan: LogicalPlan): LogicalPlan = {
var rewrittenPlan = plan
val sparkSession = PaimonSparkSession.active
parserRules(sparkSession).foreach(
rule => {
rewrittenPlan = rule.apply(rewrittenPlan)
})
rewrittenPlan
}

private def parserRules(sparkSession: SparkSession): Seq[Rule[LogicalPlan]] = {
Seq(
RewritePaimonViewCommands(sparkSession),
Expand Down Expand Up @@ -144,6 +190,154 @@ abstract class AbstractPaimonSparkSqlExtensionsParser(val delegate: ParserInterf
normalized.startsWith("copy into")
}

private def isCatalogCreateTableLike(sqlText: String): Boolean = {
if (org.apache.spark.SPARK_VERSION < "3.4") {
return false
}
if (!startsWithCreateTable(sqlText)) {
return false
}

tokenStream(sqlText) match {
case Some(tokens) if maybeCreateTableLike(tokens) =>
isParsedCatalogCreateTableLike(sqlText)
case _ => false
}
}

private def tokenStream(sqlText: String): Option[CommonTokenStream] = {
try {
val lexer = new PaimonSqlExtensionsLexer(
new UpperCaseCharStream(CharStreams.fromString(sqlText)))
lexer.removeErrorListeners()
lexer.addErrorListener(PaimonParseErrorListener)

val tokens = new CommonTokenStream(lexer)
tokens.fill()
Some(tokens)
} catch {
case _: PaimonParseException => None
}
}

private def maybeCreateTableLike(tokenStream: CommonTokenStream): Boolean = {
val tokens = tokenStream.getTokens.asScala
.filter(token => token.getChannel == Token.DEFAULT_CHANNEL)
.filterNot(token => token.getType == Token.EOF)

if (tokens.length < 5) return false
if (tokens(0).getType != PaimonSqlExtensionsParser.CREATE) return false
if (tokens(1).getType != PaimonSqlExtensionsParser.TABLE) return false

var idx = 2
if (
idx + 2 < tokens.length &&
tokens(idx).getType == PaimonSqlExtensionsParser.IF &&
tokens(idx + 1).getType == PaimonSqlExtensionsParser.NOT &&
tokens(idx + 2).getType == PaimonSqlExtensionsParser.EXISTS
) {
idx += 3
}

if (idx >= tokens.length || !isIdentifierToken(tokens(idx))) return false
idx += 1

while (
idx + 1 < tokens.length &&
tokens(idx).getText == "." &&
isIdentifierToken(tokens(idx + 1))
) {
idx += 2
}

idx < tokens.length && tokens(idx).getType == PaimonSqlExtensionsParser.LIKE
}

private def isIdentifierToken(token: Token): Boolean = {
token.getType == PaimonSqlExtensionsParser.IDENTIFIER ||
token.getType == PaimonSqlExtensionsParser.BACKQUOTED_IDENTIFIER ||
nonReservedIdentifierTokenTypes.contains(token.getType)
}

private def startsWithCreateTable(sqlText: String): Boolean = {
val createIndex = skipWhitespaceAndComments(sqlText, 0)
if (!matchesWord(sqlText, createIndex, "create")) {
return false
}

val tableIndex = skipWhitespaceAndComments(sqlText, createIndex + "create".length)
matchesWord(sqlText, tableIndex, "table")
}

private def skipWhitespaceAndComments(sqlText: String, start: Int): Int = {
var index = start
var continue = true

while (continue) {
while (index < sqlText.length && sqlText.charAt(index).isWhitespace) {
index += 1
}

if (
index + 1 < sqlText.length &&
sqlText.charAt(index) == '-' &&
sqlText.charAt(index + 1) == '-'
) {
index += 2
while (
index < sqlText.length &&
sqlText.charAt(index) != '\n' &&
sqlText.charAt(index) != '\r'
) {
index += 1
}
} else if (
index + 1 < sqlText.length &&
sqlText.charAt(index) == '/' &&
sqlText.charAt(index + 1) == '*'
) {
val close = sqlText.indexOf("*/", index + 2)
index = if (close >= 0) close + 2 else sqlText.length
} else {
continue = false
}
}

index
}

private def matchesWord(sqlText: String, index: Int, word: String): Boolean = {
index + word.length <= sqlText.length &&
sqlText.regionMatches(true, index, word, 0, word.length) &&
(index + word.length == sqlText.length ||
!isIdentifierPart(sqlText.charAt(index + word.length)))
}

private def isIdentifierPart(char: Char): Boolean = {
char.isLetterOrDigit || char == '_'
}

private def isParsedCatalogCreateTableLike(sqlText: String): Boolean = {
try {
parse(sqlText) {
parser =>
val singleStatement = parser.singleStatement()
singleStatement.statement() match {
case ctx: CreateTableLikeContext
if isCatalogIdentifier(ctx.target) || isCatalogIdentifier(ctx.source) =>
true
case _ => false
}
}
} catch {
case _: PaimonParseException => false
}
}

private def isCatalogIdentifier(identifier: MultipartIdentifierContext): Boolean = {
identifier.parts.size() >= 3
}

protected def parse[T](command: String)(toResult: PaimonSqlExtensionsParser => T): T = {
val lexer = new PaimonSqlExtensionsLexer(
new UpperCaseCharStream(CharStreams.fromString(command)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ statement
FROM multipartIdentifier
fileFormatClause
overwriteClause? #copyIntoLocation
| CREATE TABLE (IF NOT EXISTS)? target=multipartIdentifier
LIKE source=multipartIdentifier ( . )*? #createTableLike
;

callArgument
Expand Down Expand Up @@ -197,8 +199,8 @@ quotedIdentifier
;

nonReserved
: ALTER | AS | CALL | CREATE | DAYS | DELETE | EXISTS | HOURS | IF | NOT | OF | OR | TABLE
| REPLACE | RETAIN | VERSION | TAG
: ALTER | AS | CALL | CREATE | DAYS | DELETE | EXISTS | HOURS | IF | LIKE
| NOT | OF | OR | TABLE | REPLACE | RETAIN | VERSION | TAG
| TRUE | FALSE
| MAP
| COPY | INTO | FROM | FILE_FORMAT | PATTERN | FORCE | ON_ERROR | ABORT_STATEMENT | OVERWRITE
Expand All @@ -214,6 +216,7 @@ DELETE: 'DELETE';
EXISTS: 'EXISTS';
HOURS: 'HOURS';
IF : 'IF';
LIKE: 'LIKE';
MINUTES: 'MINUTES';
NOT: 'NOT';
OF: 'OF';
Expand Down
Loading