diff --git a/chatterbot/corpus.py b/chatterbot/corpus.py index 2c19c043c..9a347a1db 100644 --- a/chatterbot/corpus.py +++ b/chatterbot/corpus.py @@ -1,50 +1,70 @@ -import os import io -import glob from pathlib import Path +from dataclasses import dataclass +from typing import List, Generator + from chatterbot.exceptions import OptionalDependencyImportError +# Try to import ChatterBot corpus data directory try: from chatterbot_corpus.corpus import DATA_DIRECTORY except (ImportError, ModuleNotFoundError): - # Default to the home directory of the current user - DATA_DIRECTORY = os.path.join( - Path.home(), - 'chatterbot_corpus', - 'data' - ) + # Default to home directory if corpus package not installed + DATA_DIRECTORY = Path.home() / 'chatterbot_corpus' / 'data' + +# Only support YAML formats for now +CORPUS_EXTENSIONS = ['yml', 'yaml'] + +# Simple cache for loaded corpus files +_corpus_cache = {} -CORPUS_EXTENSION = 'yml' +@dataclass +class CorpusData: + conversations: List[List[str]] + categories: List[str] + file_path: str -def get_file_path(dotted_path, extension='json') -> str: +def get_file_path(dotted_path: str, extensions: List[str] = CORPUS_EXTENSIONS) -> Path: """ - Reads a dotted file path and returns the file path. + Convert a dotted path or filesystem path into an actual file path. + Raises FileNotFoundError if the file does not exist. """ - # If the operating system's file path seperator character is in the string - if os.sep in dotted_path or '/' in dotted_path: - # Assume the path is a valid file path - return dotted_path + path = Path(dotted_path) + # If path already exists, return it + if path.exists(): + return path + + # Split dotted path parts = dotted_path.split('.') if parts[0] == 'chatterbot': - parts.pop(0) - parts[0] = DATA_DIRECTORY + parts[0] = str(DATA_DIRECTORY) + + base_path = Path(*parts) - corpus_path = os.path.join(*parts) + # Check for file existence with supported extensions + for ext in extensions: + candidate = base_path.with_suffix(f'.{ext}') + if candidate.exists(): + return candidate - path_with_extension = '{}.{}'.format(corpus_path, extension) - if os.path.exists(path_with_extension): - corpus_path = path_with_extension + # If directory exists, return it + if base_path.is_dir(): + return base_path - return corpus_path + raise FileNotFoundError(f"Corpus file or directory not found for: {dotted_path}") -def read_corpus(file_name) -> dict: +def read_corpus(file_path: Path) -> dict: """ - Read and return the data from a corpus json file. + Read a YAML corpus file and return its contents. + Caches results for repeated access. """ + if file_path in _corpus_cache: + return _corpus_cache[file_path] + try: import yaml except ImportError: @@ -55,37 +75,49 @@ def read_corpus(file_name) -> dict: ) raise OptionalDependencyImportError(message) - with io.open(file_name, encoding='utf-8') as data_file: - return yaml.safe_load(data_file) + try: + with io.open(file_path, encoding='utf-8') as f: + data = yaml.safe_load(f) + except Exception as e: + raise RuntimeError(f"Failed to read corpus file {file_path}: {e}") from e + + if not isinstance(data, dict): + raise ValueError(f"Corpus file {file_path} did not return a dictionary.") + + _corpus_cache[file_path] = data + return data -def list_corpus_files(dotted_path) -> list[str]: +def list_corpus_files(dotted_path: str) -> List[Path]: """ - Return a list of file paths to each data file in the specified corpus. + Return a sorted list of all corpus files (with supported extensions) + in the given dotted path or directory. """ - corpus_path = get_file_path(dotted_path, extension=CORPUS_EXTENSION) - paths = [] + path = get_file_path(dotted_path) + files: List[Path] = [] - if os.path.isdir(corpus_path): - paths = glob.glob(corpus_path + '/**/*.' + CORPUS_EXTENSION, recursive=True) + if path.is_dir(): + for ext in CORPUS_EXTENSIONS: + files.extend(path.rglob(f'*.{ext}')) else: - paths.append(corpus_path) + files.append(path) - paths.sort() - return paths + return sorted(files) -def load_corpus(*data_file_paths): +def load_corpus(*data_file_paths: str) -> Generator[CorpusData, None, None]: """ - Return the data contained within a specified corpus. + Yield CorpusData objects for each specified corpus file. """ - for file_path in data_file_paths: - corpus = [] - corpus_data = read_corpus(file_path) - - conversations = corpus_data.get('conversations', []) - corpus.extend(conversations) - - categories = corpus_data.get('categories', []) - - yield corpus, categories, file_path + for file_path_str in data_file_paths: + path = get_file_path(file_path_str) + if path.is_dir(): + files = list_corpus_files(path) + else: + files = [path] + + for file in files: + corpus_data = read_corpus(file) + conversations = corpus_data.get('conversations', []) + categories = corpus_data.get('categories', []) + yield CorpusData(conversations, categories, str(file))