Skip to content
This repository was archived by the owner on Oct 10, 2025. It is now read-only.
2 changes: 1 addition & 1 deletion src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
// Resolve exec and select function if necessary
// Only used for decimal at the moment. See `bindDecimalCompare`.
function->bindFunc({childrenAfterCast, function, nullptr,
std::vector<std::string>{} /* optionalParams */});
binder::expression_vector{} /* optionalParams */});
}
auto bindData = std::make_unique<FunctionBindData>(LogicalType(function->returnTypeID));
auto uniqueExpressionName =
Expand Down
15 changes: 9 additions & 6 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "binder/binder.h"
#include "binder/expression/aggregate_function_expression.h"
#include "binder/expression/expression.h"
#include "binder/expression/scalar_function_expression.h"
#include "binder/expression_binder.h"
#include "catalog/catalog.h"
Expand Down Expand Up @@ -45,15 +46,17 @@ std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(const Parse
std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName) {
expression_vector children;
expression_vector optionalParams;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto expr = bindExpression(*parsedExpression.getChild(i));
if (parsedExpression.getChild(i)->hasAlias()) {
if (!parsedExpression.getChild(i)->hasAlias()) {
children.push_back(expr);
} else {
expr->setAlias(parsedExpression.getChild(i)->getAlias());
optionalParams.push_back(expr);
}
children.push_back(expr);
}
return bindScalarFunctionExpression(children, functionName,
parsedExpression.constCast<ParsedFunctionExpression>().getOptionalArguments());
return bindScalarFunctionExpression(children, functionName, optionalParams);
}

static std::vector<LogicalType> getTypes(const expression_vector& exprs) {
Expand All @@ -66,7 +69,7 @@ static std::vector<LogicalType> getTypes(const expression_vector& exprs) {

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName,
std::vector<std::string> optionalArguments) {
binder::expression_vector optionalArguments) {
auto catalog = Catalog::Get(*context);
auto transaction = transaction::Transaction::Get(*context);
auto childrenTypes = getTypes(children);
Expand Down Expand Up @@ -175,7 +178,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
std::unique_ptr<FunctionBindData> bindData;
if (function.bindFunc) {
auto bindInput = ScalarBindFuncInput{children, &function, context,
std::vector<std::string>{} /* optionalParams */};
binder::expression_vector{} /* optionalParams */};
bindData = function.bindFunc(bindInput);
} else {
bindData = std::make_unique<FunctionBindData>(LogicalType(function.returnTypeID));
Expand Down
2 changes: 1 addition & 1 deletion src/function/struct/struct_pack_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& inp
throw BinderException(
stringFormat("Cannot infer field name for {}.", argument->toString()));
}
auto fieldName = input.optionalArguments[i];
auto fieldName = input.optionalArguments[i]->getAlias();
if (fieldNameSet.contains(fieldName)) {
throw BinderException(stringFormat("Found duplicate field {} in STRUCT.", fieldName));
} else {
Expand Down
34 changes: 16 additions & 18 deletions src/function/union/union_value_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,35 @@ namespace kuzu {
namespace function {

static std::unique_ptr<FunctionBindData> bindFunc(const ScalarBindFuncInput& input) {
KU_ASSERT(input.arguments.size() == 1);
KU_ASSERT(input.optionalArguments.size() == 1);
std::vector<StructField> fields;
if (input.arguments[0]->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) {
input.arguments[0]->cast(LogicalType::STRING());
if (input.optionalArguments[0]->getDataType().getLogicalTypeID() ==
common::LogicalTypeID::ANY) {
input.optionalArguments[0]->cast(LogicalType::STRING());
}
fields.emplace_back(input.arguments[0]->getAlias(), input.arguments[0]->getDataType().copy());
fields.emplace_back(input.optionalArguments[0]->getAlias(),
input.optionalArguments[0]->getDataType().copy());
auto resultType = LogicalType::UNION(std::move(fields));
return FunctionBindData::getSimpleBindData(input.arguments, resultType);
return FunctionBindData::getSimpleBindData(input.optionalArguments, resultType);
}

static void execFunc(const std::vector<std::shared_ptr<common::ValueVector>>&,
static void execFunc(const std::vector<std::shared_ptr<common::ValueVector>>& parameters,
const std::vector<common::SelectionVector*>&, common::ValueVector& result,
common::SelectionVector* resultSelVector, void* /*dataPtr*/) {
// (Tanvir) This is broken, parameters does not include optional params so we would
// get an out of bounds error
KU_ASSERT_UNCONDITIONAL(false);
result.setState(parameters[0]->state);
UnionVector::getTagVector(&result)->setState(parameters[0]->state);
UnionVector::referenceVector(&result, UnionType::TAG_FIELD_IDX, parameters[0]);
UnionVector::setTagField(result, *resultSelVector, UnionType::TAG_FIELD_IDX);
}

static void valueCompileFunc(FunctionBindData* /*bindData*/,
const std::vector<std::shared_ptr<ValueVector>>& parameters,
std::shared_ptr<ValueVector>& result) {
KU_ASSERT(parameters.size() == 1);
result->setState(parameters[0]->state);
UnionVector::getTagVector(result.get())->setState(parameters[0]->state);
UnionVector::referenceVector(result.get(), UnionType::TAG_FIELD_IDX, parameters[0]);
}

function_set UnionValueFunction::getFunctionSet() {
function_set functionSet;
auto function = std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::ANY}, LogicalTypeID::UNION, execFunc);
auto function = std::make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{},
LogicalTypeID::UNION, execFunc);
function->bindFunc = bindFunc;
function->compileFunc = valueCompileFunc;
functionSet.push_back(std::move(function));
return functionSet;
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class ExpressionBinder {
const parser::ParsedExpression& parsedExpression, const std::string& functionName);
std::shared_ptr<Expression> bindScalarFunctionExpression(const expression_vector& children,
const std::string& functionName,
std::vector<std::string> optionalArguments = std::vector<std::string>{});
binder::expression_vector optionalArguments = binder::expression_vector{});
std::shared_ptr<Expression> bindRewriteFunctionExpression(const parser::ParsedExpression& expr);
std::shared_ptr<Expression> bindAggregateFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName,
Expand Down
4 changes: 2 additions & 2 deletions src/include/function/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ struct ScalarBindFuncInput {
const binder::expression_vector& arguments;
Function* definition;
main::ClientContext* context;
std::vector<std::string> optionalArguments;
const binder::expression_vector optionalArguments;

ScalarBindFuncInput(const binder::expression_vector& arguments, Function* definition,
main::ClientContext* context, std::vector<std::string> optionalArguments)
main::ClientContext* context, binder::expression_vector optionalArguments)
: arguments{arguments}, definition{definition}, context{context},
optionalArguments{std::move(optionalArguments)} {}
};
Expand Down
Loading