Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions astrbot/core/knowledge_base/chunking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from .base import BaseChunker
from .fixed_size import FixedSizeChunker
from .markdown import MarkdownChunker

__all__ = [
"BaseChunker",
"FixedSizeChunker",
"MarkdownChunker",
]
351 changes: 351 additions & 0 deletions astrbot/core/knowledge_base/chunking/markdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
"""Markdown 感知分块器

根据 Markdown 标题层级结构进行分块,保持每个章节的语义完整性。
对于超过 chunk_size 的章节,内部使用递归字符分割。
"""

import re
from dataclasses import dataclass

from .base import BaseChunker
from .recursive import RecursiveCharacterChunker


@dataclass
class _Section:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider refactoring heading-only section handling and short-chunk merging into simpler, single-pass flows to reduce state and make the markdown chunker easier to reason about.

You can reduce complexity in two focused spots without changing behavior: heading-only handling and short‑chunk merging.


1. Remove has_body + _merge_heading_only_chunks two‑phase handling

Right now:

  • _parse_sections packs heading line + body into Section.text and sets has_body.
  • _sections_to_chunks propagates has_body.
  • _merge_heading_only_chunks then reinterprets that flag.

You can instead keep heading‑only sections out of the chunk list in the first place and prepend them directly to the next section’s text. That lets you:

  • Drop has_body from _Section
  • Make _sections_to_chunks return plain list[str]
  • Inline the “heading‑only merging” logic where chunks are produced

Example of a narrower refactor:

@dataclass
class _Section:
    heading_path: list[str]
    heading_title: str  # current heading title
    body: str           # body WITHOUT heading line

async def _sections_to_chunks(
    self, sections: list[_Section], chunk_size: int, chunk_overlap: int
) -> list[str]:
    chunks: list[str] = []
    pending_heading_only: list[str] = []

    for section in sections:
        # heading-only section: accumulate and continue
        if not section.body.strip():
            pending_heading_only.append(section.heading_title)
            continue

        # build effective heading_path including any pending heading-only titles
        heading_path = section.heading_path + pending_heading_only + [section.heading_title]
        pending_heading_only = []

        context_prefix = self._build_context_prefix(heading_path)
        full_text = context_prefix + section.body

        if len(full_text) <= chunk_size:
            chunks.append(full_text.strip())
        else:
            prefix_len = self._estimate_prefix_length(heading_path)
            effective_chunk_size = max(chunk_size // 4, chunk_size - prefix_len)

            sub_chunks = await self._fallback_chunker.chunk(
                section.body,
                chunk_size=effective_chunk_size,
                chunk_overlap=chunk_overlap,
            )
            for i, sub_chunk in enumerate(sub_chunks):
                chunk_text = self._apply_heading_context(
                    heading_path, sub_chunk, is_continuation=(i > 0)
                )
                chunks.append(chunk_text.strip())

    # optional: if trailing heading-only sections exist, attach to last chunk
    if pending_heading_only and chunks:
        tail = "\n\n" + "\n\n".join(pending_heading_only)
        if len(chunks[-1] + tail) <= chunk_size:
            chunks[-1] = (chunks[-1] + tail).strip()
        else:
            chunks.extend(pending_heading_only)

    return chunks

Then you can simplify chunk() to skip _merge_heading_only_chunks entirely:

raw_chunks = await self._sections_to_chunks(sections, chunk_size, chunk_overlap)
merged = self._merge_short_chunks(raw_chunks, chunk_size)

And _Section no longer needs has_body.

This keeps all existing behaviors (headings merged into nearest body) but removes the boolean flag, the two‑phase pipeline, and _merge_heading_only_chunks.


2. Simplify _merge_short_chunks control flow

The current version has a buf with branchy updates and conditional final.append calls that are hard to reason about. You can keep semantics but use a single “current chunk” accumulator.

Pattern:

  • Maintain current_chunk (string or None).
  • For each c:
    • Try to merge into current_chunk if under chunk_size.
    • Ensure that anything shorter than min_chunk_size is merged forward if possible.

A simpler outline:

def _merge_short_chunks(self, chunks: list[str], chunk_size: int) -> list[str]:
    if self.min_chunk_size <= 0 or len(chunks) <= 1:
        return chunks

    result: list[str] = []
    current: str | None = None

    for c in chunks:
        if current is None:
            current = c
            continue

        combined = current + "\n\n" + c
        # prefer merging if either side is too small and we stay within limit
        if (len(current) < self.min_chunk_size or len(c) < self.min_chunk_size) and len(combined) <= chunk_size:
            current = combined
        else:
            result.append(current)
            current = c

    if current is not None:
        # optional last merge if trailing small chunk fits into previous
        if result and len(current) < self.min_chunk_size:
            candidate = result[-1] + "\n\n" + current
            if len(candidate) <= chunk_size:
                result[-1] = candidate
                return result
        result.append(current)

    return result

This keeps all logic in one place:

  • “Small” is defined once (len(c) < min_chunk_size).
  • Merge decision is explicit (len(combined) <= chunk_size).
  • No dual role of buf vs final[-1].

You can adjust the merge heuristics to exactly mirror current behavior, but the structure will remain much easier to read and test.

"""解析后的 Markdown 章节"""

heading_path: list[str]
text: str
has_body: bool


class MarkdownChunker(BaseChunker):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
"""Markdown 感知分块器

按照 Markdown 标题层级切分文档,每个章节作为独立的 chunk。
如果某个章节内容超过 chunk_size,则在该章节内部进行递归分割。
子章节可选继承父级标题作为上下文前缀。
"""

def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 50,
include_heading_context: bool = True,
max_heading_depth: int = 4,
min_chunk_size: int = 0,
continuation_prefix: str = "...",
) -> None:
"""初始化 Markdown 分块器

Args:
chunk_size: 每个 chunk 的最大字符数
chunk_overlap: 递归分割时的重叠字符数
include_heading_context: 是否在子章节 chunk 前附加父级标题路径
max_heading_depth: 最大识别的标题深度 (1-6)
min_chunk_size: 最小 chunk 大小,低于此值的相邻同级 chunk 会被合并
continuation_prefix: 续接 chunk 的前缀标记(默认 "...")

"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.include_heading_context = include_heading_context
# 限制 max_heading_depth 在 1-6 之间,防止无效值导致正则错误
self.max_heading_depth = max(1, min(int(max_heading_depth), 6))
self.min_chunk_size = min_chunk_size
self.continuation_prefix = continuation_prefix
self._fallback_chunker = RecursiveCharacterChunker(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)

async def chunk(self, text: str, **kwargs) -> list[str]:
"""按 Markdown 标题层级分块

