diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index f4b2b58a..7c052a53 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -77,7 +77,6 @@ def _unpmap(v): def _init_training_state( key: PRNGKey, obs_size: int, - local_devices_to_use: int, sac_network: sac_networks.SACNetworks, alpha_optimizer: optax.GradientTransformation, policy_optimizer: optax.GradientTransformation, @@ -109,16 +108,7 @@ def _init_training_state( alpha_params=log_alpha, normalizer_params=normalizer_params, ) - devices = jax.local_devices()[:local_devices_to_use] - mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) - sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) - - def _replicate(x): - if isinstance(x, jax.Array): - return jax.device_put(jnp.stack([x] * len(devices)), sharding) - return jax.device_put(np.stack([x] * len(devices)), sharding) - - return jax.tree_util.tree_map(_replicate, training_state) + return training_state def train( @@ -491,7 +481,6 @@ def training_epoch_with_timing( training_state = _init_training_state( key=global_key, obs_size=obs_size, - local_devices_to_use=local_devices_to_use, sac_network=sac_network, alpha_optimizer=alpha_optimizer, policy_optimizer=policy_optimizer, @@ -506,6 +495,19 @@ def training_epoch_with_timing( policy_params=params[1], ) + # Replicate training state across devices AFTER checkpoint restoration + # so that restored params have the correct per-device shape. Fixes #659. + devices = jax.local_devices()[:local_devices_to_use] + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + + def _replicate(x): + if isinstance(x, jax.Array): + return jax.device_put(jnp.stack([x] * len(devices)), sharding) + return jax.device_put(np.stack([x] * len(devices)), sharding) + + training_state = jax.tree_util.tree_map(_replicate, training_state) + local_key, rb_key, env_key, eval_key = jax.random.split(local_key, 4) # Env init @@ -624,4 +626,4 @@ def training_epoch_with_timing( pmap.assert_is_replicated(training_state) logging.info('total steps: %s', total_steps) pmap.synchronize_hosts() - return (make_policy, params, metrics) + return (make_policy, params, metrics) \ No newline at end of file