Skip to content
This repository was archived by the owner on Oct 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion extension/vector/src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ add_library(kuzu_hnsw_function
OBJECT
create_hnsw_index.cpp
drop_hnsw_index.cpp
query_hnsw_index.cpp)
query_hnsw_index.cpp
drop_all_hnsw_indexes.cpp)

set(VECTOR_EXTENSION_OBJECT_FILES
${VECTOR_EXTENSION_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_hnsw_function>
Expand Down
23 changes: 17 additions & 6 deletions extension/vector/src/function/create_hnsw_index.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "catalog/catalog_entry/function_catalog_entry.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/hnsw_index_catalog_entry.h"
#include "common/exception/binder.h"
#include "function/built_in_function_utils.h"
#include "function/hnsw_index_functions.h"
#include "function/table/bind_data.h"
Expand Down Expand Up @@ -44,14 +45,21 @@ static std::unique_ptr<TableFuncBindData> createInMemHNSWBindFunc(main::ClientCo
const auto tableName = input->getLiteralVal<std::string>(0);
const auto indexName = input->getLiteralVal<std::string>(1);
const auto columnName = input->getLiteralVal<std::string>(2);
auto tableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName, indexName,
HNSWIndexUtils::IndexOperation::CREATE);
const auto tableID = tableEntry->getTableID();
HNSWIndexUtils::validateColumnType(*tableEntry, columnName);
auto config = HNSWIndexConfig{input->optionalParams};
const auto operation = config.skipIfExists ?
HNSWIndexUtils::IndexOperation::CREATE_IF_NOT_EXISTS :
HNSWIndexUtils::IndexOperation::CREATE;
const auto tableEntry = HNSWIndexUtils::bindTable(*context, tableName);
if (HNSWIndexUtils::validateIndexExistence(*context, tableEntry, indexName, operation)) {
return std::make_unique<CreateHNSWIndexBindData>(context, indexName, nullptr, 0, 0,
Comment thread
carminite marked this conversation as resolved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We inline comments as follows:

Suggested change
return std::make_unique<CreateHNSWIndexBindData>(context, indexName, nullptr, 0, 0,
return std::make_unique<CreateHNSWIndexBindData>(context, indexName, nullptr /*nodeTableEntry*/, 0 /*propertyID*/, 0 /*numNodes*/,

std::move(config), true);
}
const auto nodeTableEntry = tableEntry->ptrCast<catalog::NodeTableCatalogEntry>();
const auto tableID = nodeTableEntry->getTableID();
HNSWIndexUtils::validateColumnType(*nodeTableEntry, columnName);
const auto& table =
storage::StorageManager::Get(*context)->getTable(tableID)->cast<storage::NodeTable>();
auto propertyID = tableEntry->getPropertyID(columnName);
auto config = HNSWIndexConfig{input->optionalParams};
auto propertyID = nodeTableEntry->getPropertyID(columnName);
auto numNodes = table.getStats(context->getTransaction()).getTableCard();
return std::make_unique<CreateHNSWIndexBindData>(context, indexName, tableEntry, propertyID,
numNodes, std::move(config));
Expand Down Expand Up @@ -326,6 +334,9 @@ static std::string rewriteCreateHNSWQuery(main::ClientContext& context,
const TableFuncBindData& bindData) {
context.setUseInternalCatalogEntry(true /* useInternalCatalogEntry */);
const auto hnswBindData = bindData.constPtrCast<CreateHNSWIndexBindData>();
if (hnswBindData->skipAfterBind) {
return "";
}
std::string query = "BEGIN TRANSACTION;";
auto indexName = hnswBindData->indexName;
auto tableName = hnswBindData->tableEntry->getName();
Expand Down
87 changes: 87 additions & 0 deletions extension/vector/src/function/drop_all_hnsw_indexes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include "catalog/catalog.h"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need such function right now. @ray6080 can comment there as well.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include "catalog/catalog_entry/index_catalog_entry.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/hnsw_index_catalog_entry.h"
#include "common/exception/binder.h"
#include "function/hnsw_index_functions.h"
#include "function/table/bind_data.h"
#include "index/hnsw_index_utils.h"
#include "main/client_context.h"
#include "processor/execution_context.h"
#include "storage/storage_manager.h"

using namespace kuzu::function;

namespace kuzu {
namespace vector_extension {

struct DropAllHNSWIndexesBindData final : TableFuncBindData {
catalog::NodeTableCatalogEntry* tableEntry;
std::vector<std::string> indexNames;

DropAllHNSWIndexesBindData(catalog::NodeTableCatalogEntry* tableEntry,
std::vector<std::string> indexNames)
: TableFuncBindData{0}, tableEntry{tableEntry}, indexNames{std::move(indexNames)} {}

std::unique_ptr<TableFuncBindData> copy() const override {
return std::make_unique<DropAllHNSWIndexesBindData>(tableEntry, indexNames);
}
};

static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
const TableFuncBindInput* input) {
const auto tableName = input->getLiteralVal<std::string>(0);
const auto tableEntry = HNSWIndexUtils::bindTable(*context, tableName);
const auto nodeTableEntry = tableEntry->ptrCast<catalog::NodeTableCatalogEntry>();
const auto tableID = tableEntry->getTableID();
std::vector<std::string> vectorIndexes;
auto indexEntries =
catalog::Catalog::Get(*context)->getIndexEntries(context->getTransaction(), tableID);
for (auto indexEntry : indexEntries) {
if (indexEntry->getIndexType() == HNSWIndexCatalogEntry::TYPE_NAME) {
vectorIndexes.push_back(indexEntry->getIndexName());
}
}
return std::make_unique<DropAllHNSWIndexesBindData>(nodeTableEntry, vectorIndexes);
}

static std::string dropAllHNSWIndexesTables(main::ClientContext& context,
const TableFuncBindData& bindData) {
const auto dropAllHNSWIndexesBindData = bindData.constPtrCast<DropAllHNSWIndexesBindData>();
context.setUseInternalCatalogEntry(true /* useInternalCatalogEntry */);
std::string query = "";
for (const auto& indexName : dropAllHNSWIndexesBindData->indexNames) {
const auto requireNewTransaction = !context.getTransactionContext()->hasActiveTransaction();
if (requireNewTransaction) {
query += "BEGIN TRANSACTION;";
}
auto nodeTableID = dropAllHNSWIndexesBindData->tableEntry->getTableID();
query += common::stringFormat("CALL _DROP_HNSW_INDEX('{}', '{}');",
dropAllHNSWIndexesBindData->tableEntry->getName(), indexName);
query += common::stringFormat("DROP TABLE {};",
HNSWIndexUtils::getUpperGraphTableName(nodeTableID, indexName));
query += common::stringFormat("DROP TABLE {};",
HNSWIndexUtils::getLowerGraphTableName(nodeTableID, indexName));
if (requireNewTransaction) {
query += "COMMIT;";
}
}
return query;
}

function_set DropAllVectorIndexesFunction::getFunctionSet() {
function_set functionSet;
std::vector inputTypes = {common::LogicalTypeID::STRING};
auto func = std::make_unique<TableFunction>(name, inputTypes);
func->tableFunc = TableFunction::emptyTableFunc;
func->bindFunc = bindFunc;
func->initSharedStateFunc = SimpleTableFunc::initSharedState;
func->initLocalStateFunc = TableFunction::initEmptyLocalState;
func->rewriteFunc = dropAllHNSWIndexesTables;
func->canParallelFunc = [] { return false; };
functionSet.push_back(std::move(func));
return functionSet;
}

} // namespace vector_extension
} // namespace kuzu
25 changes: 19 additions & 6 deletions extension/vector/src/function/drop_hnsw_index.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "common/exception/binder.h"
#include "function/hnsw_index_functions.h"
#include "function/table/bind_data.h"
#include "index/hnsw_index_utils.h"
Expand All @@ -15,22 +16,31 @@ namespace vector_extension {
struct DropHNSWIndexBindData final : TableFuncBindData {
catalog::NodeTableCatalogEntry* tableEntry;
std::string indexName;
bool skipAfterBind;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to something along the line of skipIndexDropping.


DropHNSWIndexBindData(catalog::NodeTableCatalogEntry* tableEntry, std::string indexName)
: TableFuncBindData{0}, tableEntry{tableEntry}, indexName{std::move(indexName)} {}
DropHNSWIndexBindData(catalog::NodeTableCatalogEntry* tableEntry, std::string indexName,
bool skipAfterBind = false)
: TableFuncBindData{0}, tableEntry{tableEntry}, indexName{std::move(indexName)},
skipAfterBind{skipAfterBind} {}

std::unique_ptr<TableFuncBindData> copy() const override {
return std::make_unique<DropHNSWIndexBindData>(tableEntry, indexName);
return std::make_unique<DropHNSWIndexBindData>(tableEntry, indexName, skipAfterBind);
}
};

static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
const TableFuncBindInput* input) {
const auto tableName = input->getLiteralVal<std::string>(0);
const auto indexName = input->getLiteralVal<std::string>(1);
const auto tableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName, indexName,
HNSWIndexUtils::IndexOperation::DROP);
return std::make_unique<DropHNSWIndexBindData>(tableEntry, indexName);
auto config = DropHNSWConfig{input->optionalParams};
const auto operation = config.skipIfNotExists ? HNSWIndexUtils::IndexOperation::DROP_IF_EXISTS :
HNSWIndexUtils::IndexOperation::DROP;
const auto tableEntry = HNSWIndexUtils::bindTable(*context, tableName);
if (!HNSWIndexUtils::validateIndexExistence(*context, tableEntry, indexName, operation)) {
return std::make_unique<DropHNSWIndexBindData>(nullptr, indexName, true);
}
const auto nodeTableEntry = tableEntry->ptrCast<catalog::NodeTableCatalogEntry>();
return std::make_unique<DropHNSWIndexBindData>(nodeTableEntry, indexName);
}

static common::offset_t internalTableFunc(const TableFuncInput& input, TableFuncOutput&) {
Expand All @@ -48,6 +58,9 @@ static std::string dropHNSWIndexTables(main::ClientContext& context,
const TableFuncBindData& bindData) {
const auto dropHNSWIndexBindData = bindData.constPtrCast<DropHNSWIndexBindData>();
context.setUseInternalCatalogEntry(true /* useInternalCatalogEntry */);
if (dropHNSWIndexBindData->skipAfterBind) {
return "";
}
std::string query = "";
const auto requireNewTransaction = !context.getTransactionContext()->hasActiveTransaction();
if (requireNewTransaction) {
Expand Down
4 changes: 3 additions & 1 deletion extension/vector/src/function/query_hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ static std::unique_ptr<TableFuncBindData> bindFunc(main::ClientContext* context,
throw BinderException(
stringFormat("Cannot find table or graph named as {}.", tableOrGraphName));
}
auto nodeTableEntry = HNSWIndexUtils::bindNodeTable(*context, tableName, indexName,
auto tableEntry = HNSWIndexUtils::bindTable(*context, tableName);
(void)HNSWIndexUtils::validateIndexExistence(*context, tableEntry, indexName,
HNSWIndexUtils::IndexOperation::QUERY);
auto nodeTableEntry = tableEntry->ptrCast<NodeTableCatalogEntry>();
// Bind columns
auto columnNames = std::vector<std::string>{QueryVectorIndexFunction::nnColumnName,
QueryVectorIndexFunction::distanceColumnName};
Expand Down
14 changes: 11 additions & 3 deletions extension/vector/src/include/function/hnsw_index_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ struct CreateHNSWIndexBindData final : function::TableFuncBindData {
catalog::TableCatalogEntry* tableEntry;
common::property_id_t propertyID;
HNSWIndexConfig config;
bool skipAfterBind;
Comment thread
acquamarin marked this conversation as resolved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this to skipIndexCreation, which is more accurate.


CreateHNSWIndexBindData(main::ClientContext* context, std::string indexName,
catalog::TableCatalogEntry* tableEntry, common::property_id_t propertyID,
common::offset_t numNodes, HNSWIndexConfig config)
common::offset_t numNodes, HNSWIndexConfig config, bool skipAfterBind = false)
: TableFuncBindData{numNodes}, context{context}, indexName{std::move(indexName)},
tableEntry{tableEntry}, propertyID{propertyID}, config{std::move(config)} {}
tableEntry{tableEntry}, propertyID{propertyID}, config{std::move(config)},
skipAfterBind{skipAfterBind} {}

std::unique_ptr<TableFuncBindData> copy() const override {
return std::make_unique<CreateHNSWIndexBindData>(context, indexName, tableEntry, propertyID,
numRows, config.copy());
numRows, config.copy(), skipAfterBind);
}
};

Expand Down Expand Up @@ -141,5 +143,11 @@ struct QueryVectorIndexFunction final {
static function::function_set getFunctionSet();
};

struct DropAllVectorIndexesFunction final {
Comment thread
carminite marked this conversation as resolved.
Outdated
static constexpr const char* name = "DROP_ALL_VECTOR_INDEXES";

static function::function_set getFunctionSet();
};

} // namespace vector_extension
} // namespace kuzu
24 changes: 23 additions & 1 deletion extension/vector/src/include/index/hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ struct CacheEmbeddings {
static constexpr bool DEFAULT_VALUE = true;
};

struct SkipIfExists {
static constexpr const char* NAME = "skip_if_exists";
static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::BOOL;
static constexpr bool DEFAULT_VALUE = false;
};

struct SkipIfNotExists {
static constexpr const char* NAME = "skip_if_not_exists";
static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::BOOL;
static constexpr bool DEFAULT_VALUE = false;
};

struct BlindSearchUpSelThreshold {
static constexpr const char* NAME = "blind_search_up_sel";
static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::DOUBLE;
Expand All @@ -103,6 +115,7 @@ struct HNSWIndexConfig {
double alpha = Alpha::DEFAULT_VALUE;
int64_t efc = Efc::DEFAULT_VALUE;
bool cacheEmbeddingsColumn = CacheEmbeddings::DEFAULT_VALUE;
bool skipIfExists = SkipIfExists::DEFAULT_VALUE;

HNSWIndexConfig() = default;

Expand All @@ -119,11 +132,20 @@ struct HNSWIndexConfig {
private:
HNSWIndexConfig(const HNSWIndexConfig& other)
: mu{other.mu}, ml{other.ml}, pu{other.pu}, metric{other.metric}, alpha{other.alpha},
efc{other.efc}, cacheEmbeddingsColumn(other.cacheEmbeddingsColumn) {}
efc{other.efc}, cacheEmbeddingsColumn(other.cacheEmbeddingsColumn),
skipIfExists(other.skipIfExists) {}
Comment thread
carminite marked this conversation as resolved.
Outdated

static MetricType getMetricType(const std::string& metricName);
};

struct DropHNSWConfig {
bool skipIfNotExists = SkipIfNotExists::DEFAULT_VALUE;

DropHNSWConfig() = default;

explicit DropHNSWConfig(const function::optional_params_t& optionalParams);
};

struct QueryHNSWConfig {
int64_t efs = Efs::DEFAULT_VALUE;
double blindSearchUpSelThreshold = BlindSearchUpSelThreshold::DEFAULT_VALUE;
Expand Down
17 changes: 13 additions & 4 deletions extension/vector/src/include/index/hnsw_index_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,23 @@ concept VectorElementType = std::is_floating_point_v<T>;
using metric_func_t = std::function<double(const void*, const void*, uint32_t)>;

struct HNSWIndexUtils {
enum class KUZU_API IndexOperation { CREATE, QUERY, DROP };
enum class KUZU_API IndexOperation {
Comment thread
carminite marked this conversation as resolved.
Outdated
CREATE,
CREATE_IF_NOT_EXISTS,
QUERY,
DROP,
DROP_IF_EXISTS
};

static void validateIndexExistence(const main::ClientContext& context,
static bool indexExists(const main::ClientContext& context,
const catalog::TableCatalogEntry* tableEntry, const std::string& indexName);

static bool validateIndexExistence(const main::ClientContext& context,
const catalog::TableCatalogEntry* tableEntry, const std::string& indexName,
IndexOperation indexOperation);

static catalog::NodeTableCatalogEntry* bindNodeTable(const main::ClientContext& context,
const std::string& tableName, const std::string& indexName, IndexOperation indexOperation);
static catalog::TableCatalogEntry* bindTable(const main::ClientContext& context,
const std::string& tableName);

static void validateAutoTransaction(const main::ClientContext& context,
const std::string& funcName);
Expand Down
16 changes: 16 additions & 0 deletions extension/vector/src/index/hnsw_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ HNSWIndexConfig::HNSWIndexConfig(const function::optional_params_t& optionalPara
} else if (CacheEmbeddings::NAME == lowerCaseName) {
value.validateType(CacheEmbeddings::TYPE);
cacheEmbeddingsColumn = value.getValue<bool>();
} else if (SkipIfExists::NAME == lowerCaseName) {
value.validateType(SkipIfExists::TYPE);
skipIfExists = value.getValue<bool>();
} else {
throw common::BinderException{
common::stringFormat("Unrecognized optional parameter {} in {}.", name,
Expand Down Expand Up @@ -188,6 +191,19 @@ MetricType HNSWIndexConfig::getMetricType(const std::string& metricName) {
KU_UNREACHABLE;
}

DropHNSWConfig::DropHNSWConfig(const function::optional_params_t& optionalParams) {
for (auto& [name, value] : optionalParams) {
auto lowerCaseName = common::StringUtils::getLower(name);
if (SkipIfNotExists::NAME == lowerCaseName) {
value.validateType(SkipIfNotExists::TYPE);
skipIfNotExists = value.getValue<bool>();
} else {
throw common::BinderException{common::stringFormat(
Comment thread
carminite marked this conversation as resolved.
"Unrecognized optional parameter {} in {}.", name, QueryVectorIndexFunction::name)};
}
}
}

QueryHNSWConfig::QueryHNSWConfig(const function::optional_params_t& optionalParams) {
for (auto& [name, value] : optionalParams) {
auto lowerCaseName = common::StringUtils::getLower(name);
Expand Down
Loading