Args:
text: Markdown 格式的输入文本
chunk_size: 覆盖默认的 chunk 大小
chunk_overlap: 覆盖默认的重叠大小

Returns:
list[str]: 分块后的文本列表

"""
if not text or not text.strip():
return []

chunk_size = kwargs.get("chunk_size", self.chunk_size)
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)

# 解析 Markdown 结构
sections = self._parse_sections(text)

if not sections:
# 没有识别到标题结构,回退到递归分割
return await self._fallback_chunker.chunk(
text, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)

# 将 sections 转换为 raw chunks
raw_chunks = await self._sections_to_chunks(sections, chunk_size, chunk_overlap)

# 合并纯标题节到下一个有内容的 chunk
merged = self._merge_heading_only_chunks(raw_chunks, chunk_size)

# 合并过短的相邻 chunk
merged = self._merge_short_chunks(merged, chunk_size)

return merged

def _estimate_prefix_length(self, heading_path: list[str]) -> int:
"""估算标题上下文前缀的最大长度(用于扣除子块可用空间)"""
if not self.include_heading_context or not heading_path:
return 0
title = " > ".join(heading_path)
# 续接前缀格式: "{continuation_prefix} {title}\n\n"
continuation = f"{self.continuation_prefix} {title}\n\n"
return len(continuation)

async def _sections_to_chunks(
self, sections: list[_Section], chunk_size: int, chunk_overlap: int
) -> list[tuple[str, bool]]:
"""将解析后的 sections 转换为 (chunk_text, has_body) 列表"""
raw_chunks: list[tuple[str, bool]] = []

for section in sections:
section_text = section.text
heading_path = section.heading_path
has_body = section.has_body

# 构建带上下文的文本
context_prefix = self._build_context_prefix(heading_path)
full_text = context_prefix + section_text

if len(full_text) <= chunk_size:
raw_chunks.append((full_text.strip(), has_body))
else:
# 章节过长,内部递归分割
# 扣除前缀长度,确保添加前缀后不超过 chunk_size
prefix_len = self._estimate_prefix_length(heading_path)
effective_chunk_size = max(
Comment on lines +129 to +131
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Effective sub-chunk size can exceed the configured chunk_size when heading prefixes are very long.

When prefix_len > chunk_size, effective_chunk_size becomes chunk_size // 4 and is used to split section_text, but _apply_heading_context later prepends the full prefix without re-checking the size. This can yield chunks where len(prefix) + len(sub_chunk) > chunk_size, breaking the size guarantee.

To preserve the chunk_size invariant, you could either:

  • Compute effective_chunk_size as max(1, chunk_size - prefix_len) (and explicitly handle very small values), or
  • Enforce chunk_size after prefixing by trimming or further splitting so len(chunk_text) <= chunk_size.

chunk_size // 4, chunk_size - prefix_len
)

sub_chunks = await self._fallback_chunker.chunk(
section_text,
chunk_size=effective_chunk_size,
chunk_overlap=chunk_overlap,
)
Comment thread
Loagaeth marked this conversation as resolved.
for i, sub_chunk in enumerate(sub_chunks):
chunk_text = self._apply_heading_context(
heading_path, sub_chunk, is_continuation=(i > 0)
)
raw_chunks.append((chunk_text, True))

return raw_chunks

def _build_context_prefix(self, heading_path: list[str]) -> str:
"""构建标题路径前缀"""
if self.include_heading_context and heading_path:
return " > ".join(heading_path) + "\n\n"
return ""

def _apply_heading_context(
self, heading_path: list[str], content: str, is_continuation: bool
) -> str:
"""为 chunk 内容添加标题上下文"""
if not self.include_heading_context or not heading_path:
return content.strip()

title = " > ".join(heading_path)
if is_continuation:
return f"{self.continuation_prefix} {title}\n\n{content}".strip()
return f"{title}\n\n{content}".strip()

def _merge_heading_only_chunks(
self, raw_chunks: list[tuple[str, bool]], chunk_size: int
) -> list[str]:
"""合并没有实质正文的 chunk 到下一个有正文的 chunk"""
merged: list[str] = []
pending = ""

for chunk_text, has_body in raw_chunks:
if not chunk_text:
continue
if not has_body:
# 纯标题节,暂存;但如果 pending 已经够长,先 flush
if pending and len(pending) + len(chunk_text) + 2 > chunk_size:
merged.append(pending.strip())
pending = ""
pending += chunk_text + "\n\n"
else:
if pending:
combined = pending + chunk_text
if len(combined) <= chunk_size:
merged.append(combined.strip())
else:
merged.append(pending.strip())
merged.append(chunk_text.strip())
pending = ""
else:
merged.append(chunk_text.strip())

# 处理尾部残留的 pending
if pending:
pending_text = pending.strip()
if merged and len(merged[-1] + "\n\n" + pending_text) <= chunk_size:
merged[-1] = merged[-1] + "\n\n" + pending_text
else:
merged.append(pending_text)

return [c for c in merged if c.strip()]

def _merge_short_chunks(self, chunks: list[str], chunk_size: int) -> list[str]:
"""合并过短的相邻 chunk(低于 min_chunk_size)"""
if self.min_chunk_size <= 0 or len(chunks) <= 1:
return chunks

final: list[str] = []
buf = ""

for c in chunks:
if buf:
combined = buf + "\n\n" + c
if len(combined) <= chunk_size:
buf = combined
else:
final.append(buf)
buf = c if len(c) < self.min_chunk_size else ""
if len(c) >= self.min_chunk_size:
final.append(c)
elif len(c) < self.min_chunk_size:
buf = c
else:
final.append(c)

if buf:
if final and len(final[-1] + "\n\n" + buf) <= chunk_size:
final[-1] = final[-1] + "\n\n" + buf
else:
final.append(buf)

return final

def _parse_sections(self, text: str) -> list[_Section]:
"""解析 Markdown 文本为章节列表

