diff --git a/pysr/sr.py b/pysr/sr.py index 8f8b9616a..e82ae3a08 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import json import logging import os import pickle as pkl @@ -12,6 +13,7 @@ import warnings from collections.abc import Callable from dataclasses import dataclass, fields +from importlib.metadata import PackageNotFoundError, version from io import StringIO from multiprocessing import cpu_count from pathlib import Path @@ -79,6 +81,63 @@ ALREADY_RAN = False pysr_logger = logging.getLogger(__name__) +CHECKPOINT_METADATA_FILENAME = "checkpoint_metadata.json" + + +def _get_pysr_version() -> str: + try: + return version("pysr") + except PackageNotFoundError: # pragma: no cover + return "unknown" + + +def _get_expected_checkpoint_versions() -> dict[str, str]: + juliapkg = json.loads((Path(__file__).with_name("juliapkg.json")).read_text()) + symbolic_regression = juliapkg["packages"]["SymbolicRegression"] + symbolic_regression_backend = symbolic_regression.get( + "version", symbolic_regression.get("rev", "unknown") + ) + julia_version = f"{jl.VERSION.major}.{jl.VERSION.minor}.{jl.VERSION.patch}" + return { + "pysr_version": _get_pysr_version(), + "symbolic_regression_backend": symbolic_regression_backend, + "julia_version": julia_version, + } + + +def _validate_checkpoint_metadata(run_directory: PathLike) -> None: + metadata_filename = Path(run_directory) / CHECKPOINT_METADATA_FILENAME + if not metadata_filename.exists(): + return + + try: + metadata = json.loads(metadata_filename.read_text()) + except json.JSONDecodeError as e: + raise ValueError( + f"Checkpoint metadata file `{metadata_filename}` is not valid JSON. " + "Delete the checkpoint files and rerun, or restore a valid checkpoint." + ) from e + expected = _get_expected_checkpoint_versions() + mismatches = [] + for key, expected_value in expected.items(): + found_value = metadata.get(key, "") + if found_value != expected_value: + mismatches.append((key, expected_value, found_value)) + + if mismatches: + mismatch_summary = "; ".join( + [ + f"{key}: expected `{expected_value}`, found `{found_value}`" + for key, expected_value, found_value in mismatches + ] + ) + raise ValueError( + "Checkpoint version metadata mismatch detected before deserialization. " + f"{mismatch_summary}. " + "This checkpoint was created with incompatible versions. " + "Reinstall matching PySR/Julia backend versions, set the appropriate backend " + "version, or delete the checkpoint and rerun." + ) def _process_constraints( @@ -1202,6 +1261,7 @@ def from_file( pkl_filename = Path(run_directory) / "checkpoint.pkl" if pkl_filename.exists(): pysr_logger.info(f"Attempting to load model from {pkl_filename}...") + _validate_checkpoint_metadata(run_directory) assert binary_operators is None assert unary_operators is None assert operators is None @@ -1388,11 +1448,19 @@ def _checkpoint(self): """ # Save model state: self.show_pickle_warnings_ = False + checkpoint_succeeded = False with open(self.get_pkl_filename(), "wb") as f: try: pkl.dump(self, f) + checkpoint_succeeded = True except Exception as e: pysr_logger.debug(f"Error checkpointing model: {e}") + if checkpoint_succeeded: + self.get_checkpoint_metadata_filename().write_text( + json.dumps( + _get_expected_checkpoint_versions(), indent=2, sort_keys=True + ) + ) self.show_pickle_warnings_ = True def get_pkl_filename(self) -> Path: @@ -1400,6 +1468,9 @@ def get_pkl_filename(self) -> Path: path.parent.mkdir(parents=True, exist_ok=True) return path + def get_checkpoint_metadata_filename(self) -> Path: + return self.get_pkl_filename().with_name(CHECKPOINT_METADATA_FILENAME) + @property def equations(self): # pragma: no cover warnings.warn( diff --git a/pysr/test/test_main.py b/pysr/test/test_main.py index 4edccc7da..5af22bca7 100644 --- a/pysr/test/test_main.py +++ b/pysr/test/test_main.py @@ -1,5 +1,6 @@ import functools import importlib +import json import os import pickle as pkl import platform @@ -1500,6 +1501,67 @@ def test_from_file_requires_operator_configuration(self): self.assertIn("must provide either `operators`", str(cm.exception)) + def test_from_file_checkpoint_metadata_mismatch_raises_clear_error(self): + with tempfile.TemporaryDirectory() as tmpdir: + run_dir = Path(tmpdir) / "run" + run_dir.mkdir() + (run_dir / "checkpoint.pkl").write_bytes(b"this is not a pickle file") + (run_dir / "checkpoint_metadata.json").write_text( + json.dumps( + { + "pysr_version": "0.0.0", + "symbolic_regression_backend": "v0.0.0", + "julia_version": "0.0.0", + } + ) + ) + + with self.assertRaises(ValueError) as cm: + PySRRegressor.from_file(run_directory=run_dir) + + self.assertIn("Checkpoint version metadata mismatch", str(cm.exception)) + self.assertIn("pysr_version", str(cm.exception)) + + def test_from_file_checkpoint_without_metadata_falls_back_to_unpickling(self): + """Backward-compat: older checkpoints won't have metadata; we should not error early.""" + with tempfile.TemporaryDirectory() as tmpdir: + run_dir = Path(tmpdir) / "run" + run_dir.mkdir() + (run_dir / "checkpoint.pkl").write_bytes(b"this is not a pickle file") + + with self.assertRaises(Exception) as cm: + PySRRegressor.from_file(run_directory=run_dir) + + # The error should come from unpickling, not our metadata guardrails. + self.assertNotIn("Checkpoint version metadata mismatch", str(cm.exception)) + self.assertNotIn("Checkpoint metadata file", str(cm.exception)) + + def test_from_file_checkpoint_invalid_metadata_json_raises_clear_error(self): + with tempfile.TemporaryDirectory() as tmpdir: + run_dir = Path(tmpdir) / "run" + run_dir.mkdir() + (run_dir / "checkpoint.pkl").write_bytes(b"this is not a pickle file") + (run_dir / "checkpoint_metadata.json").write_text("{") + + with self.assertRaises(ValueError) as cm: + PySRRegressor.from_file(run_directory=run_dir) + + self.assertIn("not valid JSON", str(cm.exception)) + + def test_checkpoint_writes_version_metadata(self): + with tempfile.TemporaryDirectory() as tmpdir: + model = PySRRegressor() + model.output_directory_ = tmpdir + model.run_id_ = "run" + model._checkpoint() + + metadata_file = Path(tmpdir) / "run" / "checkpoint_metadata.json" + self.assertTrue(metadata_file.exists()) + metadata = json.loads(metadata_file.read_text()) + self.assertIn("pysr_version", metadata) + self.assertIn("symbolic_regression_backend", metadata) + self.assertIn("julia_version", metadata) + def test_size_warning(self): """Ensure that a warning is given for a large input size.""" model = PySRRegressor()