Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import copy
import json
import logging
import os
import pickle as pkl
Expand All @@ -12,6 +13,7 @@
import warnings
from collections.abc import Callable
from dataclasses import dataclass, fields
from importlib.metadata import PackageNotFoundError, version
from io import StringIO
from multiprocessing import cpu_count
from pathlib import Path
Expand Down Expand Up @@ -79,6 +81,63 @@
ALREADY_RAN = False

pysr_logger = logging.getLogger(__name__)
CHECKPOINT_METADATA_FILENAME = "checkpoint_metadata.json"


def _get_pysr_version() -> str:
try:
return version("pysr")
except PackageNotFoundError: # pragma: no cover
return "unknown"


def _get_expected_checkpoint_versions() -> dict[str, str]:
juliapkg = json.loads((Path(__file__).with_name("juliapkg.json")).read_text())
symbolic_regression = juliapkg["packages"]["SymbolicRegression"]
symbolic_regression_backend = symbolic_regression.get(
"version", symbolic_regression.get("rev", "unknown")
)
julia_version = f"{jl.VERSION.major}.{jl.VERSION.minor}.{jl.VERSION.patch}"
return {
"pysr_version": _get_pysr_version(),
"symbolic_regression_backend": symbolic_regression_backend,
"julia_version": julia_version,
}


def _validate_checkpoint_metadata(run_directory: PathLike) -> None:
metadata_filename = Path(run_directory) / CHECKPOINT_METADATA_FILENAME
if not metadata_filename.exists():
return

try:
metadata = json.loads(metadata_filename.read_text())
except json.JSONDecodeError as e:
raise ValueError(
f"Checkpoint metadata file `{metadata_filename}` is not valid JSON. "
"Delete the checkpoint files and rerun, or restore a valid checkpoint."
) from e
expected = _get_expected_checkpoint_versions()
mismatches = []
for key, expected_value in expected.items():
found_value = metadata.get(key, "<missing>")
if found_value != expected_value:
mismatches.append((key, expected_value, found_value))

if mismatches:
mismatch_summary = "; ".join(
[
f"{key}: expected `{expected_value}`, found `{found_value}`"
for key, expected_value, found_value in mismatches
]
)
raise ValueError(
"Checkpoint version metadata mismatch detected before deserialization. "
f"{mismatch_summary}. "
"This checkpoint was created with incompatible versions. "
"Reinstall matching PySR/Julia backend versions, set the appropriate backend "
"version, or delete the checkpoint and rerun."
)


def _process_constraints(
Expand Down Expand Up @@ -1202,6 +1261,7 @@ def from_file(
pkl_filename = Path(run_directory) / "checkpoint.pkl"
if pkl_filename.exists():
pysr_logger.info(f"Attempting to load model from {pkl_filename}...")
_validate_checkpoint_metadata(run_directory)
assert binary_operators is None
assert unary_operators is None
assert operators is None
Expand Down Expand Up @@ -1388,18 +1448,29 @@ def _checkpoint(self):
"""
# Save model state:
self.show_pickle_warnings_ = False
checkpoint_succeeded = False
with open(self.get_pkl_filename(), "wb") as f:
try:
pkl.dump(self, f)
checkpoint_succeeded = True
except Exception as e:
pysr_logger.debug(f"Error checkpointing model: {e}")
if checkpoint_succeeded:
self.get_checkpoint_metadata_filename().write_text(
json.dumps(
_get_expected_checkpoint_versions(), indent=2, sort_keys=True
)
)
self.show_pickle_warnings_ = True

def get_pkl_filename(self) -> Path:
path = Path(self.output_directory_) / self.run_id_ / "checkpoint.pkl"
path.parent.mkdir(parents=True, exist_ok=True)
return path

def get_checkpoint_metadata_filename(self) -> Path:
return self.get_pkl_filename().with_name(CHECKPOINT_METADATA_FILENAME)

@property
def equations(self): # pragma: no cover
warnings.warn(
Expand Down
62 changes: 62 additions & 0 deletions pysr/test/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import importlib
import json
import os
import pickle as pkl
import platform
Expand Down Expand Up @@ -1500,6 +1501,67 @@ def test_from_file_requires_operator_configuration(self):

self.assertIn("must provide either `operators`", str(cm.exception))

def test_from_file_checkpoint_metadata_mismatch_raises_clear_error(self):
with tempfile.TemporaryDirectory() as tmpdir:
run_dir = Path(tmpdir) / "run"
run_dir.mkdir()
(run_dir / "checkpoint.pkl").write_bytes(b"this is not a pickle file")
(run_dir / "checkpoint_metadata.json").write_text(
json.dumps(
{
"pysr_version": "0.0.0",
"symbolic_regression_backend": "v0.0.0",
"julia_version": "0.0.0",
}
)
)

with self.assertRaises(ValueError) as cm:
PySRRegressor.from_file(run_directory=run_dir)

self.assertIn("Checkpoint version metadata mismatch", str(cm.exception))
self.assertIn("pysr_version", str(cm.exception))

def test_from_file_checkpoint_without_metadata_falls_back_to_unpickling(self):
"""Backward-compat: older checkpoints won't have metadata; we should not error early."""
with tempfile.TemporaryDirectory() as tmpdir:
run_dir = Path(tmpdir) / "run"
run_dir.mkdir()
(run_dir / "checkpoint.pkl").write_bytes(b"this is not a pickle file")

with self.assertRaises(Exception) as cm:
PySRRegressor.from_file(run_directory=run_dir)

# The error should come from unpickling, not our metadata guardrails.
self.assertNotIn("Checkpoint version metadata mismatch", str(cm.exception))
self.assertNotIn("Checkpoint metadata file", str(cm.exception))

def test_from_file_checkpoint_invalid_metadata_json_raises_clear_error(self):
with tempfile.TemporaryDirectory() as tmpdir:
run_dir = Path(tmpdir) / "run"
run_dir.mkdir()
(run_dir / "checkpoint.pkl").write_bytes(b"this is not a pickle file")
(run_dir / "checkpoint_metadata.json").write_text("{")

with self.assertRaises(ValueError) as cm:
PySRRegressor.from_file(run_directory=run_dir)

self.assertIn("not valid JSON", str(cm.exception))

def test_checkpoint_writes_version_metadata(self):
with tempfile.TemporaryDirectory() as tmpdir:
model = PySRRegressor()
model.output_directory_ = tmpdir
model.run_id_ = "run"
model._checkpoint()

metadata_file = Path(tmpdir) / "run" / "checkpoint_metadata.json"
self.assertTrue(metadata_file.exists())
metadata = json.loads(metadata_file.read_text())
self.assertIn("pysr_version", metadata)
self.assertIn("symbolic_regression_backend", metadata)
self.assertIn("julia_version", metadata)

def test_size_warning(self):
"""Ensure that a warning is given for a large input size."""
model = PySRRegressor()
Expand Down
Loading