会跳过围栏代码块(``` 或 ~~~)内的内容,避免误匹配代码中的 # 字符。

Returns:
list[_Section]: 章节列表

"""
# 先标记围栏代码块的范围,解析时跳过
fenced_ranges = self._find_fenced_code_ranges(text)

# 匹配 Markdown 标题行(支持 # 后有或无空格)
heading_pattern = re.compile(
r"^(#{1," + str(self.max_heading_depth) + r"})\s*(.+)$", re.MULTILINE
Comment thread
Loagaeth marked this conversation as resolved.
)

# 找到所有标题及其位置(排除代码块内的)
headings = []
for match in heading_pattern.finditer(text):
if self._is_in_fenced_block(match.start(), fenced_ranges):
continue
level = len(match.group(1))
title = match.group(2).strip()
start = match.start()
end = match.end()
headings.append(
{"level": level, "title": title, "start": start, "end": end}
)

if not headings:
return []

sections: list[_Section] = []

# 处理第一个标题之前的内容(如果有)
preamble = text[: headings[0]["start"]].strip()
if preamble:
sections.append(_Section(heading_path=[], text=preamble, has_body=True))

# 维护标题栈来追踪层级路径
heading_stack: list[dict] = []

for i, heading in enumerate(headings):
# 更新标题栈
while heading_stack and heading_stack[-1]["level"] >= heading["level"]:
heading_stack.pop()
heading_stack.append(
{"level": heading["level"], "title": heading["title"]}
)

# 获取当前章节的内容范围
content_start = heading["end"]
if i + 1 < len(headings):
content_end = headings[i + 1]["start"]
else:
content_end = len(text)

# 提取内容(标题行 + 正文)
heading_line = text[heading["start"] : heading["end"]]
body = text[content_start:content_end].strip()

# 组合章节文本
section_text = heading_line
if body:
section_text += "\n" + body

# 构建标题路径
heading_path = [h["title"] for h in heading_stack[:-1]]

sections.append(
_Section(
heading_path=heading_path,
text=section_text,
has_body=bool(body),
)
)

return sections

@staticmethod
def _find_fenced_code_ranges(text: str) -> list[tuple[int, int]]:
"""找到所有围栏代码块的 (start, end) 范围"""
ranges: list[tuple[int, int]] = []
fence_pattern = re.compile(r"^(`{3,}|~{3,})", re.MULTILINE)
matches = list(fence_pattern.finditer(text))
Comment thread
sourcery-ai[bot] marked this conversation as resolved.

