diff --git a/.gitignore b/.gitignore index 19ed49afa..c8482dc09 100644 --- a/.gitignore +++ b/.gitignore @@ -59,4 +59,8 @@ sysinfo.txt *.apk *.unitypackage -UnitySDK.log \ No newline at end of file +UnitySDK.log +.venv/ +PR.md +PR_Arguments.md +PR_Opening.md \ No newline at end of file diff --git a/brax/io/model.py b/brax/io/model.py index cbb8a812f..f7f2dc1fc 100644 --- a/brax/io/model.py +++ b/brax/io/model.py @@ -16,16 +16,104 @@ import pickle from typing import Any +import warnings + from etils import epath +import jax +import msgpack +import numpy as np -def load_params(path: str) -> Any: - with epath.Path(path).open('rb') as fin: - buf = fin.read() - return pickle.loads(buf) +class SecurityWarning(UserWarning): + """Warning category for insecure model loading.""" + + pass + + +def _encode_pytree(obj: Any) -> Any: + """Recursively converts a Pytree into msgpack-compatible types.""" + if isinstance(obj, (jax.Array, np.ndarray)): + # Standard metadata-preserving array format + return { + '__type__': 'array', + 'data': obj.tobytes(), + 'shape': obj.shape, + 'dtype': str(obj.dtype), + } + # Handle flax.struct.dataclass and NamedTuples (like RunningStatisticsState) + if hasattr(obj, '__dict__') and hasattr(obj, '_asdict'): + return { + '__type__': obj.__class__.__name__, + 'data': {k: _encode_pytree(v) for k, v in obj._asdict().items()}, + } + # Handle nested containers + if isinstance(obj, dict): + return {k: _encode_pytree(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return { + '__type__': obj.__class__.__name__, + 'data': [_encode_pytree(x) for x in obj], + } + return obj + + +def _decode_pytree(obj: Any) -> Any: + """Reconstructs Pytree types from serialized dictionaries.""" + if isinstance(obj, dict): + type_name = obj.get('__type__') + # Reconstruct Arrays + if type_name == 'array': + return jax.numpy.frombuffer(obj['data'], dtype=obj['dtype']).reshape( + obj['shape'] + ) + + # Reconstruct specialized Brax/Flax types + data = obj.get('data') + if type_name == 'RunningStatisticsState': + from brax.training.acme import running_statistics + + return running_statistics.RunningStatisticsState(**_decode_pytree(data)) + if type_name == 'UInt64': + from brax.training import types + + return types.UInt64(**_decode_pytree(data)) + + # Reconstruct containers + if type_name == 'tuple': + return tuple(_decode_pytree(x) for x in data) + if type_name == 'list': + return [_decode_pytree(x) for x in data] + + # Generic nested dicts + return {k: _decode_pytree(v) for k, v in (data if data else obj).items()} + + return obj def save_params(path: str, params: Any): - """Saves parameters in flax format.""" + """Saves parameters safely using msgpack.""" + encoded = _encode_pytree(params) with epath.Path(path).open('wb') as fout: - fout.write(pickle.dumps(params)) + fout.write(msgpack.packb(encoded)) + + +def load_params(path: str, allow_pickle: bool = False) -> Any: + """Loads parameters safely, with a security-gated legacy path.""" + with epath.Path(path).open('rb') as fin: + buf = fin.read() + + if buf.startswith(b'\x80'): # Pickle Protocol 2+ Header + if not allow_pickle: + raise RuntimeError( + 'SECURITY ERROR: Insecure pickle file detected. For security reasons,' + ' loading is blocked. Use allow_pickle=True if you trust the source.' + ) + + warnings.warn( + 'SECURITY WARNING: Loading legacy pickle files is insecure and ' + 'deprecated. Please migrate your models to the new secure format.', + category=SecurityWarning, + ) + return pickle.loads(buf) + + return _decode_pytree(msgpack.unpackb(buf)) diff --git a/brax/io/model_test.py b/brax/io/model_test.py new file mode 100644 index 000000000..bd175ca22 --- /dev/null +++ b/brax/io/model_test.py @@ -0,0 +1,96 @@ +# Copyright 2026 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for parameter saving/loading.""" + +import os +import pickle +import tempfile + +from absl.testing import absltest +import jax +import jax.numpy as jnp + +from brax.io import model as brax_model + + +class ModelTest(absltest.TestCase): + + def test_save_load_params(self): + """Verifies that Msgpack serialization preserves Pytree data integrity.""" + params = { + 'policy': { + 'w': jnp.ones((4, 8)), + 'b': jnp.zeros((8,)), + }, + 'stats': (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), + 'list': [jnp.array(1), jnp.array(2)], + } + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'params.msgpack') + brax_model.save_params(path, params) + loaded_params = brax_model.load_params(path) + + # Check structure and values + import numpy as np + + jax.tree_util.tree_map(np.testing.assert_allclose, params, loaded_params) + + def test_pickle_security_block(self): + """Verifies that legacy pickle files are blocked by default.""" + params = {'test': 123} + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'params.pkl') + with open(path, 'wb') as f: + f.write(pickle.dumps(params)) + + with self.assertRaisesRegex( + RuntimeError, 'SECURITY ERROR: Insecure pickle file' + ): + brax_model.load_params(path) + + def test_pickle_allow_explicit(self): + """Verifies that legacy files can still be loaded with explicit flag.""" + params = {'test': 456} + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'params.pkl') + with open(path, 'wb') as f: + f.write(pickle.dumps(params)) + + with self.assertWarns(brax_model.SecurityWarning): + loaded_params = brax_model.load_params(path, allow_pickle=True) + + self.assertEqual(params, loaded_params) + + def test_rce_prevention(self): + """Verifies that malicious payloads are blocked before deserialization.""" + + class Malicious: + + def __reduce__(self): + return (os.system, ('echo RCE_EXPLOITED',)) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'malicious.pkl') + with open(path, 'wb') as f: + f.write(pickle.dumps(Malicious())) + + # Should raise RuntimeError and NOT execute the payload + with self.assertRaises(RuntimeError): + brax_model.load_params(path) + + +if __name__ == '__main__': + absltest.main() diff --git a/brax/training/agents/apg/train_test.py b/brax/training/agents/apg/train_test.py index 2f416fac7..f602037ee 100644 --- a/brax/training/agents/apg/train_test.py +++ b/brax/training/agents/apg/train_test.py @@ -18,11 +18,12 @@ from absl.testing import absltest from absl.testing import parameterized +import jax + from brax import envs from brax.training.acme import running_statistics from brax.training.agents.apg import networks as apg_networks from brax.training.agents.apg import train as apg -import jax class APGTest(parameterized.TestCase): @@ -62,8 +63,14 @@ def testNetworkEncoding(self, normalize_observations): env.observation_size, env.action_size, normalize_fn ) inference = apg_networks.make_inference_fn(apg_network) - byte_encoding = pickle.dumps(params) - decoded_params = pickle.loads(byte_encoding) + import tempfile + + from brax.io import model as brax_model + + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/params.msgpack' + brax_model.save_params(path, params) + decoded_params = brax_model.load_params(path) # Compute one action. state = env.reset(jax.random.PRNGKey(0)) diff --git a/brax/training/agents/ars/train_test.py b/brax/training/agents/ars/train_test.py index d12cd4285..1ee58789f 100644 --- a/brax/training/agents/ars/train_test.py +++ b/brax/training/agents/ars/train_test.py @@ -18,11 +18,12 @@ from absl.testing import absltest from absl.testing import parameterized +import jax + from brax import envs from brax.training.acme import running_statistics from brax.training.agents.ars import networks as ars_networks from brax.training.agents.ars import train as ars -import jax class ARSTest(parameterized.TestCase): @@ -44,8 +45,14 @@ def testModelEncoding(self, normalize_observations): env.observation_size, env.action_size, normalize_fn ) inference = ars_networks.make_inference_fn(ars_network) - byte_encoding = pickle.dumps(params) - decoded_params = pickle.loads(byte_encoding) + import tempfile + + from brax.io import model as brax_model + + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/params.msgpack' + brax_model.save_params(path, params) + decoded_params = brax_model.load_params(path) # Compute one action. state = env.reset(jax.random.PRNGKey(0)) diff --git a/brax/training/agents/bc/train_test.py b/brax/training/agents/bc/train_test.py index 8627f65bd..658df6fc8 100644 --- a/brax/training/agents/bc/train_test.py +++ b/brax/training/agents/bc/train_test.py @@ -19,11 +19,12 @@ from absl.testing import absltest from absl.testing import parameterized +import jax + from brax import envs from brax.training.acme import running_statistics from brax.training.agents.bc import networks as bc_networks from brax.training.agents.bc import train as bc -import jax class BCTest(parameterized.TestCase): @@ -107,8 +108,14 @@ def testNetworkEncoding(self): make_inference = bc_networks.make_inference_fn(bc_network) # Test serialization and deserialization - byte_encoding = pickle.dumps(params) - decoded_params = pickle.loads(byte_encoding) + import tempfile + + from brax.io import model as brax_model + + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/params.msgpack' + brax_model.save_params(path, params) + decoded_params = brax_model.load_params(path) # Compute one action with both the original and the reconstructed parameters state = fast.reset(jax.random.PRNGKey(0)) diff --git a/brax/training/agents/es/train_test.py b/brax/training/agents/es/train_test.py index 0b102d45b..334570b91 100644 --- a/brax/training/agents/es/train_test.py +++ b/brax/training/agents/es/train_test.py @@ -18,11 +18,12 @@ from absl.testing import absltest from absl.testing import parameterized +import jax + from brax import envs from brax.training.acme import running_statistics from brax.training.agents.es import networks as es_networks from brax.training.agents.es import train as es -import jax class ESTest(parameterized.TestCase): @@ -54,8 +55,14 @@ def testModelEncoding(self, normalize_observations): env.observation_size, env.action_size, normalize_fn ) inference = es_networks.make_inference_fn(es_network) - byte_encoding = pickle.dumps(params) - decoded_params = pickle.loads(byte_encoding) + import tempfile + + from brax.io import model as brax_model + + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/params.msgpack' + brax_model.save_params(path, params) + decoded_params = brax_model.load_params(path) # Compute one action. state = env.reset(jax.random.PRNGKey(0)) diff --git a/brax/training/agents/ppo/train_test.py b/brax/training/agents/ppo/train_test.py index bbbc24b17..e44dafe6f 100644 --- a/brax/training/agents/ppo/train_test.py +++ b/brax/training/agents/ppo/train_test.py @@ -16,21 +16,22 @@ import functools import pickle + from absl.testing import absltest from absl.testing import parameterized +import jax +from jax import numpy as jnp + from brax import envs from brax.training.acme import running_statistics from brax.training.agents.ppo import networks as ppo_networks from brax.training.agents.ppo import networks_vision as ppo_networks_vision from brax.training.agents.ppo import train as ppo -import jax -from jax import numpy as jnp class PPOTest(parameterized.TestCase): """Tests for PPO module.""" - @parameterized.parameters('ndarray', 'dict_state') def testTrain(self, obs_mode): """Test PPO with a simple env.""" @@ -211,8 +212,14 @@ def testNetworkEncoding(self, normalize_observations): env.observation_size, env.action_size, normalize_fn ) inference = ppo_networks.make_inference_fn(ppo_network) - byte_encoding = pickle.dumps(params) - decoded_params = pickle.loads(byte_encoding) + import tempfile + + from brax.io import model as brax_model + + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/params.msgpack' + brax_model.save_params(path, params) + decoded_params = brax_model.load_params(path) # Compute one action. state = env.reset(jax.random.PRNGKey(0)) diff --git a/brax/training/agents/sac/train_test.py b/brax/training/agents/sac/train_test.py index 02ecd06de..7cf0a7336 100644 --- a/brax/training/agents/sac/train_test.py +++ b/brax/training/agents/sac/train_test.py @@ -18,17 +18,17 @@ from absl.testing import absltest from absl.testing import parameterized +import jax + from brax import envs from brax.training.acme import running_statistics from brax.training.agents.sac import networks as sac_networks from brax.training.agents.sac import train as sac -import jax class SACTest(parameterized.TestCase): """Tests for SAC module.""" - def testTrain(self): """Test SAC with a simple env.""" fast = envs.get_environment('fast') @@ -69,8 +69,14 @@ def testNetworkEncoding(self, normalize_observations): env.observation_size, env.action_size, normalize_fn ) inference = sac_networks.make_inference_fn(sac_network) - byte_encoding = pickle.dumps(params) - decoded_params = pickle.loads(byte_encoding) + import tempfile + + from brax.io import model as brax_model + + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/params.msgpack' + brax_model.save_params(path, params) + decoded_params = brax_model.load_params(path) # Compute one action. state = env.reset(jax.random.PRNGKey(0))