From ce81e61ba419069cde4e6f4e72c059ba835424f7 Mon Sep 17 00:00:00 2001 From: mukunda katta Date: Fri, 15 May 2026 07:50:05 -0700 Subject: [PATCH] Use user cache directory for tiktoken downloads --- tests/test_misc.py | 31 +++++++++++++++++++++++++++++++ tiktoken/load.py | 21 +++++++++++++++++---- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/tests/test_misc.py b/tests/test_misc.py index 0832c8ee..0f4a03a8 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -1,7 +1,11 @@ +import hashlib +import os +import stat import subprocess import sys import tiktoken +import tiktoken.load def test_encoding_for_model(): @@ -28,3 +32,30 @@ def test_optional_blobfile_dependency(): assert "blobfile" not in sys.modules """ subprocess.check_call([sys.executable, "-c", prog]) + + +def test_default_cache_dir_is_user_specific(tmp_path, monkeypatch): + data = b"token data" + expected_hash = hashlib.sha256(data).hexdigest() + blobpath = "https://openaipublic.blob.core.windows.net/encodings/example.tiktoken" + cache_key = hashlib.sha1(blobpath.encode()).hexdigest() + + monkeypatch.delenv("TIKTOKEN_CACHE_DIR", raising=False) + monkeypatch.delenv("DATA_GYM_CACHE_DIR", raising=False) + monkeypatch.setenv("XDG_CACHE_HOME", str(tmp_path / "xdg-cache")) + monkeypatch.setattr(tiktoken.load, "read_file", lambda _: data) + + assert tiktoken.load.read_file_cached(blobpath, expected_hash) == data + + cache_dir = tmp_path / "xdg-cache" / "tiktoken" + assert (cache_dir / cache_key).read_bytes() == data + assert not (tmp_path / "data-gym-cache").exists() + + def fail_read_file(_: str) -> bytes: + raise AssertionError("cached file was not used") + + monkeypatch.setattr(tiktoken.load, "read_file", fail_read_file) + assert tiktoken.load.read_file_cached(blobpath, expected_hash) == data + + if os.name != "nt": + assert stat.S_IMODE(cache_dir.stat().st_mode) == 0o700 diff --git a/tiktoken/load.py b/tiktoken/load.py index 3c76bcb3..a9e8a5f0 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -5,6 +5,19 @@ import os +def _default_cache_dir() -> str: + if os.name == "nt": + cache_home = os.environ.get("LOCALAPPDATA") + if cache_home is None: + cache_home = os.path.join(os.path.expanduser("~"), "AppData", "Local") + else: + cache_home = os.environ.get("XDG_CACHE_HOME") + if cache_home is None: + cache_home = os.path.join(os.path.expanduser("~"), ".cache") + + return os.path.join(cache_home, "tiktoken") + + def read_file(blobpath: str) -> bytes: if "://" not in blobpath: with open(blobpath, "rb", buffering=0) as f: @@ -39,9 +52,7 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: elif "DATA_GYM_CACHE_DIR" in os.environ: cache_dir = os.environ["DATA_GYM_CACHE_DIR"] else: - import tempfile - - cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache") + cache_dir = _default_cache_dir() user_specified_cache = False if cache_dir == "": @@ -73,7 +84,9 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: import uuid try: - os.makedirs(cache_dir, exist_ok=True) + os.makedirs(cache_dir, mode=0o700, exist_ok=True) + if not user_specified_cache: + os.chmod(cache_dir, 0o700) tmp_filename = cache_path + "." + str(uuid.uuid4()) + ".tmp" with open(tmp_filename, "wb") as f: f.write(contents)