i = 0
while i < len(matches):
open_match = matches[i]
open_fence = open_match.group(1)
fence_char = open_fence[0]
fence_len = len(open_fence)

# 找到对应的关闭围栏
for j in range(i + 1, len(matches)):
close_match = matches[j]
close_fence = close_match.group(1)
if close_fence[0] == fence_char and len(close_fence) >= fence_len:
ranges.append((open_match.start(), close_match.end()))
i = j + 1
break
else:
# 没有找到关闭围栏,剩余部分都视为代码块
ranges.append((open_match.start(), len(text)))
break
continue

return ranges

@staticmethod
def _is_in_fenced_block(pos: int, ranges: list[tuple[int, int]]) -> bool:
"""判断给定位置是否在围栏代码块内"""
for start, end in ranges:
if start <= pos < end:
return True
return False
15 changes: 14 additions & 1 deletion astrbot/core/knowledge_base/kb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

from .chunking.base import BaseChunker
from .chunking.markdown import MarkdownChunker
from .chunking.recursive import RecursiveCharacterChunker
from .kb_db_sqlite import KBSQLiteDatabase
from .models import KBDocument, KBMedia, KnowledgeBase
Expand Down Expand Up @@ -315,7 +316,19 @@ async def upload_document(
await progress_callback("chunking", 0, 100)

try:
chunks_text = await self.chunker.chunk(
# 根据文件类型选择分块器:Markdown 文件使用结构感知分块
effective_chunker = self.chunker
file_ext = Path(file_name).suffix.lower() if file_name else ""
if file_ext in (".md", ".markdown", ".mkd", ".mdx"):
effective_chunker = MarkdownChunker(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
logger.info(
f"检测到 Markdown 文件 '{file_name}',使用 MarkdownChunker 进行结构化分块"
)

chunks_text = await effective_chunker.chunk(
text_content,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
Expand Down