diff --git a/extension/fts/src/function/CMakeLists.txt b/extension/fts/src/function/CMakeLists.txt index 644389029dc..5e693626b51 100644 --- a/extension/fts/src/function/CMakeLists.txt +++ b/extension/fts/src/function/CMakeLists.txt @@ -1,11 +1,11 @@ +add_subdirectory(query_fts) + add_library(kuzu_fts_function OBJECT create_fts_index.cpp drop_fts_index.cpp fts_config.cpp fts_index_utils.cpp - query_fts_index.cpp - query_fts_bind_data.cpp stem.cpp tokenize.cpp) diff --git a/extension/fts/src/function/create_fts_index.cpp b/extension/fts/src/function/create_fts_index.cpp index 2062137dcb0..3ee5d1a9de5 100644 --- a/extension/fts/src/function/create_fts_index.cpp +++ b/extension/fts/src/function/create_fts_index.cpp @@ -145,6 +145,20 @@ static std::string formatStrInCypher(const std::string& input) { return result; } +static std::string createTablesForExactTermMatch(const CreateFTSBindData& bindData) { + std::string query; + auto appearsInfoTableName = + FTSUtils::getAppearsInfoTableName(bindData.tableID, bindData.indexName); + auto originalTermsTableName = + FTSUtils::getOrigTermsTableName(bindData.tableID, bindData.indexName); + query += common::stringFormat("CREATE NODE TABLE `{}`(term string, primary key(term));", + originalTermsTableName); + query += + common::stringFormat("COPY `{}` FROM (match (doc:`{}`) return distinct doc.term_origin);", + originalTermsTableName, appearsInfoTableName); + return query; +} + std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& bindData) { auto ftsBindData = bindData.constPtrCast(); auto tableID = ftsBindData->tableID; @@ -174,8 +188,9 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& // Create the terms_in_doc table which servers as a temporary table to store the // relationship between terms and docs. auto appearsInfoTableName = FTSUtils::getAppearsInfoTableName(tableID, indexName); - query += stringFormat("CREATE NODE TABLE `{}` (ID SERIAL, term string, docID INT64, primary " - "key(ID));", + query += stringFormat( + "CREATE NODE TABLE `{}` (ID SERIAL, term string, term_origin string, docID INT64, primary " + "key(ID));", appearsInfoTableName); auto tableName = ftsBindData->tableName; auto tableEntry = catalog::Catalog::Get(context)->getTableCatalogEntry( @@ -189,7 +204,7 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& "WITH t AS t1, id AS id1 " "WHERE t1 is NOT NULL AND SIZE(t1) > 0 AND " "NOT EXISTS {MATCH (s:`{}` {sw: t1})} " - "RETURN STEM(t1, '{}'), id1);", + "RETURN STEM(t1, '{}'), t1, id1);", appearsInfoTableName, tableName, FTSUtils::getTokenizeMacroName(tableID, indexName), propertyName, ftsBindData->createFTSConfig.stopWordsTableInfo.tableName, ftsBindData->createFTSConfig.stemmer); @@ -213,6 +228,11 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& "RETURN t.term, CAST(count(distinct t.docID) AS UINT64));", termsTableName, appearsInfoTableName); + // If the exact_term_match is enabled, we need to create an additional tables. + if (ftsBindData->createFTSConfig.exactTermMatch) { + query += createTablesForExactTermMatch(*ftsBindData); + } + auto appearsInTableName = FTSUtils::getAppearsInTableName(tableID, indexName); // Finally, create a terms table that records the documents in which the terms appear, along // with the frequency of each term. @@ -236,8 +256,10 @@ std::string createFTSIndexQuery(ClientContext& context, const TableFuncBindData& properties += "]"; std::string params; params += stringFormat("stemmer := '{}', ", ftsBindData->createFTSConfig.stemmer); - params += stringFormat("stopWords := '{}'", + params += stringFormat("stopWords := '{}', ", ftsBindData->createFTSConfig.stopWordsTableInfo.stopWords); + params += stringFormat("exact_term_match := {}", + ftsBindData->createFTSConfig.exactTermMatch ? "true" : "false"); query += stringFormat("CALL _CREATE_FTS_INDEX('{}', '{}', {}, {});", tableName, indexName, properties, params); query += stringFormat("RETURN 'Index {} has been created.' as result;", ftsBindData->indexName); diff --git a/extension/fts/src/function/fts_config.cpp b/extension/fts/src/function/fts_config.cpp index 44ca6905518..2c1226b4c6a 100644 --- a/extension/fts/src/function/fts_config.cpp +++ b/extension/fts/src/function/fts_config.cpp @@ -149,14 +149,17 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_ common::StringUtils::replaceAll(ignorePatternQuery, "?", ""); IgnorePattern::validate(ignorePattern); IgnorePattern::validate(ignorePatternQuery); - } else if (lowerCaseName == "tokenizer") { - value.validateType(common::LogicalTypeID::STRING); + } else if (Tokenizer::NAME == lowerCaseName) { + value.validateType(Tokenizer::TYPE); tokenizerInfo.tokenizer = common::StringUtils::getLower(value.getValue()); Tokenizer::validate(tokenizerInfo.tokenizer); } else if (lowerCaseName == "jieba_dict_dir") { value.validateType(common::LogicalTypeID::STRING); tokenizerInfo.jiebaDictDir = common::StringUtils::getLower(value.getValue()); + } else if (ExactTermMatch::NAME == lowerCaseName) { + value.validateType(ExactTermMatch::TYPE); + exactTermMatch = value.getValue(); } else { throw common::BinderException{"Unrecognized optional parameter: " + name}; } @@ -165,7 +168,8 @@ CreateFTSConfig::CreateFTSConfig(main::ClientContext& context, common::table_id_ FTSConfig CreateFTSConfig::getFTSConfig() const { return FTSConfig{stemmer, stopWordsTableInfo.tableName, stopWordsTableInfo.stopWords, - ignorePattern, ignorePatternQuery, tokenizerInfo.tokenizer, tokenizerInfo.jiebaDictDir}; + ignorePattern, ignorePatternQuery, tokenizerInfo.tokenizer, tokenizerInfo.jiebaDictDir, + exactTermMatch}; } void FTSConfig::serialize(common::Serializer& serializer) const { @@ -176,6 +180,7 @@ void FTSConfig::serialize(common::Serializer& serializer) const { serializer.serializeValue(ignorePatternQuery); serializer.serializeValue(tokenizer); serializer.serializeValue(jiebaDictDir); + serializer.serializeValue(exactTermMatch); } FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) { @@ -187,6 +192,7 @@ FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) { deserializer.deserializeValue(config.ignorePatternQuery); deserializer.deserializeValue(config.tokenizer); deserializer.deserializeValue(config.jiebaDictDir); + deserializer.deserializeValue(config.exactTermMatch); return config; } diff --git a/extension/fts/src/function/query_fts/CMakeLists.txt b/extension/fts/src/function/query_fts/CMakeLists.txt new file mode 100644 index 00000000000..93df47a9844 --- /dev/null +++ b/extension/fts/src/function/query_fts/CMakeLists.txt @@ -0,0 +1,10 @@ +add_library(kuzu_query_fts_function + OBJECT + query_fts_index.cpp + query_fts_pattern_match.cpp + query_fts_bind_data.cpp + query_fts_term_lookup.cpp) + +set(FTS_EXTENSION_OBJECT_FILES + ${FTS_EXTENSION_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/extension/fts/src/function/query_fts_bind_data.cpp b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp similarity index 58% rename from extension/fts/src/function/query_fts_bind_data.cpp rename to extension/fts/src/function/query_fts/query_fts_bind_data.cpp index 11f2752b916..4b58cefe73e 100644 --- a/extension/fts/src/function/query_fts_bind_data.cpp +++ b/extension/fts/src/function/query_fts/query_fts_bind_data.cpp @@ -1,4 +1,4 @@ -#include "function/query_fts_bind_data.h" +#include "function/query_fts/query_fts_bind_data.h" #include "binder/binder.h" #include "binder/expression/expression_util.h" @@ -42,6 +42,36 @@ void QueryFTSOptionalParams::evaluateParams(main::ClientContext* context) { topK.evaluateParam(context); } +QueryFTSBindData::QueryFTSBindData(binder::expression_vector columns, + graph::NativeGraphEntry graphEntry, std::shared_ptr docs, + std::shared_ptr query, const catalog::IndexCatalogEntry& entry, + std::unique_ptr optionalParams, common::idx_t numDocs, double avgDocLen) + : GDSBindData{std::move(columns), std::move(graphEntry), binder::expression_vector{docs}}, + query{std::move(query)}, entry{entry}, + outputTableID{output[0]->constCast().getTableIDs()[0]}, + numDocs{numDocs}, avgDocLen{avgDocLen}, + patternMatchAlgo{PatternMatchFactory::getPatternMatchAlgo( + entry.getAuxInfo().cast().config.exactTermMatch ? TermMatchType::EXACT : + TermMatchType::STEM)} { + auto& nodeExpr = output[0]->constCast(); + KU_ASSERT(nodeExpr.getNumEntries() == 1); + outputTableID = nodeExpr.getEntry(0)->getTableID(); + this->optionalParams = std::move(optionalParams); +} + +catalog::TableCatalogEntry* QueryFTSBindData::getTermsEntry(main::ClientContext& context) const { + auto catalog = catalog::Catalog::Get(context); + return catalog->getTableCatalogEntry(transaction::Transaction::Get(context), + FTSUtils::getTermsTableName(entry.getTableID(), entry.getIndexName())); +} + +catalog::TableCatalogEntry* QueryFTSBindData::getOrigTermsEntry( + main::ClientContext& context) const { + auto catalog = catalog::Catalog::Get(context); + return catalog->getTableCatalogEntry(transaction::Transaction::Get(context), + FTSUtils::getOrigTermsTableName(entry.getTableID(), entry.getIndexName())); +} + std::vector QueryFTSBindData::getQueryTerms(main::ClientContext& context) const { auto queryInStr = ExpressionUtil::evaluateLiteral(&context, query, LogicalType::STRING()); diff --git a/extension/fts/src/function/query_fts_index.cpp b/extension/fts/src/function/query_fts/query_fts_index.cpp similarity index 83% rename from extension/fts/src/function/query_fts_index.cpp rename to extension/fts/src/function/query_fts/query_fts_index.cpp index f663d93fac0..bf86407f2f9 100644 --- a/extension/fts/src/function/query_fts_index.cpp +++ b/extension/fts/src/function/query_fts/query_fts_index.cpp @@ -1,4 +1,4 @@ -#include "function/query_fts_index.h" +#include "function/query_fts/query_fts_index.h" #include @@ -11,16 +11,16 @@ #include "common/types/internal_id_util.h" #include "function/fts_index_utils.h" #include "function/gds/gds_utils.h" -#include "function/query_fts_bind_data.h" +#include "function/query_fts/query_fts_bind_data.h" +#include "function/query_fts/query_fts_pattern_match.h" +#include "function/query_fts/query_fts_term_lookup.h" #include "graph/on_disk_graph.h" #include "index/fts_index.h" #include "planner/operator/logical_hash_join.h" #include "planner/operator/logical_table_function_call.h" #include "planner/planner.h" #include "processor/execution_context.h" -#include "re2.h" #include "storage/storage_manager.h" -#include "storage/table/node_table.h" #include "utils/fts_utils.h" namespace kuzu { @@ -242,71 +242,14 @@ class QFTSVertexCompute final : public VertexCompute { std::unique_ptr writer; }; -using VCQueryTerm = std::variant>; -class MatchTermsVertexCompute final : public VertexCompute { -public: - explicit MatchTermsVertexCompute(std::unordered_map& resDfs, - std::vector& queryTerms) - : resDfs{resDfs}, queryTerms{queryTerms} {} - void vertexCompute(const graph::VertexScanState::Chunk& chunk) override { - auto terms = chunk.getProperties(0); - auto dfs = chunk.getProperties(1); - auto nodeIds = chunk.getNodeIDs(); - for (auto& queryTerm : queryTerms) { - // queryTerm.index() is 0 for string, 1 for unique_ptr - if (queryTerm.index() == 0) { - std::string& queryString = std::get<0>(queryTerm); - for (auto i = 0u; i < chunk.size(); ++i) { - if (queryString == terms[i].getAsString()) { - resDfs[nodeIds[i].offset] = dfs[i]; - } - } - } else { - RE2& regex = *std::get<1>(queryTerm); - for (auto i = 0u; i < chunk.size(); ++i) { - if (RE2::FullMatch(terms[i].getAsString(), regex)) { - resDfs[nodeIds[i].offset] = dfs[i]; - } - } - } - } - } - std::unique_ptr copy() override { - return std::make_unique(resDfs, queryTerms); - } - -private: - std::unordered_map& resDfs; - std::vector& queryTerms; -}; - static constexpr char SCORE_PROP_NAME[] = "score"; -static constexpr char DOC_FREQUENCY_PROP_NAME[] = "df"; static constexpr char TERM_FREQUENCY_PROP_NAME[] = "tf"; static constexpr char DOC_LEN_PROP_NAME[] = "len"; static constexpr char DOC_ID_PROP_NAME[] = "docID"; static std::unordered_map getDFs(main::ClientContext& context, processor::ExecutionContext* executionContext, graph::Graph* graph, - catalog::TableCatalogEntry* termsEntry, std::vector& queryTerms) { - auto storageManager = StorageManager::Get(context); - auto tableID = termsEntry->getTableID(); - auto& termsNodeTable = storageManager->getTable(tableID)->cast(); - auto tx = transaction::Transaction::Get(context); - auto dfColumnID = termsEntry->getColumnID(DOC_FREQUENCY_PROP_NAME); - std::vector vectorTypes; - vectorTypes.push_back(LogicalType::INTERNAL_ID()); - vectorTypes.push_back(LogicalType::UINT64()); - auto dataChunk = Table::constructDataChunk(MemoryManager::Get(context), std::move(vectorTypes)); - dataChunk.state->getSelVectorUnsafe().setSelSize(1); - auto nodeIDVector = &dataChunk.getValueVectorMutable(0); - auto dfVector = &dataChunk.getValueVectorMutable(1); - auto termsVector = ValueVector(LogicalType::STRING(), MemoryManager::Get(context)); - termsVector.state = dataChunk.state; - auto nodeTableScanState = - NodeTableScanState(nodeIDVector, std::vector{dfVector}, dataChunk.state); - nodeTableScanState.setToTable(transaction::Transaction::Get(context), &termsNodeTable, - {dfColumnID}, {}); + const QueryFTSBindData& bindData, std::vector& queryTerms) { std::unordered_map dfs; std::vector vcQueryTerms; vcQueryTerms.reserve(queryTerms.size()); @@ -323,22 +266,17 @@ static std::unordered_map getDFs(main::ClientContext& contex vcQueryTerms.emplace_back(std::in_place_type, queryTerm); } } + if (hasWildcardQueryTerm) { - auto matchVc = MatchTermsVertexCompute{dfs, vcQueryTerms}; - GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchVc, - termsEntry, std::vector{"term", DOC_FREQUENCY_PROP_NAME}); + bindData.patternMatchAlgo(dfs, vcQueryTerms, executionContext, graph, bindData); } else { + TermsDFLookup termsDFLookup{bindData.getTermsEntry(context), context}; for (auto& queryTerm : queryTerms) { - termsVector.setValue(0, queryTerm); - offset_t offset = 0; - if (!termsNodeTable.lookupPK(tx, &termsVector, 0 /* vectorPos */, offset)) { + auto offsetDFPair = termsDFLookup.lookupTermDF(queryTerm); + if (offsetDFPair.first == INVALID_OFFSET) { continue; } - auto nodeID = nodeID_t{offset, tableID}; - nodeIDVector->setValue(0, nodeID); - termsNodeTable.initScanState(tx, nodeTableScanState, tableID, offset); - [[maybe_unused]] auto res = termsNodeTable.lookup(tx, nodeTableScanState); - dfs.emplace(offset, dfVector->getValue(0)); + dfs.emplace(offsetDFPair); } } return dfs; @@ -381,7 +319,7 @@ static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) { } auto termsEntry = graphEntry->nodeInfos[0].entry; auto queryTerms = qFTSBindData.getQueryTerms(clientContext); - auto dfs = getDFs(clientContext, input.context, graph, termsEntry, queryTerms); + auto dfs = getDFs(clientContext, input.context, graph, qFTSBindData, queryTerms); // Do edge compute to extend terms -> docs and save the term frequency and document frequency // for each term-doc pair. The reason why we store the term frequency and document frequency // is that: we need the `len` property from the docs table which is only available during the @@ -444,7 +382,6 @@ static std::unique_ptr bindFunc(main::ClientContext* context, auto inputTableName = getParamVal(*input, 0); auto indexName = getParamVal(*input, 1); auto query = input->getParam(2); - auto tableEntry = FTSIndexUtils::bindNodeTable(*context, inputTableName, indexName, FTSIndexUtils::IndexOperation::QUERY); auto catalog = catalog::Catalog::Get(*context); @@ -459,7 +396,12 @@ static std::unique_ptr bindFunc(main::ClientContext* context, FTSUtils::getDocsTableName(tableEntry->getTableID(), indexName)); auto appearsInEntry = catalog->getTableCatalogEntry(transaction, FTSUtils::getAppearsInTableName(tableEntry->getTableID(), indexName)); - auto graphEntry = graph::NativeGraphEntry({termsEntry, docsEntry}, {appearsInEntry}); + std::vector nodeEntries{termsEntry, docsEntry}; + if (ftsIndexEntry->getAuxInfo().cast().config.exactTermMatch) { + nodeEntries.push_back(catalog->getTableCatalogEntry(transaction, + FTSUtils::getOrigTermsTableName(tableEntry->getTableID(), indexName))); + } + auto graphEntry = graph::NativeGraphEntry(std::move(nodeEntries), {appearsInEntry}); expression_vector columns; auto& docsNode = nodeOutput->constCast(); diff --git a/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp b/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp new file mode 100644 index 00000000000..6ba01068086 --- /dev/null +++ b/extension/fts/src/function/query_fts/query_fts_pattern_match.cpp @@ -0,0 +1,138 @@ +#include "function/query_fts/query_fts_pattern_match.h" + +#include + +#include "catalog/fts_index_catalog_entry.h" +#include "function/gds/compute.h" +#include "function/gds/gds_utils.h" +#include "function/query_fts/query_fts_bind_data.h" +#include "function/query_fts/query_fts_term_lookup.h" +#include "libstemmer.h" +#include "storage/storage_manager.h" +#include "utils/fts_utils.h" + +using namespace kuzu::function; +using namespace kuzu::processor; + +namespace kuzu { +namespace fts_extension { + +class MatchTermVertexCompute : public function::VertexCompute { +public: + MatchTermVertexCompute(std::vector& queryTerms, + std::unordered_map& resDfs) + : queryTerms{queryTerms}, resDfs{resDfs} {} + + virtual void handleMatchedTerm(uint64_t itr, const graph::VertexScanState::Chunk& chunk) = 0; + + void vertexCompute(const graph::VertexScanState::Chunk& chunk) override { + auto terms = chunk.getProperties(0); + for (auto& queryTerm : queryTerms) { + // queryTerm.index() is 0 for string, 1 for unique_ptr + if (queryTerm.index() == 0) { + std::string& queryString = std::get<0>(queryTerm); + for (auto i = 0u; i < chunk.size(); ++i) { + if (queryString == terms[i].getAsString()) { + handleMatchedTerm(i, chunk); + } + } + } else { + RE2& regex = *std::get<1>(queryTerm); + for (auto i = 0u; i < chunk.size(); ++i) { + if (RE2::FullMatch(terms[i].getAsString(), regex)) { + handleMatchedTerm(i, chunk); + } + } + } + } + } + +protected: + std::vector& queryTerms; + std::unordered_map& resDfs; +}; + +class StemTermMatchVertexCompute final : public MatchTermVertexCompute { +public: + explicit StemTermMatchVertexCompute(std::unordered_map& resDfs, + std::vector& queryTerms) + : MatchTermVertexCompute{queryTerms, resDfs} {} + + void handleMatchedTerm(uint64_t itr, const graph::VertexScanState::Chunk& chunk) override { + auto dfs = chunk.getProperties(1); + auto nodeIds = chunk.getNodeIDs(); + resDfs[nodeIds[itr].offset] = dfs[itr]; + } + + std::unique_ptr copy() override { + return std::make_unique(resDfs, queryTerms); + } +}; + +class ExactTermMatchVertexCompute final : public MatchTermVertexCompute { +public: + ExactTermMatchVertexCompute(std::unordered_map& resDfs, + std::vector& queryTerms, const QueryFTSBindData& bindData, + main::ClientContext& context) + : MatchTermVertexCompute{queryTerms, resDfs}, + sbStemmer{sb_stemmer_new( + reinterpret_cast( + bindData.entry.getAuxInfo().cast().config.stemmer.c_str()), + "UTF_8")}, + bindData{bindData}, context{context}, + termsDFLookup{bindData.getTermsEntry(context), context} {} + + ~ExactTermMatchVertexCompute() override { sb_stemmer_delete(sbStemmer); } + + void handleMatchedTerm(uint64_t itr, const graph::VertexScanState::Chunk& chunk) override { + auto term = chunk.getProperties(0)[itr]; + auto stemData = sb_stemmer_stem(sbStemmer, + reinterpret_cast(term.getData()), term.len); + auto result = termsDFLookup.lookupTermDF(reinterpret_cast(stemData)); + KU_ASSERT(result.first != common::INVALID_OFFSET); + resDfs.insert(result); + } + + std::unique_ptr copy() override { + return std::make_unique(resDfs, queryTerms, bindData, context); + } + +private: + sb_stemmer* sbStemmer; + const QueryFTSBindData& bindData; + main::ClientContext& context; + TermsDFLookup termsDFLookup; +}; + +static void stemTermMatch(std::unordered_map& dfs, + std::vector& vcQueryTerms, ExecutionContext* executionContext, graph::Graph* graph, + const QueryFTSBindData& bindData) { + auto matchVc = StemTermMatchVertexCompute{dfs, vcQueryTerms}; + GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchVc, + bindData.getTermsEntry(*executionContext->clientContext), + std::vector{"term", TermsDFLookup::DOC_FREQUENCY_PROP_NAME}); +} + +static void exactTermMatch(std::unordered_map& dfs, + std::vector& vcQueryTerms, ExecutionContext* executionContext, graph::Graph* graph, + const QueryFTSBindData& bindData) { + auto matchOrigTermVc = + ExactTermMatchVertexCompute{dfs, vcQueryTerms, bindData, *executionContext->clientContext}; + GDSUtils::runVertexCompute(executionContext, GDSDensityState::DENSE, graph, matchOrigTermVc, + bindData.getOrigTermsEntry(*executionContext->clientContext), + std::vector{"term"}); +} + +pattern_match_algo PatternMatchFactory::getPatternMatchAlgo(TermMatchType termMatchType) { + switch (termMatchType) { + case TermMatchType::EXACT: + return exactTermMatch; + case TermMatchType::STEM: + return stemTermMatch; + default: + KU_UNREACHABLE; + } +} + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/function/query_fts/query_fts_term_lookup.cpp b/extension/fts/src/function/query_fts/query_fts_term_lookup.cpp new file mode 100644 index 00000000000..8c072e0f2e6 --- /dev/null +++ b/extension/fts/src/function/query_fts/query_fts_term_lookup.cpp @@ -0,0 +1,44 @@ +#include "function/query_fts/query_fts_term_lookup.h" + +#include "storage/storage_manager.h" +#include "transaction/transaction.h" + +namespace kuzu { +namespace fts_extension { + +using namespace kuzu::common; +using namespace kuzu::catalog; +using namespace kuzu::main; +using namespace kuzu::storage; +using namespace kuzu::transaction; + +TermsDFLookup::TermsDFLookup(TableCatalogEntry* termsEntry, ClientContext& context) + : dataChunkState{DataChunkState::getSingleValueDataChunkState()}, + termsVector{LogicalType::STRING(), MemoryManager::Get(context)}, + nodeIDVector{LogicalType::INTERNAL_ID()}, dfVector{LogicalType::UINT64()}, + termsTable{ + StorageManager::Get(context)->getTable(termsEntry->getTableID())->cast()}, + nodeTableScanState{&nodeIDVector, std::vector{&dfVector}, dataChunkState}, + dfColumnID{termsEntry->getColumnID(DOC_FREQUENCY_PROP_NAME)}, trx{Transaction::Get(context)} { + termsVector.state = dataChunkState; + nodeIDVector.state = dataChunkState; + dfVector.state = dataChunkState; + nodeTableScanState.setToTable(transaction::Transaction::Get(context), &termsTable, {dfColumnID}, + {}); +} + +std::pair TermsDFLookup::lookupTermDF(const std::string& term) { + termsVector.setValue(0, term); + offset_t offset = 0; + if (!termsTable.lookupPK(trx, &termsVector, 0 /* vectorPos */, offset)) { + return {INVALID_OFFSET, UINT64_MAX}; + } + auto nodeID = nodeID_t{offset, termsTable.getTableID()}; + nodeIDVector.setValue(0, nodeID); + termsTable.initScanState(trx, nodeTableScanState, termsTable.getTableID(), offset); + termsTable.lookup(trx, nodeTableScanState); + return {offset, dfVector.getValue(0)}; +} + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/include/function/fts_config.h b/extension/fts/src/include/function/fts_config.h index 2a1d8d2628b..1fcd2476510 100644 --- a/extension/fts/src/include/function/fts_config.h +++ b/extension/fts/src/include/function/fts_config.h @@ -4,6 +4,7 @@ #include "common/types/types.h" #include "function/table/bind_input.h" +#include "function/table/optional_params.h" namespace kuzu { namespace fts_extension { @@ -16,6 +17,12 @@ struct Stemmer { static void validate(const std::string& stemmer); }; +struct ExactTermMatch { + static constexpr const char* NAME = "exact_term_match"; + static constexpr common::LogicalTypeID TYPE = common::LogicalTypeID::BOOL; + static constexpr bool DEFAULT_VALUE = false; +}; + enum class StopWordsSource : uint8_t { FILE = 0, TABLE = 1, @@ -78,6 +85,7 @@ struct CreateFTSConfig { std::string ignorePattern = IgnorePattern::DEFAULT_VALUE; std::string ignorePatternQuery = IgnorePattern::DEFAULT_VALUE_QUERY; TokenizerInfo tokenizerInfo; + bool exactTermMatch = ExactTermMatch::DEFAULT_VALUE; CreateFTSConfig() = default; CreateFTSConfig(main::ClientContext& context, common::table_id_t tableID, @@ -96,15 +104,16 @@ struct FTSConfig { std::string ignorePatternQuery = ""; std::string tokenizer = ""; std::string jiebaDictDir = ""; + bool exactTermMatch = false; FTSConfig() = default; FTSConfig(std::string stemmer, std::string stopWordsTableName, std::string stopWordsSource, std::string ignorePattern, std::string ignorePatternQuery, std::string tokenizer, - std::string jiebaDictDir) + std::string jiebaDictDir, bool exactTermMatch) : stemmer{std::move(stemmer)}, stopWordsTableName{std::move(stopWordsTableName)}, stopWordsSource{std::move(stopWordsSource)}, ignorePattern{std::move(ignorePattern)}, ignorePatternQuery{std::move(ignorePatternQuery)}, tokenizer{std::move(tokenizer)}, - jiebaDictDir{std::move(jiebaDictDir)} {} + jiebaDictDir{std::move(jiebaDictDir)}, exactTermMatch{exactTermMatch} {} void serialize(common::Serializer& serializer) const; diff --git a/extension/fts/src/include/function/query_fts_bind_data.h b/extension/fts/src/include/function/query_fts/query_fts_bind_data.h similarity index 77% rename from extension/fts/src/include/function/query_fts_bind_data.h rename to extension/fts/src/include/function/query_fts/query_fts_bind_data.h index 716fbcdc780..f92585d4781 100644 --- a/extension/fts/src/include/function/query_fts_bind_data.h +++ b/extension/fts/src/include/function/query_fts/query_fts_bind_data.h @@ -4,6 +4,7 @@ #include "catalog/catalog_entry/index_catalog_entry.h" #include "function/fts_config.h" #include "function/gds/gds.h" +#include "function/query_fts/query_fts_pattern_match.h" namespace kuzu { namespace fts_extension { @@ -35,24 +36,22 @@ struct QueryFTSBindData final : public function::GDSBindData { common::table_id_t outputTableID; common::idx_t numDocs; double avgDocLen; + pattern_match_algo patternMatchAlgo; QueryFTSBindData(binder::expression_vector columns, graph::NativeGraphEntry graphEntry, std::shared_ptr docs, std::shared_ptr query, const catalog::IndexCatalogEntry& entry, std::unique_ptr optionalParams, common::idx_t numDocs, - double avgDocLen) - : GDSBindData{std::move(columns), std::move(graphEntry), binder::expression_vector{docs}}, - query{std::move(query)}, entry{entry}, - outputTableID{output[0]->constCast().getTableIDs()[0]}, - numDocs{numDocs}, avgDocLen{avgDocLen} { - auto& nodeExpr = output[0]->constCast(); - KU_ASSERT(nodeExpr.getNumEntries() == 1); - outputTableID = nodeExpr.getEntry(0)->getTableID(); - this->optionalParams = std::move(optionalParams); - } + double avgDocLen); + QueryFTSBindData(const QueryFTSBindData& other) : GDSBindData{other}, query{other.query}, entry{other.entry}, - outputTableID{other.outputTableID}, numDocs{other.numDocs}, avgDocLen{other.avgDocLen} {} + outputTableID{other.outputTableID}, numDocs{other.numDocs}, avgDocLen{other.avgDocLen}, + patternMatchAlgo{other.patternMatchAlgo} {} + + catalog::TableCatalogEntry* getTermsEntry(main::ClientContext& context) const; + + catalog::TableCatalogEntry* getOrigTermsEntry(main::ClientContext& context) const; std::vector getQueryTerms(main::ClientContext& context) const; diff --git a/extension/fts/src/include/function/query_fts_index.h b/extension/fts/src/include/function/query_fts/query_fts_index.h similarity index 100% rename from extension/fts/src/include/function/query_fts_index.h rename to extension/fts/src/include/function/query_fts/query_fts_index.h diff --git a/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h new file mode 100644 index 00000000000..2a2c1040ff0 --- /dev/null +++ b/extension/fts/src/include/function/query_fts/query_fts_pattern_match.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +#include "graph/graph.h" +#include "processor/execution_context.h" +#include "re2.h" + +namespace kuzu { +namespace fts_extension { + +struct FTSConfig; +struct QueryFTSBindData; + +using VCQueryTerm = std::variant>; + +using pattern_match_algo = std::function& dfs, + std::vector& vcQueryTerms, processor::ExecutionContext* executionContext, + graph::Graph* graph, const QueryFTSBindData& bindData)>; + +enum class TermMatchType : uint8_t { + STEM = 0, + EXACT = 1, +}; + +class PatternMatchFactory { +public: + static pattern_match_algo getPatternMatchAlgo(TermMatchType termMatchType); +}; + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/include/function/query_fts/query_fts_term_lookup.h b/extension/fts/src/include/function/query_fts/query_fts_term_lookup.h new file mode 100644 index 00000000000..f713ec89359 --- /dev/null +++ b/extension/fts/src/include/function/query_fts/query_fts_term_lookup.h @@ -0,0 +1,32 @@ +#pragma once + +#include "catalog/catalog_entry/table_catalog_entry.h" +#include "main/client_context.h" +#include "storage/buffer_manager/memory_manager.h" +#include "storage/table/node_table.h" + +namespace kuzu { +namespace fts_extension { + +class TermsDFLookup { +public: + static constexpr char DOC_FREQUENCY_PROP_NAME[] = "df"; + +public: + TermsDFLookup(catalog::TableCatalogEntry* termsEntry, main::ClientContext& context); + + std::pair lookupTermDF(const std::string& term); + +private: + std::shared_ptr dataChunkState; + common::ValueVector termsVector; + common::ValueVector nodeIDVector; + common::ValueVector dfVector; + storage::NodeTable& termsTable; + storage::NodeTableScanState nodeTableScanState; + common::column_id_t dfColumnID; + transaction::Transaction* trx; +}; + +} // namespace fts_extension +} // namespace kuzu diff --git a/extension/fts/src/include/utils/fts_utils.h b/extension/fts/src/include/utils/fts_utils.h index 7649733745f..709530b8dc0 100644 --- a/extension/fts/src/include/utils/fts_utils.h +++ b/extension/fts/src/include/utils/fts_utils.h @@ -48,6 +48,17 @@ struct FTSUtils { return common::stringFormat("{}_terms", getInternalTablePrefix(tableID, indexName)); } + static std::string getOrigTermsTableName(common::table_id_t tableID, + const std::string& indexName) { + return common::stringFormat("{}_orig_terms", getInternalTablePrefix(tableID, indexName)); + } + + static std::string getOrigTermsRelTableName(common::table_id_t tableID, + const std::string& indexName) { + return common::stringFormat("{}_orig_terms_rel", + getInternalTablePrefix(tableID, indexName)); + } + static std::string getAppearsInTableName(common::table_id_t tableID, const std::string& indexName) { return common::stringFormat("{}_appears_in", getInternalTablePrefix(tableID, indexName)); diff --git a/extension/fts/src/main/fts_extension.cpp b/extension/fts/src/main/fts_extension.cpp index edcbf59d6a9..201bb8f7833 100644 --- a/extension/fts/src/main/fts_extension.cpp +++ b/extension/fts/src/main/fts_extension.cpp @@ -4,7 +4,7 @@ #include "catalog/fts_index_catalog_entry.h" #include "function/create_fts_index.h" #include "function/drop_fts_index.h" -#include "function/query_fts_index.h" +#include "function/query_fts/query_fts_index.h" #include "function/stem.h" #include "function/tokenize.h" #include "index/fts_index.h" diff --git a/extension/fts/test/test_files/wildcard.test b/extension/fts/test/test_files/wildcard.test index b14ab52d8b5..72743c7cd90 100644 --- a/extension/fts/test/test_files/wildcard.test +++ b/extension/fts/test/test_files/wildcard.test @@ -51,3 +51,34 @@ Abcdefg|This book is a test?ax*alphabetical? ---- 2 Echoes of the Past|A deep dive into the history of ancient civilizations. Computers|The hiory*?story*a?b?c of computing + +-CASE exact_term_match +-LOAD_DYNAMIC_EXTENSION fts +-STATEMENT CREATE NODE TABLE news (content string, primary key(content)); +---- ok +-STATEMENT create (n:news {content: "alice is a canadian runner"}) +---- ok +-STATEMENT create (n:news {content: "carol is running in the playground"}) +---- ok +-STATEMENT CALL CREATE_FTS_INDEX('news', 'news_index_0', ['content'], exact_term_match := FALSE); +---- ok +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_0', 'runn*') RETURN node.content, score order by score +---- 1 +alice is a canadian runner|0.301030 +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_0', 'runn?ng') RETURN node.content, score order by score +---- 0 +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_0', 'runne?') RETURN node.content, score order by score +---- 1 +alice is a canadian runner|0.301030 +-STATEMENT CALL CREATE_FTS_INDEX('news', 'news_index_1', ['content'], exact_term_match := TRUE); +---- ok +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_1', 'runn*') RETURN node.content, score order by score +---- 2 +alice is a canadian runner|0.301030 +carol is running in the playground|0.301030 +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_1', 'runn?ng') RETURN node.content, score order by score +---- 1 +carol is running in the playground|0.301030 +-STATEMENT CALL QUERY_FTS_INDEX('news', 'news_index_1', 'runne?') RETURN node.content, score order by score +---- 1 +alice is a canadian runner|0.301030