diff --git a/src/binder/bind_expression/bind_comparison_expression.cpp b/src/binder/bind_expression/bind_comparison_expression.cpp index 783c87600b4..413d76fd17e 100644 --- a/src/binder/bind_expression/bind_comparison_expression.cpp +++ b/src/binder/bind_expression/bind_comparison_expression.cpp @@ -77,7 +77,7 @@ std::shared_ptr ExpressionBinder::bindComparisonExpression( // Resolve exec and select function if necessary // Only used for decimal at the moment. See `bindDecimalCompare`. function->bindFunc({childrenAfterCast, function, nullptr, - std::vector{} /* optionalParams */}); + binder::expression_vector{} /* optionalParams */}); } auto bindData = std::make_unique(LogicalType(function->returnTypeID)); auto uniqueExpressionName = diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 5e014e6f117..191d4dbe43b 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -1,5 +1,6 @@ #include "binder/binder.h" #include "binder/expression/aggregate_function_expression.h" +#include "binder/expression/expression.h" #include "binder/expression/scalar_function_expression.h" #include "binder/expression_binder.h" #include "catalog/catalog.h" @@ -45,15 +46,17 @@ std::shared_ptr ExpressionBinder::bindFunctionExpression(const Parse std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( const ParsedExpression& parsedExpression, const std::string& functionName) { expression_vector children; + expression_vector optionalParams; for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) { auto expr = bindExpression(*parsedExpression.getChild(i)); - if (parsedExpression.getChild(i)->hasAlias()) { + if (!parsedExpression.getChild(i)->hasAlias()) { + children.push_back(expr); + } else { expr->setAlias(parsedExpression.getChild(i)->getAlias()); + optionalParams.push_back(expr); } - children.push_back(expr); } - return bindScalarFunctionExpression(children, functionName, - parsedExpression.constCast().getOptionalArguments()); + return bindScalarFunctionExpression(children, functionName, optionalParams); } static std::vector getTypes(const expression_vector& exprs) { @@ -66,7 +69,7 @@ static std::vector getTypes(const expression_vector& exprs) { std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( const expression_vector& children, const std::string& functionName, - std::vector optionalArguments) { + binder::expression_vector optionalArguments) { auto catalog = Catalog::Get(*context); auto transaction = transaction::Transaction::Get(*context); auto childrenTypes = getTypes(children); @@ -175,7 +178,7 @@ std::shared_ptr ExpressionBinder::bindAggregateFunctionExpression( std::unique_ptr bindData; if (function.bindFunc) { auto bindInput = ScalarBindFuncInput{children, &function, context, - std::vector{} /* optionalParams */}; + binder::expression_vector{} /* optionalParams */}; bindData = function.bindFunc(bindInput); } else { bindData = std::make_unique(LogicalType(function.returnTypeID)); diff --git a/src/function/struct/struct_pack_function.cpp b/src/function/struct/struct_pack_function.cpp index 57576f32993..5ab5cc83cc2 100644 --- a/src/function/struct/struct_pack_function.cpp +++ b/src/function/struct/struct_pack_function.cpp @@ -23,7 +23,7 @@ static std::unique_ptr bindFunc(const ScalarBindFuncInput& inp throw BinderException( stringFormat("Cannot infer field name for {}.", argument->toString())); } - auto fieldName = input.optionalArguments[i]; + auto fieldName = input.optionalArguments[i]->getAlias(); if (fieldNameSet.contains(fieldName)) { throw BinderException(stringFormat("Found duplicate field {} in STRUCT.", fieldName)); } else { diff --git a/src/function/union/union_value_function.cpp b/src/function/union/union_value_function.cpp index d98e2bfc111..fedb1ca8094 100644 --- a/src/function/union/union_value_function.cpp +++ b/src/function/union/union_value_function.cpp @@ -7,37 +7,35 @@ namespace kuzu { namespace function { static std::unique_ptr bindFunc(const ScalarBindFuncInput& input) { - KU_ASSERT(input.arguments.size() == 1); + KU_ASSERT(input.optionalArguments.size() == 1); std::vector fields; - if (input.arguments[0]->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) { - input.arguments[0]->cast(LogicalType::STRING()); + if (input.optionalArguments[0]->getDataType().getLogicalTypeID() == + common::LogicalTypeID::ANY) { + input.optionalArguments[0]->cast(LogicalType::STRING()); } - fields.emplace_back(input.arguments[0]->getAlias(), input.arguments[0]->getDataType().copy()); + fields.emplace_back(input.optionalArguments[0]->getAlias(), + input.optionalArguments[0]->getDataType().copy()); auto resultType = LogicalType::UNION(std::move(fields)); - return FunctionBindData::getSimpleBindData(input.arguments, resultType); + return FunctionBindData::getSimpleBindData(input.optionalArguments, resultType); } -static void execFunc(const std::vector>&, +static void execFunc(const std::vector>& parameters, const std::vector&, common::ValueVector& result, common::SelectionVector* resultSelVector, void* /*dataPtr*/) { + // (Tanvir) This is broken, parameters does not include optional params so we would + // get an out of bounds error + KU_ASSERT_UNCONDITIONAL(false); + result.setState(parameters[0]->state); + UnionVector::getTagVector(&result)->setState(parameters[0]->state); + UnionVector::referenceVector(&result, UnionType::TAG_FIELD_IDX, parameters[0]); UnionVector::setTagField(result, *resultSelVector, UnionType::TAG_FIELD_IDX); } -static void valueCompileFunc(FunctionBindData* /*bindData*/, - const std::vector>& parameters, - std::shared_ptr& result) { - KU_ASSERT(parameters.size() == 1); - result->setState(parameters[0]->state); - UnionVector::getTagVector(result.get())->setState(parameters[0]->state); - UnionVector::referenceVector(result.get(), UnionType::TAG_FIELD_IDX, parameters[0]); -} - function_set UnionValueFunction::getFunctionSet() { function_set functionSet; - auto function = std::make_unique(name, - std::vector{LogicalTypeID::ANY}, LogicalTypeID::UNION, execFunc); + auto function = std::make_unique(name, std::vector{}, + LogicalTypeID::UNION, execFunc); function->bindFunc = bindFunc; - function->compileFunc = valueCompileFunc; functionSet.push_back(std::move(function)); return functionSet; } diff --git a/src/include/binder/expression_binder.h b/src/include/binder/expression_binder.h index 5564d0c4da1..3816d0e2a95 100644 --- a/src/include/binder/expression_binder.h +++ b/src/include/binder/expression_binder.h @@ -86,7 +86,7 @@ class ExpressionBinder { const parser::ParsedExpression& parsedExpression, const std::string& functionName); std::shared_ptr bindScalarFunctionExpression(const expression_vector& children, const std::string& functionName, - std::vector optionalArguments = std::vector{}); + binder::expression_vector optionalArguments = binder::expression_vector{}); std::shared_ptr bindRewriteFunctionExpression(const parser::ParsedExpression& expr); std::shared_ptr bindAggregateFunctionExpression( const parser::ParsedExpression& parsedExpression, const std::string& functionName, diff --git a/src/include/function/function.h b/src/include/function/function.h index 5559a78d4ef..aa43c2a9e4f 100644 --- a/src/include/function/function.h +++ b/src/include/function/function.h @@ -47,10 +47,10 @@ struct ScalarBindFuncInput { const binder::expression_vector& arguments; Function* definition; main::ClientContext* context; - std::vector optionalArguments; + const binder::expression_vector optionalArguments; ScalarBindFuncInput(const binder::expression_vector& arguments, Function* definition, - main::ClientContext* context, std::vector optionalArguments) + main::ClientContext* context, binder::expression_vector optionalArguments) : arguments{arguments}, definition{definition}, context{context}, optionalArguments{std::move(optionalArguments)} {} };