diff --git a/extension/vector/src/function/create_hnsw_index.cpp b/extension/vector/src/function/create_hnsw_index.cpp index 605be13ff91..38a526f4f9d 100644 --- a/extension/vector/src/function/create_hnsw_index.cpp +++ b/extension/vector/src/function/create_hnsw_index.cpp @@ -1,6 +1,7 @@ #include "catalog/catalog_entry/function_catalog_entry.h" #include "catalog/catalog_entry/node_table_catalog_entry.h" #include "catalog/hnsw_index_catalog_entry.h" +#include "common/exception/binder.h" #include "function/built_in_function_utils.h" #include "function/hnsw_index_functions.h" #include "function/table/bind_data.h" @@ -44,17 +45,22 @@ static std::unique_ptr createInMemHNSWBindFunc(main::ClientCo const auto tableName = input->getLiteralVal(0); const auto indexName = input->getLiteralVal(1); const auto columnName = input->getLiteralVal(2); - auto tableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName, indexName, - HNSWIndexUtils::IndexOperation::CREATE); - const auto tableID = tableEntry->getTableID(); - HNSWIndexUtils::validateColumnType(*tableEntry, columnName); + auto config = HNSWIndexConfig{input->optionalParams}; + const auto nodeTableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName); + if (HNSWIndexUtils::validateIndexExistence(*context, nodeTableEntry, indexName, + HNSWIndexUtils::IndexOperation::CREATE, config.conflictAction)) { + return std::make_unique(context, indexName, nullptr, 0, 0, + std::move(config), + true); // Placeholders for nodeTableEntry, propertyID, numNodes - WILL NOT BE ACCESSED + } + const auto tableID = nodeTableEntry->getTableID(); + HNSWIndexUtils::validateColumnType(*nodeTableEntry, columnName); const auto& table = storage::StorageManager::Get(*context)->getTable(tableID)->cast(); - auto propertyID = tableEntry->getPropertyID(columnName); - auto config = HNSWIndexConfig{input->optionalParams}; + auto propertyID = nodeTableEntry->getPropertyID(columnName); auto transaction = transaction::Transaction::Get(*context); auto numNodes = table.getStats(transaction).getTableCard(); - return std::make_unique(context, indexName, tableEntry, propertyID, + return std::make_unique(context, indexName, nodeTableEntry, propertyID, numNodes, std::move(config)); } @@ -327,6 +333,9 @@ static std::string rewriteCreateHNSWQuery(main::ClientContext& context, const TableFuncBindData& bindData) { context.setUseInternalCatalogEntry(true /* useInternalCatalogEntry */); const auto hnswBindData = bindData.constPtrCast(); + if (hnswBindData->skipAfterBind) { + return ""; + } std::string query = "BEGIN TRANSACTION;"; auto indexName = hnswBindData->indexName; auto tableName = hnswBindData->tableEntry->getName(); diff --git a/extension/vector/src/function/drop_hnsw_index.cpp b/extension/vector/src/function/drop_hnsw_index.cpp index 9e2eed01275..2cad8505822 100644 --- a/extension/vector/src/function/drop_hnsw_index.cpp +++ b/extension/vector/src/function/drop_hnsw_index.cpp @@ -1,5 +1,6 @@ #include "catalog/catalog.h" #include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "common/exception/binder.h" #include "function/hnsw_index_functions.h" #include "function/table/bind_data.h" #include "index/hnsw_index_utils.h" @@ -16,12 +17,15 @@ namespace vector_extension { struct DropHNSWIndexBindData final : TableFuncBindData { catalog::NodeTableCatalogEntry* tableEntry; std::string indexName; + bool skipAfterBind; - DropHNSWIndexBindData(catalog::NodeTableCatalogEntry* tableEntry, std::string indexName) - : TableFuncBindData{0}, tableEntry{tableEntry}, indexName{std::move(indexName)} {} + DropHNSWIndexBindData(catalog::NodeTableCatalogEntry* tableEntry, std::string indexName, + bool skipAfterBind = false) + : TableFuncBindData{0}, tableEntry{tableEntry}, indexName{std::move(indexName)}, + skipAfterBind{skipAfterBind} {} std::unique_ptr copy() const override { - return std::make_unique(tableEntry, indexName); + return std::make_unique(tableEntry, indexName, skipAfterBind); } }; @@ -29,9 +33,13 @@ static std::unique_ptr bindFunc(main::ClientContext* context, const TableFuncBindInput* input) { const auto tableName = input->getLiteralVal(0); const auto indexName = input->getLiteralVal(1); - const auto tableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName, indexName, - HNSWIndexUtils::IndexOperation::DROP); - return std::make_unique(tableEntry, indexName); + auto config = DropHNSWConfig{input->optionalParams}; + const auto nodeTableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName); + if (!HNSWIndexUtils::validateIndexExistence(*context, nodeTableEntry, indexName, + HNSWIndexUtils::IndexOperation::DROP, config.conflictAction)) { + return std::make_unique(nullptr, indexName, true); + } + return std::make_unique(nodeTableEntry, indexName); } static common::offset_t internalTableFunc(const TableFuncInput& input, TableFuncOutput&) { @@ -49,6 +57,9 @@ static std::string dropHNSWIndexTables(main::ClientContext& context, const TableFuncBindData& bindData) { const auto dropHNSWIndexBindData = bindData.constPtrCast(); context.setUseInternalCatalogEntry(true /* useInternalCatalogEntry */); + if (dropHNSWIndexBindData->skipAfterBind) { + return ""; + } std::string query = ""; const auto requireNewTransaction = !context.getTransactionContext()->hasActiveTransaction(); if (requireNewTransaction) { diff --git a/extension/vector/src/function/query_hnsw_index.cpp b/extension/vector/src/function/query_hnsw_index.cpp index 4e3e5aead37..7f69cbaedf3 100644 --- a/extension/vector/src/function/query_hnsw_index.cpp +++ b/extension/vector/src/function/query_hnsw_index.cpp @@ -133,7 +133,8 @@ static std::unique_ptr bindFunc(main::ClientContext* context, throw BinderException( stringFormat("Cannot find table or graph named as {}.", tableOrGraphName)); } - auto nodeTableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName, indexName, + auto nodeTableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName); + (void)HNSWIndexUtils::validateIndexExistence(*context, nodeTableEntry, indexName, HNSWIndexUtils::IndexOperation::QUERY); // Bind columns auto columnNames = std::vector{QueryVectorIndexFunction::nnColumnName, diff --git a/extension/vector/src/include/function/hnsw_index_functions.h b/extension/vector/src/include/function/hnsw_index_functions.h index 7824269c39d..d6155df1c9b 100644 --- a/extension/vector/src/include/function/hnsw_index_functions.h +++ b/extension/vector/src/include/function/hnsw_index_functions.h @@ -17,16 +17,18 @@ struct CreateHNSWIndexBindData final : function::TableFuncBindData { catalog::TableCatalogEntry* tableEntry; common::property_id_t propertyID; HNSWIndexConfig config; + bool skipAfterBind; CreateHNSWIndexBindData(main::ClientContext* context, std::string indexName, catalog::TableCatalogEntry* tableEntry, common::property_id_t propertyID, - common::offset_t numNodes, HNSWIndexConfig config) + common::offset_t numNodes, HNSWIndexConfig config, bool skipAfterBind = false) : TableFuncBindData{numNodes}, context{context}, indexName{std::move(indexName)}, - tableEntry{tableEntry}, propertyID{propertyID}, config{std::move(config)} {} + tableEntry{tableEntry}, propertyID{propertyID}, config{std::move(config)}, + skipAfterBind{skipAfterBind} {} std::unique_ptr copy() const override { return std::make_unique(context, indexName, tableEntry, propertyID, - numRows, config.copy()); + numRows, config.copy(), skipAfterBind); } }; diff --git a/extension/vector/src/include/index/hnsw_config.h b/extension/vector/src/include/index/hnsw_config.h index 78529332d1b..625a3f6a979 100644 --- a/extension/vector/src/include/index/hnsw_config.h +++ b/extension/vector/src/include/index/hnsw_config.h @@ -2,6 +2,7 @@ #include +#include "common/enums/conflict_action.h" #include "common/types/types.h" #include "function/table/bind_input.h" @@ -79,6 +80,20 @@ struct CacheEmbeddings { static constexpr bool DEFAULT_VALUE = true; }; +struct SkipIfExists { + static constexpr const char* NAME = "skip_if_exists"; + static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::BOOL; + static constexpr common::ConflictAction DEFAULT_VALUE = + common::ConflictAction::ON_CONFLICT_THROW; +}; + +struct SkipIfNotExists { + static constexpr const char* NAME = "skip_if_not_exists"; + static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::BOOL; + static constexpr common::ConflictAction DEFAULT_VALUE = + common::ConflictAction::ON_CONFLICT_THROW; +}; + struct BlindSearchUpSelThreshold { static constexpr const char* NAME = "blind_search_up_sel"; static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::DOUBLE; @@ -103,6 +118,7 @@ struct HNSWIndexConfig { double alpha = Alpha::DEFAULT_VALUE; int64_t efc = Efc::DEFAULT_VALUE; bool cacheEmbeddingsColumn = CacheEmbeddings::DEFAULT_VALUE; + common::ConflictAction conflictAction = SkipIfExists::DEFAULT_VALUE; HNSWIndexConfig() = default; @@ -119,11 +135,20 @@ struct HNSWIndexConfig { private: HNSWIndexConfig(const HNSWIndexConfig& other) : mu{other.mu}, ml{other.ml}, pu{other.pu}, metric{other.metric}, alpha{other.alpha}, - efc{other.efc}, cacheEmbeddingsColumn(other.cacheEmbeddingsColumn) {} + efc{other.efc}, cacheEmbeddingsColumn(other.cacheEmbeddingsColumn), + conflictAction(other.conflictAction) {} static MetricType getMetricType(const std::string& metricName); }; +struct DropHNSWConfig { + common::ConflictAction conflictAction = SkipIfNotExists::DEFAULT_VALUE; + + DropHNSWConfig() = default; + + explicit DropHNSWConfig(const function::optional_params_t& optionalParams); +}; + struct QueryHNSWConfig { int64_t efs = Efs::DEFAULT_VALUE; double blindSearchUpSelThreshold = BlindSearchUpSelThreshold::DEFAULT_VALUE; diff --git a/extension/vector/src/include/index/hnsw_index_utils.h b/extension/vector/src/include/index/hnsw_index_utils.h index 6e6d6cb7f81..6a08ed11e60 100644 --- a/extension/vector/src/include/index/hnsw_index_utils.h +++ b/extension/vector/src/include/index/hnsw_index_utils.h @@ -21,12 +21,17 @@ using metric_func_t = std::function; struct HNSWIndexUtils { enum class KUZU_API IndexOperation { CREATE, QUERY, DROP }; - static void validateIndexExistence(const main::ClientContext& context, + static bool indexExists(const main::ClientContext& context, + const transaction::Transaction* transaction, const catalog::TableCatalogEntry* tableEntry, + const std::string& indexName); + + static bool validateIndexExistence(const main::ClientContext& context, const catalog::TableCatalogEntry* tableEntry, const std::string& indexName, - IndexOperation indexOperation); + IndexOperation indexOperation, + common::ConflictAction conflictAction = common::ConflictAction::ON_CONFLICT_THROW); static catalog::NodeTableCatalogEntry* bindNodeTable(const main::ClientContext& context, - const std::string& tableName, const std::string& indexName, IndexOperation indexOperation); + const std::string& tableName); static void validateAutoTransaction(const main::ClientContext& context, const std::string& funcName); diff --git a/extension/vector/src/index/hnsw_config.cpp b/extension/vector/src/index/hnsw_config.cpp index f9957af037b..e672ce90452 100644 --- a/extension/vector/src/index/hnsw_config.cpp +++ b/extension/vector/src/index/hnsw_config.cpp @@ -111,6 +111,11 @@ HNSWIndexConfig::HNSWIndexConfig(const function::optional_params_t& optionalPara } else if (CacheEmbeddings::NAME == lowerCaseName) { value.validateType(CacheEmbeddings::TYPE); cacheEmbeddingsColumn = value.getValue(); + } else if (SkipIfExists::NAME == lowerCaseName) { + value.validateType(SkipIfExists::TYPE); + conflictAction = value.getValue() ? + common::ConflictAction::ON_CONFLICT_DO_NOTHING : + common::ConflictAction::ON_CONFLICT_THROW; } else { throw common::BinderException{ common::stringFormat("Unrecognized optional parameter {} in {}.", name, @@ -188,6 +193,21 @@ MetricType HNSWIndexConfig::getMetricType(const std::string& metricName) { KU_UNREACHABLE; } +DropHNSWConfig::DropHNSWConfig(const function::optional_params_t& optionalParams) { + for (auto& [name, value] : optionalParams) { + auto lowerCaseName = common::StringUtils::getLower(name); + if (SkipIfNotExists::NAME == lowerCaseName) { + value.validateType(SkipIfNotExists::TYPE); + conflictAction = value.getValue() ? + common::ConflictAction::ON_CONFLICT_DO_NOTHING : + common::ConflictAction::ON_CONFLICT_THROW; + } else { + throw common::BinderException{common::stringFormat( + "Unrecognized optional parameter {} in {}.", name, DropVectorIndexFunction::name)}; + } + } +} + QueryHNSWConfig::QueryHNSWConfig(const function::optional_params_t& optionalParams) { for (auto& [name, value] : optionalParams) { auto lowerCaseName = common::StringUtils::getLower(name); diff --git a/extension/vector/src/index/hnsw_index_utils.cpp b/extension/vector/src/index/hnsw_index_utils.cpp index 2e453c553b5..330fb65b13f 100644 --- a/extension/vector/src/index/hnsw_index_utils.cpp +++ b/extension/vector/src/index/hnsw_index_utils.cpp @@ -12,25 +12,53 @@ namespace kuzu { namespace vector_extension { -void HNSWIndexUtils::validateIndexExistence(const main::ClientContext& context, +bool HNSWIndexUtils::indexExists(const main::ClientContext& context, + const transaction::Transaction* transaction, const catalog::TableCatalogEntry* tableEntry, + const std::string& indexName) { + return catalog::Catalog::Get(context)->containsIndex(transaction, tableEntry->getTableID(), + indexName); +} + +bool HNSWIndexUtils::validateIndexExistence(const main::ClientContext& context, const catalog::TableCatalogEntry* tableEntry, const std::string& indexName, - IndexOperation indexOperation) { + IndexOperation indexOperation, common::ConflictAction conflictAction) { auto transaction = transaction::Transaction::Get(context); switch (indexOperation) { case IndexOperation::CREATE: { - if (catalog::Catalog::Get(context)->containsIndex(transaction, tableEntry->getTableID(), - indexName)) { - throw common::BinderException{common::stringFormat( - "Index {} already exists in table {}.", indexName, tableEntry->getName())}; + if (indexExists(context, transaction, tableEntry, indexName)) { + switch (conflictAction) { + case common::ConflictAction::ON_CONFLICT_THROW: + throw common::BinderException{common::stringFormat( + "Index {} already exists in table {}.", indexName, tableEntry->getName())}; + case common::ConflictAction::ON_CONFLICT_DO_NOTHING: + return true; + default: + KU_UNREACHABLE; + } + } + return false; + } break; + case IndexOperation::DROP: { + if (!indexExists(context, transaction, tableEntry, indexName)) { + switch (conflictAction) { + case common::ConflictAction::ON_CONFLICT_THROW: + throw common::BinderException{ + common::stringFormat("Table {} doesn't have an index with name {}.", + tableEntry->getName(), indexName)}; + case common::ConflictAction::ON_CONFLICT_DO_NOTHING: + return false; + default: + KU_UNREACHABLE; + } } + return true; } break; - case IndexOperation::DROP: case IndexOperation::QUERY: { - if (!catalog::Catalog::Get(context)->containsIndex(transaction, tableEntry->getTableID(), - indexName)) { + if (!indexExists(context, transaction, tableEntry, indexName)) { throw common::BinderException{common::stringFormat( "Table {} doesn't have an index with name {}.", tableEntry->getName(), indexName)}; } + return true; } break; default: { KU_UNREACHABLE; @@ -39,13 +67,12 @@ void HNSWIndexUtils::validateIndexExistence(const main::ClientContext& context, } catalog::NodeTableCatalogEntry* HNSWIndexUtils::bindNodeTable(const main::ClientContext& context, - const std::string& tableName, const std::string& indexName, IndexOperation indexOperation) { + const std::string& tableName) { binder::Binder::validateTableExistence(context, tableName); auto transaction = transaction::Transaction::Get(context); const auto tableEntry = catalog::Catalog::Get(context)->getTableCatalogEntry(transaction, tableName); binder::Binder::validateNodeTableType(tableEntry); - validateIndexExistence(context, tableEntry, indexName, indexOperation); return tableEntry->ptrCast(); } diff --git a/extension/vector/test/test_files/error.test b/extension/vector/test/test_files/error.test index b56cfd39fb9..d60f455629e 100644 --- a/extension/vector/test/test_files/error.test +++ b/extension/vector/test/test_files/error.test @@ -75,6 +75,11 @@ Binder exception: Unrecognized optional parameter unknown_param in QUERY_VECTOR_ -STATEMENT CALL QUERY_VECTOR_INDEX('embeddings', 'e_hnsw_index', CAST([0.1521,0.3021,0.5366,0.2774,0.5593,0.5589,0.1365,0.8557],'FLOAT[8]'), 3, efs := -1) RETURN node.id, distance ORDER BY distance; ---- error Binder exception: Efs must be a positive integer. +-STATEMENT CALL DROP_VECTOR_INDEX('embeddings', 'e_hnsw_index', unknown_param := 1); +---- error +Binder exception: Unrecognized optional parameter unknown_param in DROP_VECTOR_INDEX. +-STATEMENT CALL DROP_VECTOR_INDEX('embeddings', 'e_hnsw_index'); +---- ok -CASE CastingError -LOAD_DYNAMIC_EXTENSION vector diff --git a/extension/vector/test/test_files/error_suppress.test b/extension/vector/test/test_files/error_suppress.test new file mode 100644 index 00000000000..14539d83846 --- /dev/null +++ b/extension/vector/test/test_files/error_suppress.test @@ -0,0 +1,28 @@ +-DATASET CSV empty + +-- + +-CASE CreateSkipIfExists +-LOAD_DYNAMIC_EXTENSION vector +-STATEMENT CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id)); +---- ok +-STATEMENT CALL CREATE_VECTOR_INDEX('embeddings', 'e_hnsw_index', 'vec'); +---- ok +-STATEMENT CALL CREATE_VECTOR_INDEX('embeddings', 'e_hnsw_index', 'vec'); +---- error +Binder exception: Index e_hnsw_index already exists in table embeddings. +-STATEMENT CALL CREATE_VECTOR_INDEX('embeddings', 'e_hnsw_index', 'vec', skip_if_exists := true); +---- ok +-STATEMENT CALL SHOW_INDEXES() RETURN * +---- 1 +embeddings|e_hnsw_index|HNSW|[vec]|True|CALL CREATE_VECTOR_INDEX('embeddings', 'e_hnsw_index', 'vec', mu := 30, ml := 60, pu := 0.050000, metric := 'cosine', alpha := 1.100000, efc := 200); + +-CASE DropSkipIfNotExists +-LOAD_DYNAMIC_EXTENSION vector +-STATEMENT CREATE NODE TABLE embeddings (id int64, vec FLOAT[8], PRIMARY KEY (id)); +---- ok +-STATEMENT CALL DROP_VECTOR_INDEX('embeddings', 'e_hnsw_index'); +---- error +Binder exception: Table embeddings doesn't have an index with name e_hnsw_index. +-STATEMENT CALL DROP_VECTOR_INDEX('embeddings', 'e_hnsw_index', skip_if_not_exists := true); +---- ok