From 5952f1836015dd0bdbd6d31f82eab9530b2873fd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 May 2026 17:46:59 +0100 Subject: [PATCH 1/4] [Feature][Performance] NextObservationDelta env transform Adds a stateless env-side transform that stores `("next", obs)` as a low-precision delta from the root `obs`, reducing the rollout-time memory footprint of large continuous observations. The transform compresses next observations in `_step` and rehydrates the flowing tensordict's root observation in a new `_post_step_mdp_hooks` extension point on `EnvBase`. The hook was previously half-stubbed in `common.py` / `_base.py` / `llm/chat.py`; it is now wired through `step_and_maybe_reset` and threaded into `Transform`, `Compose`, and `TransformedEnv`. Caveats documented on the class: - The compression is lossy; round-trip error scales with delta dtype precision and observation magnitude. - Memory savings only materialize against non-pre-allocated stacked output (e.g. `SyncDataCollector(use_buffers=False)` or a lazy RB storage). Pre-allocated buffers upcast the write. - The hook fires from `step_and_maybe_reset`; direct `env.rollout()` callers must rehydrate manually. - `check_env_specs` rejects the transformed env in v1 because the observation spec is shared between root and `("next", ...)` and we do not fork it. Includes a `TestNextObservationDelta` test class with 16 cases (14 passing, 2 documented skips) covering single-env, serial/parallel batched envs (inner and outer wrapping), auto-inference skipping non-floating dtypes, multi-key, reset semantics, Compose ordering, and an end-to-end `SyncDataCollector(use_buffers=False)` check that the stacked batch carries `float16` `("next", obs)`. --- docs/source/reference/envs_transforms.rst | 1 + .../transforms/test_observation_transforms.py | 268 ++++++++++++++++++ torchrl/envs/__init__.py | 2 + torchrl/envs/common.py | 21 +- torchrl/envs/transforms/__init__.py | 2 + torchrl/envs/transforms/_base.py | 59 +++- torchrl/envs/transforms/_observation.py | 207 +++++++++++++- torchrl/envs/transforms/transforms.py | 2 + 8 files changed, 552 insertions(+), 10 deletions(-) diff --git a/docs/source/reference/envs_transforms.rst b/docs/source/reference/envs_transforms.rst index ff1d1887a8e..96731163aff 100644 --- a/docs/source/reference/envs_transforms.rst +++ b/docs/source/reference/envs_transforms.rst @@ -277,6 +277,7 @@ Available Transforms MeanActionSelector ModuleTransform MultiAction + NextObservationDelta NextStateReconstructor NoopResetEnv ObservationNorm diff --git a/test/transforms/test_observation_transforms.py b/test/transforms/test_observation_transforms.py index 1dfc60b2c56..12f9f081aa3 100644 --- a/test/transforms/test_observation_transforms.py +++ b/test/transforms/test_observation_transforms.py @@ -30,6 +30,7 @@ Crop, FlattenObservation, GrayScale, + NextObservationDelta, ObservationNorm, ParallelEnv, PermuteTransform, @@ -2790,3 +2791,270 @@ def test_transform_no_env(self, batch): td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch) td = trans(td) assert td["pixels"].shape == torch.Size((*batch, C, D, H, W)) + + +class TestNextObservationDelta(TransformBase): + """Tests for the env-side delta-compression transform.""" + + @staticmethod + def _delta_tol(delta_dtype: torch.dtype, scale: float = 1.0) -> float: + # Round-trip error of (next_obs - obs) cast to delta_dtype then back. + # finfo.eps is the gap around 1.0; scale by max magnitude of the delta + # being represented to get a meaningful tolerance. + return torch.finfo(delta_dtype).eps * 8.0 * scale + + # `check_env_specs` is intentionally NOT used here. NextObservationDelta + # changes the runtime dtype of `("next", obs)` to `delta_dtype` while + # leaving `observation_spec` untouched (root and `("next", ...)` share the + # same spec in TorchRL, and we don't fork it in v1). The spec-vs-runtime + # check would therefore reject the env. We use a reset + step smoke test + # instead, which is what callers actually do. + + @staticmethod + def _smoke_one_step(env, *, expect_compressed: bool): + td = env.reset() + td.set("action", env.action_spec.rand()) + post_step, flowing = env.step_and_maybe_reset(td) + if expect_compressed: + assert post_step["next", "observation"].dtype == torch.float16 + assert flowing["observation"].dtype == torch.float32 + else: + # batched-env wraps a TransformedEnv worker -- the outer batched + # env's step_and_maybe_reset does not invoke the hook, and may + # upcast through pre-allocated buffers. We only assert the env + # boots and steps without raising. + assert post_step["next", "observation"].shape == flowing["observation"].shape + + def test_single_trans_env_check(self): + env = TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ) + self._smoke_one_step(env, expect_compressed=True) + + def test_serial_trans_env_check(self): + env = SerialEnv( + 2, + lambda: TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ), + ) + self._smoke_one_step(env, expect_compressed=False) + + def test_parallel_trans_env_check(self): + env = ParallelEnv( + 2, + lambda: TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ), + mp_start_method="fork", + ) + try: + self._smoke_one_step(env, expect_compressed=False) + finally: + env.close() + + def test_trans_serial_env_check(self): + env = TransformedEnv( + SerialEnv(2, lambda: ContinuousActionVecMockEnv()), + NextObservationDelta(in_keys=["observation"]), + ) + self._smoke_one_step(env, expect_compressed=True) + + def test_trans_parallel_env_check(self): + env = TransformedEnv( + ParallelEnv(2, lambda: ContinuousActionVecMockEnv(), mp_start_method="fork"), + NextObservationDelta(in_keys=["observation"]), + ) + try: + self._smoke_one_step(env, expect_compressed=True) + finally: + env.close() + + def test_transform_no_env(self): + # The transform is env-side only: calling it like a module / RB transform + # must raise. + t = NextObservationDelta(in_keys=["observation"]) + with pytest.raises(NotImplementedError, match="env-side transform"): + t(TensorDict({"observation": torch.zeros(3)}, [])) + + def test_transform_compose(self): + # Composed offline call still routes through forward, which is unsupported. + t = Compose(NextObservationDelta(in_keys=["observation"])) + with pytest.raises(NotImplementedError, match="env-side transform"): + t(TensorDict({"observation": torch.zeros(3)}, [])) + + def test_transform_env(self): + torch.manual_seed(0) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ) + td = env.reset() + td.set("action", env.action_spec.rand()) + post_step, flowing = env.step_and_maybe_reset(td) + + assert post_step["next", "observation"].dtype == torch.float16 + assert post_step["observation"].dtype == torch.float32 + assert flowing["observation"].dtype == torch.float32 + + # Rehydrated flowing root == root_obs + delta (round-tripped through + # the delta dtype). + expected = ( + post_step["observation"].to(torch.float32) + + post_step["next", "observation"].to(torch.float32) + ) + tol = self._delta_tol(torch.float16, scale=max(1.0, expected.abs().max().item())) + torch.testing.assert_close( + flowing["observation"].to(torch.float32), expected, atol=tol, rtol=tol + ) + + def test_transform_model(self): + pytest.skip("NextObservationDelta is an env-side transform; not a module hook.") + + def test_transform_rb(self): + from torchrl.data import LazyTensorStorage, ReplayBuffer + + rb = ReplayBuffer( + storage=LazyTensorStorage(4), + transform=NextObservationDelta(in_keys=["observation"]), + batch_size=4, + ) + # Extend goes through `inv`, which is a no-op for this transform + # (it has no in_keys_inv) -- so the write succeeds. + rb.extend( + TensorDict({"observation": torch.zeros(4, 3)}, batch_size=[4]) + ) + # Sampling, however, routes through `forward`, which is unsupported + # for this env-side-only transform. + with pytest.raises(NotImplementedError, match="env-side transform"): + rb.sample() + + def test_transform_inverse(self): + pytest.skip("NextObservationDelta has no inverse (no in_keys_inv).") + + def test_auto_infer_keys_skips_uint8(self): + # Build an env whose observation_spec contains a uint8 image key plus + # a float vector. Only the float vector should be picked up by + # auto-inference. + from torchrl.data.tensor_specs import Bounded, Composite, Unbounded + + class _DualObsEnv(ContinuousActionVecMockEnv): + pass + + env = TransformedEnv( + ContinuousActionVecMockEnv( + observation_spec=Composite( + observation=Unbounded(shape=(7,)), + pixels=Bounded(low=0, high=255, shape=(3, 4, 4), dtype=torch.uint8), + ) + ), + NextObservationDelta(), + ) + # Access lazy in_keys via the transform. + in_keys = list(env.transform.in_keys) + assert ("observation",) in [tuple([k]) if isinstance(k, str) else tuple(k) for k in in_keys] + # No uint8 leaf made it in. + for k in in_keys: + spec = env.observation_spec[k] if not isinstance(k, tuple) else env.observation_spec[k] + assert spec.dtype.is_floating_point + + def test_multi_in_keys_explicit(self): + torch.manual_seed(1) + from torchrl.data.tensor_specs import Composite, Unbounded + + env = TransformedEnv( + ContinuousActionVecMockEnv( + observation_spec=Composite( + observation=Unbounded(shape=(7,)), + observation_orig=Unbounded(shape=(7,)), + ) + ), + NextObservationDelta(in_keys=["observation", "observation_orig"]), + ) + td = env.reset() + td.set("action", env.action_spec.rand()) + out, out_ = env.step_and_maybe_reset(td) + assert out["next", "observation"].dtype == torch.float16 + assert out["next", "observation_orig"].dtype == torch.float16 + assert out_["observation"].dtype == torch.float32 + assert out_["observation_orig"].dtype == torch.float32 + + def test_reset_between_steps(self): + # Stateless contract: a reset between two steps must not corrupt the + # delta computed on the second step. + torch.manual_seed(2) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ) + td0 = env.reset() + td0.set("action", env.action_spec.rand()) + env.step_and_maybe_reset(td0) + td1 = env.reset() + td1.set("action", env.action_spec.rand()) + out, out_ = env.step_and_maybe_reset(td1) + expected = out["observation"].to(torch.float32) + out["next", "observation"].to( + torch.float32 + ) + tol = self._delta_tol(torch.float16, scale=max(1.0, expected.abs().max().item())) + torch.testing.assert_close( + out_["observation"].to(torch.float32), expected, atol=tol, rtol=tol + ) + + def test_compose_with_downstream_transform(self): + # NextObservationDelta inside a Compose with another env transform. + # Round-trip through the rehydration must still match a reference run + # that doesn't compress. + from torchrl.envs.transforms import RewardSum + + torch.manual_seed(3) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + Compose( + NextObservationDelta(in_keys=["observation"]), + RewardSum(), + ), + ) + td = env.reset() + td.set("action", env.action_spec.rand()) + out, out_ = env.step_and_maybe_reset(td) + assert out["next", "observation"].dtype == torch.float16 + assert out_["observation"].dtype == torch.float32 + + def test_collector_use_buffers_false(self): + # End-to-end: with use_buffers=False (no pre-allocated final_rollout), + # the stacked rollout actually carries float16 ("next", obs). + from torchrl.collectors import SyncDataCollector + + torch.manual_seed(4) + + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ) + + collector = SyncDataCollector( + create_env_fn=make_env, + policy=None, + frames_per_batch=16, + total_frames=16, + use_buffers=False, + ) + try: + batch = next(iter(collector)) + finally: + collector.shutdown() + + assert batch["next", "observation"].dtype == torch.float16 + assert batch["observation"].dtype == torch.float32 + + # Reconstruct next.obs at full precision and verify shape/finiteness. + recon = batch["observation"].to(torch.float32) + batch["next", "observation"].to( + torch.float32 + ) + assert recon.shape == batch["observation"].shape + assert torch.isfinite(recon).all() diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 52af517af69..7437931d5b2 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -89,6 +89,7 @@ MeanActionSelector, MultiAction, MultiStepTransform, + NextObservationDelta, NoopResetEnv, ObservationNorm, ObservationTransform, @@ -210,6 +211,7 @@ "MultiStepTransform", "MultiThreadedEnv", "MultiThreadedEnvWrapper", + "NextObservationDelta", "NoopResetEnv", "ObservationNorm", "ObservationTransform", diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 117953ac1de..52ac5d00f4e 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3892,8 +3892,8 @@ def step_and_maybe_reset( tensordict ) tensordict_ = self._step_mdp(tensordict) - # if self._post_step_mdp_hooks is not None: - # tensordict_ = self._post_step_mdp_hooks(tensordict_) + if self._post_step_mdp_hooks is not None: + tensordict_ = self._post_step_mdp_hooks(tensordict, tensordict_) if native_autoreset: for obs_key, obs in reset_observations.items(): if obs_key in tensordict_.keys(True, True): @@ -3908,7 +3908,22 @@ def step_and_maybe_reset( tensordict_ = self.maybe_reset(tensordict_) return tensordict, tensordict_ - # _post_step_mdp_hooks: Callable[[TensorDictBase], TensorDictBase] | None = None + _post_step_mdp_hooks: Callable[ + [TensorDictBase, TensorDictBase], TensorDictBase + ] | None = None + """Optional hook called after :meth:`_step_mdp` inside :meth:`step_and_maybe_reset`. + + Signature: ``(tensordict, tensordict_) -> tensordict_`` where ``tensordict`` + is the post-step tensordict (still carrying ``("next", ...)`` entries) and + ``tensordict_`` is the result of ``step_mdp``. Used by transforms that need + to modify the flowing tensordict the policy will read on the next iteration + (for example to rehydrate observations that were compressed in + :meth:`Transform._step`). + + Defaults to ``None``: when unset, the hook is skipped. Transforms that need + it are expected to expose a ``_post_step_mdp_hooks`` method themselves and + rely on :class:`~torchrl.envs.TransformedEnv` to delegate the call. + """ @property @_cache_value diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 004f77530bd..f0a36fae4ac 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -40,6 +40,7 @@ InitTracker, LineariseRewards, MultiAction, + NextObservationDelta, NoopResetEnv, ObservationNorm, ObservationTransform, @@ -112,6 +113,7 @@ "ModuleTransform", "MultiAction", "MultiStepTransform", + "NextObservationDelta", "NextStateReconstructor", "NoopResetEnv", "ObservationNorm", diff --git a/torchrl/envs/transforms/_base.py b/torchrl/envs/transforms/_base.py index 4e79304046c..5a1c9b85cff 100644 --- a/torchrl/envs/transforms/_base.py +++ b/torchrl/envs/transforms/_base.py @@ -386,6 +386,39 @@ def _step( next_tensordict = self._call(next_tensordict) return next_tensordict + def _post_step_mdp_hooks( + self, + tensordict: TensorDictBase, + tensordict_: TensorDictBase, + ) -> TensorDictBase: + """Hook called after :func:`~torchrl.envs.utils.step_mdp` inside ``step_and_maybe_reset``. + + Override when a transform needs to modify the tensordict the policy + will read on the next iteration *after* root keys have been promoted + from ``("next", ...)``. This is the natural place to undo / rehydrate + a representation that was compressed in :meth:`_step` (for example + a low-precision delta) before the policy sees the next observation. + + Args: + tensordict (TensorDictBase): post-step tensordict, still carrying + the ``("next", ...)`` sub-tensordict. + tensordict_ (TensorDictBase): post-step-mdp tensordict, with root + keys promoted from ``("next", ...)``. This is what the next + policy call will receive (after a possible reset). + + Returns: + The (possibly modified) ``tensordict_``. + + .. note:: Transforms that implement this hook must rely on the env they + are attached to wiring it up. :class:`~torchrl.envs.TransformedEnv` + delegates ``EnvBase._post_step_mdp_hooks`` to + ``self.transform._post_step_mdp_hooks``, so a transform appended to + a ``TransformedEnv`` is picked up automatically. Non-collector + entry points (e.g. ``env.rollout()``) currently do not invoke this + hook. + """ + return tensordict_ + def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform. @@ -1004,12 +1037,17 @@ def add_truncated_keys(self) -> TransformedEnv: self.empty_cache() return self - # def _post_step_mdp_hooks(self, tensordict: TensorDictBase) -> TensorDictBase: - # """Allows modification of the tensordict after the step_mdp.""" - # if type(self.base_env)._post_step_mdp_hooks is not None: - # If the base env has a _post_step_mdp_hooks, we call it - # tensordict = self.base_env._post_step_mdp_hooks(tensordict) - # return tensordict + def _post_step_mdp_hooks( + self, + tensordict: TensorDictBase, + tensordict_: TensorDictBase, + ) -> TensorDictBase: + """Run the transform-chain post-step-mdp hook, then the base env's own.""" + tensordict_ = self.transform._post_step_mdp_hooks(tensordict, tensordict_) + base_env = self.base_env + if base_env is not None and base_env._post_step_mdp_hooks is not None: + tensordict_ = base_env._post_step_mdp_hooks(tensordict, tensordict_) + return tensordict_ def _set_env(self, env: EnvBase, device) -> None: if device != env.device: @@ -1594,6 +1632,15 @@ def _step( next_tensordict = t._step(tensordict, next_tensordict) return next_tensordict + def _post_step_mdp_hooks( + self, + tensordict: TensorDictBase, + tensordict_: TensorDictBase, + ) -> TensorDictBase: + for t in self.transforms: + tensordict_ = t._post_step_mdp_hooks(tensordict, tensordict_) + return tensordict_ + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: for t in reversed(self.transforms): tensordict = t._inv_call(tensordict) diff --git a/torchrl/envs/transforms/_observation.py b/torchrl/envs/transforms/_observation.py index 7f7a4620b02..b0f20e91132 100644 --- a/torchrl/envs/transforms/_observation.py +++ b/torchrl/envs/transforms/_observation.py @@ -8,7 +8,7 @@ import warnings from collections.abc import Sequence from copy import copy -from typing import Any, TYPE_CHECKING +from typing import Any, Literal, TYPE_CHECKING import torch @@ -44,6 +44,7 @@ "Crop", "FlattenObservation", "GrayScale", + "NextObservationDelta", "PermuteTransform", "Resize", "SqueezeTransform", @@ -1419,3 +1420,207 @@ def __repr__(self) -> str: f"{self.__class__.__name__}(N={self.N}, dim" f"={self.dim}, keys={self.in_keys})" ) + + +class NextObservationDelta(Transform): + """Stores ``("next", obs)`` as a low-precision delta from the root ``obs``. + + For environments with large continuous observations, rollouts collected by + :class:`~torchrl.collectors.SyncDataCollector` (and friends) hold both the + root observation and the ``("next", ...)`` mirror at full precision in the + stacked output. This transform reduces that footprint by replacing + ``next_tensordict[key]`` with the casted delta + ``(next_obs - obs).to(delta_dtype)`` inside :meth:`_step`, and rehydrating + the flowing tensordict's root observation in :meth:`_post_step_mdp_hooks` + so the policy still reads a full-precision tensor on the next iteration. + + The transform is **stateless**: the rehydration uses the post-step + tensordict's root ``obs`` (still full precision) and the post-step-mdp + tensordict's promoted delta. No private cache is required, which keeps + it well-behaved under :class:`~torchrl.envs.ParallelEnv` as long as the + transform is appended *inside* the worker environment. + + Args: + in_keys (sequence of NestedKey, optional): observation keys whose + ``("next", k)`` should be compressed. Defaults to ``None``, in + which case the transform lazily walks + ``parent.observation_spec`` and picks every floating-point leaf + whose dtype is not in ``excluded_dtypes``. + + Keyword Args: + delta_dtype (torch.dtype, optional): dtype in which the delta is + stored. Must be a floating dtype. Defaults to ``torch.float16``. + restore_dtype (torch.dtype or ``"root"``, optional): dtype the + rehydrated observation is cast to. ``"root"`` (default) matches + the root ``obs`` dtype at runtime. + auto_skip (bool, optional): if ``True`` (default), per-key skip the + transform when ``("next", key)`` is already in ``delta_dtype`` + (idempotent under repeated application). + excluded_dtypes (tuple of torch.dtype, optional): dtypes to skip when + auto-inferring ``in_keys``. Defaults to the integer + bool + family. Floating-point observations always pass through; pass + an explicit ``in_keys`` list to override. + + .. warning:: + The compression is **lossy**: round-tripping through ``delta_dtype`` + loses precision, particularly for unnormalized observations whose + magnitudes exceed the dtype range or fall below its smallest + representable step. + + .. warning:: + Rollout memory savings only materialize when the stacked output is + **not** pre-allocated at full precision. With + :class:`~torchrl.collectors.SyncDataCollector` set ``use_buffers=False`` + (or use a lazy replay-buffer storage). Pre-allocated + ``_final_rollout`` buffers will upcast writes back to the original + dtype and erase the saving. + + .. warning:: + The post-step-mdp rehydration is wired through + :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset`, which is the + entry point used by the data collectors. ``env.rollout()`` does not + currently invoke the hook, so when used directly users should + rehydrate manually if they need the next obs at full precision. + + Example: + >>> import torch + >>> from torchrl.envs import GymEnv, TransformedEnv + >>> from torchrl.envs.transforms import NextObservationDelta + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), NextObservationDelta()) + >>> td_root = env.reset() + >>> _ = td_root.set("action", env.action_spec.rand()) + >>> td, td_ = env.step_and_maybe_reset(td_root) + >>> td["next", "observation"].dtype + torch.float16 + >>> td_["observation"].dtype + torch.float32 + """ + + def __init__( + self, + in_keys: Sequence[NestedKey] | None = None, + *, + delta_dtype: torch.dtype = torch.float16, + restore_dtype: torch.dtype | Literal["root"] = "root", + auto_skip: bool = True, + excluded_dtypes: tuple[torch.dtype, ...] = ( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.bool, + ), + ): + if not delta_dtype.is_floating_point: + raise ValueError( + f"delta_dtype must be a floating-point dtype, got {delta_dtype}." + ) + if restore_dtype != "root" and not ( + isinstance(restore_dtype, torch.dtype) and restore_dtype.is_floating_point + ): + raise ValueError( + f"restore_dtype must be a floating-point dtype or 'root', got " + f"{restore_dtype!r}." + ) + self.delta_dtype = delta_dtype + self.restore_dtype = restore_dtype + self.auto_skip = auto_skip + self.excluded_dtypes = tuple(excluded_dtypes) + super().__init__(in_keys=in_keys, out_keys=in_keys) + + @property + def in_keys(self) -> Sequence[NestedKey] | None: + in_keys = self.__dict__.get("_in_keys", None) + if in_keys is None: + parent = self.parent + if parent is None: + return None + in_keys = [] + for key, spec in parent.observation_spec.items(True, True): + dtype = spec.dtype + if dtype is None: + continue + if dtype in self.excluded_dtypes: + continue + if not dtype.is_floating_point: + continue + in_keys.append(unravel_key(key)) + self._in_keys = in_keys + if self.__dict__.get("_out_keys", None) is None: + self._out_keys = copy(in_keys) + return in_keys + + @in_keys.setter + def in_keys(self, value: Sequence[NestedKey] | None) -> None: + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(v) for v in value] + self._in_keys = value + + @property + def out_keys(self) -> Sequence[NestedKey] | None: + out_keys = self.__dict__.get("_out_keys", None) + if out_keys is None: + in_keys = self.in_keys + if in_keys is None: + return None + out_keys = self._out_keys = copy(in_keys) + return out_keys + + @out_keys.setter + def out_keys(self, value: Sequence[NestedKey] | None) -> None: + if value is not None: + if isinstance(value, (str, tuple)): + value = [value] + value = [unravel_key(v) for v in value] + self._out_keys = value + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + in_keys = self.in_keys + if not in_keys: + return next_tensordict + for key in in_keys: + obs = tensordict.get(key, default=None) + next_obs = next_tensordict.get(key, default=None) + if obs is None or next_obs is None: + continue + if self.auto_skip and next_obs.dtype == self.delta_dtype: + continue + delta = next_obs.to(self.delta_dtype) - obs.to(self.delta_dtype) + next_tensordict.set(key, delta) + return next_tensordict + + def _post_step_mdp_hooks( + self, + tensordict: TensorDictBase, + tensordict_: TensorDictBase, + ) -> TensorDictBase: + in_keys = self.in_keys + if not in_keys: + return tensordict_ + for key in in_keys: + root = tensordict.get(key, default=None) + delta = tensordict_.get(key, default=None) + if root is None or delta is None: + continue + dtype = root.dtype if self.restore_dtype == "root" else self.restore_dtype + tensordict_.set(key, root.to(dtype) + delta.to(dtype)) + return tensordict_ + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + raise NotImplementedError( + f"{type(self).__name__} is an env-side transform; calling it directly " + "(e.g. as a replay buffer transform) is not supported. For RB-side " + "reconstruction of `('next', obs)`, see " + "`torchrl.envs.transforms.rb_transforms.NextStateReconstructor`." + ) + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(in_keys={self.__dict__.get('_in_keys', None)}, " + f"delta_dtype={self.delta_dtype}, restore_dtype={self.restore_dtype!r})" + ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 74618d63506..52545432074 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -78,6 +78,7 @@ Crop, FlattenObservation, GrayScale, + NextObservationDelta, PermuteTransform, Resize, SqueezeTransform, @@ -133,6 +134,7 @@ "InitTracker", "LineariseRewards", "MultiAction", + "NextObservationDelta", "NoopResetEnv", "ObservationNorm", "ObservationTransform", From de715c4afa62fa0be5d1913a79a3a9c5436b817e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 May 2026 18:05:58 +0100 Subject: [PATCH 2/4] Address review: env.rollout hook, pre-alloc buffers, batched-env guard - Wire `_post_step_mdp_hooks` in `EnvBase._rollout_stop_early` so `env.rollout(..., break_when_any_done=True)` rehydrates the flowing td just like `step_and_maybe_reset` already did. The non-stop path already routed through `step_and_maybe_reset` and is unchanged. - Add `Transform.transform_fake_tensordict(td)` hook (no-op default), iterated by `Compose`, called by a new `TransformedEnv.fake_tensordict` override. `NextObservationDelta` overrides it to cast the `("next", key)` leaves to `delta_dtype` in the spec-derived fake td. Pre-allocated `_final_rollout` in `SyncDataCollector(use_buffers=True)` now reserves storage at the compressed dtype rather than upcasting writes; the collector test covers both `use_buffers={True, False}`. - Add `Transform._check_batched_worker_compat()` (no-op default). `NextObservationDelta` raises with a clear message pointing at the correct usage pattern. `BatchedEnvBase._get_metadata` builds a transient probe env and runs the validator via a new `env_validator` kwarg on `get_env_metadata`, so the inner-batched configuration fails loudly at construction time instead of silently upcasting at runtime. The remaining v1 caveat in the docstring is that `check_env_specs` still does not pass: it calls `observation_spec.contains(("next", obs))` and TorchRL shares `observation_spec` between root and `("next", ...)` leaves, so a compressed dtype is rejected. Working around this properly requires forking the spec system, which is out of scope for this PR. Tests use a reset+step smoke instead. --- .../transforms/test_observation_transforms.py | 109 +++++++++++------- torchrl/envs/batched_envs.py | 23 +++- torchrl/envs/common.py | 5 +- torchrl/envs/env_creator.py | 35 +++++- torchrl/envs/transforms/_base.py | 57 ++++++++- torchrl/envs/transforms/_observation.py | 51 +++++--- 6 files changed, 215 insertions(+), 65 deletions(-) diff --git a/test/transforms/test_observation_transforms.py b/test/transforms/test_observation_transforms.py index 12f9f081aa3..39f906d525d 100644 --- a/test/transforms/test_observation_transforms.py +++ b/test/transforms/test_observation_transforms.py @@ -2803,65 +2803,60 @@ def _delta_tol(delta_dtype: torch.dtype, scale: float = 1.0) -> float: # being represented to get a meaningful tolerance. return torch.finfo(delta_dtype).eps * 8.0 * scale - # `check_env_specs` is intentionally NOT used here. NextObservationDelta - # changes the runtime dtype of `("next", obs)` to `delta_dtype` while - # leaving `observation_spec` untouched (root and `("next", ...)` share the - # same spec in TorchRL, and we don't fork it in v1). The spec-vs-runtime - # check would therefore reject the env. We use a reset + step smoke test - # instead, which is what callers actually do. + # `check_env_specs` is not used here. It enforces + # ``observation_spec.contains(("next", obs))``, and TorchRL's spec system + # shares observation_spec between root and ("next", ...) leaves. Since + # this transform deliberately compresses ("next", obs) below the spec + # dtype, `spec.contains` fails. We assert end-to-end behavior instead: + # the env steps, the rollout / collector batch carries the compressed + # dtype, and the rehydrated flowing td is full precision. @staticmethod - def _smoke_one_step(env, *, expect_compressed: bool): + def _smoke_one_step(env): td = env.reset() td.set("action", env.action_spec.rand()) post_step, flowing = env.step_and_maybe_reset(td) - if expect_compressed: - assert post_step["next", "observation"].dtype == torch.float16 - assert flowing["observation"].dtype == torch.float32 - else: - # batched-env wraps a TransformedEnv worker -- the outer batched - # env's step_and_maybe_reset does not invoke the hook, and may - # upcast through pre-allocated buffers. We only assert the env - # boots and steps without raising. - assert post_step["next", "observation"].shape == flowing["observation"].shape + assert post_step["next", "observation"].dtype == torch.float16 + assert flowing["observation"].dtype == torch.float32 def test_single_trans_env_check(self): env = TransformedEnv( ContinuousActionVecMockEnv(), NextObservationDelta(in_keys=["observation"]), ) - self._smoke_one_step(env, expect_compressed=True) + self._smoke_one_step(env) def test_serial_trans_env_check(self): - env = SerialEnv( - 2, - lambda: TransformedEnv( - ContinuousActionVecMockEnv(), - NextObservationDelta(in_keys=["observation"]), - ), - ) - self._smoke_one_step(env, expect_compressed=False) + # NextObservationDelta inside a batched-env worker must raise at + # construction time -- the outer batched env's step_and_maybe_reset + # does not propagate the worker's transform hook and pre-allocated + # batched buffers would upcast the write anyway. + with pytest.raises(RuntimeError, match="cannot live inside a SerialEnv"): + SerialEnv( + 2, + lambda: TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ), + ) def test_parallel_trans_env_check(self): - env = ParallelEnv( - 2, - lambda: TransformedEnv( - ContinuousActionVecMockEnv(), - NextObservationDelta(in_keys=["observation"]), - ), - mp_start_method="fork", - ) - try: - self._smoke_one_step(env, expect_compressed=False) - finally: - env.close() + with pytest.raises(RuntimeError, match="cannot live inside a SerialEnv"): + ParallelEnv( + 2, + lambda: TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ), + mp_start_method="fork", + ) def test_trans_serial_env_check(self): env = TransformedEnv( SerialEnv(2, lambda: ContinuousActionVecMockEnv()), NextObservationDelta(in_keys=["observation"]), ) - self._smoke_one_step(env, expect_compressed=True) + self._smoke_one_step(env) def test_trans_parallel_env_check(self): env = TransformedEnv( @@ -2869,7 +2864,7 @@ def test_trans_parallel_env_check(self): NextObservationDelta(in_keys=["observation"]), ) try: - self._smoke_one_step(env, expect_compressed=True) + self._smoke_one_step(env) finally: env.close() @@ -3024,9 +3019,11 @@ def test_compose_with_downstream_transform(self): assert out["next", "observation"].dtype == torch.float16 assert out_["observation"].dtype == torch.float32 - def test_collector_use_buffers_false(self): - # End-to-end: with use_buffers=False (no pre-allocated final_rollout), - # the stacked rollout actually carries float16 ("next", obs). + @pytest.mark.parametrize("use_buffers", [True, False]) + def test_collector_compressed(self, use_buffers): + # End-to-end: the stacked rollout actually carries float16 + # ("next", obs) in both the pre-allocated (use_buffers=True) and the + # lazy (use_buffers=False) collector paths. from torchrl.collectors import SyncDataCollector torch.manual_seed(4) @@ -3042,7 +3039,7 @@ def make_env(): policy=None, frames_per_batch=16, total_frames=16, - use_buffers=False, + use_buffers=use_buffers, ) try: batch = next(iter(collector)) @@ -3052,9 +3049,33 @@ def make_env(): assert batch["next", "observation"].dtype == torch.float16 assert batch["observation"].dtype == torch.float32 - # Reconstruct next.obs at full precision and verify shape/finiteness. recon = batch["observation"].to(torch.float32) + batch["next", "observation"].to( torch.float32 ) assert recon.shape == batch["observation"].shape assert torch.isfinite(recon).all() + + def test_env_rollout_hook_fires(self): + # env.rollout() must invoke _post_step_mdp_hooks so the flowing td + # carries full-precision root obs even on the stop-early path. + torch.manual_seed(5) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ) + out = env.rollout(8, break_when_any_done=True) + # Stacked ("next", obs) is compressed; root obs is full precision + # because each iteration's flowing td was rehydrated by the hook + # before becoming the next iteration's root. + assert out["next", "observation"].dtype == torch.float16 + assert out["observation"].dtype == torch.float32 + + def test_env_rollout_nonstop_hook_fires(self): + torch.manual_seed(6) + env = TransformedEnv( + ContinuousActionVecMockEnv(), + NextObservationDelta(in_keys=["observation"]), + ) + out = env.rollout(8, break_when_any_done=False) + assert out["next", "observation"].dtype == torch.float16 + assert out["observation"].dtype == torch.float32 diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b3248c4eb33..d74c8028d46 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -714,13 +714,28 @@ def __getstate__(self): def _has_dynamic_specs(self): return not self._use_buffers + @staticmethod + def _validate_worker_env(env) -> None: + """Walk transforms on a worker env and call each transform's + :meth:`~torchrl.envs.transforms.Transform._check_batched_worker_compat`. + + Transforms that should not live inside a batched-env worker raise here + with a clear message so the user gets immediate feedback rather than + silently-wrong runtime behavior. + """ + transform = getattr(env, "transform", None) + if transform is not None: + transform._check_batched_worker_compat() + def _get_metadata( self, create_env_fn: list[Callable], create_env_kwargs: list[dict] ): if self._single_task: # if EnvCreator, the metadata are already there meta_data: EnvMetaData = get_env_metadata( - create_env_fn[0], create_env_kwargs[0] + create_env_fn[0], + create_env_kwargs[0], + env_validator=self._validate_worker_env, ) self.meta_data = meta_data.expand( *(self.num_workers, *meta_data.batch_size) @@ -740,7 +755,11 @@ def _get_metadata( self.meta_data: list[EnvMetaData] = [] for i in range(n_tasks): self.meta_data.append( - get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone() + get_env_metadata( + create_env_fn[i], + create_env_kwargs[i], + env_validator=self._validate_worker_env, + ).clone() ) if self.share_individual_td is not True: share_individual_td = not _stackable( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 52ac5d00f4e..896a4106503 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3756,7 +3756,10 @@ def _rollout_stop_early( if i == max_steps - 1: # we don't truncate as one could potentially continue the run break - tensordict = self._step_mdp(tensordict) + post_step_td = tensordict + tensordict = self._step_mdp(post_step_td) + if self._post_step_mdp_hooks is not None: + tensordict = self._post_step_mdp_hooks(post_step_td, tensordict) if break_when_any_done: # done and truncated are in done_keys diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index caa2d759a8b..7b86a945656 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -233,9 +233,27 @@ def env_creator(fun: Callable) -> EnvCreator: return EnvCreator(fun) -def get_env_metadata(env_or_creator: EnvBase | Callable, kwargs: dict | None = None): - """Retrieves a EnvMetaData object from an env.""" +def get_env_metadata( + env_or_creator: EnvBase | Callable, + kwargs: dict | None = None, + *, + env_validator: Callable[[EnvBase], None] | None = None, +): + """Retrieves a EnvMetaData object from an env. + + Args: + env_or_creator: env instance or a creator callable / :class:`EnvCreator`. + kwargs: optional kwargs forwarded to a creator. + + Keyword Args: + env_validator: optional callable invoked on the (possibly transient) + env instance before its metadata is extracted. Used by batched + envs to fail loudly on worker-incompatible transforms at + construction time. + """ if isinstance(env_or_creator, (EnvBase,)): + if env_validator is not None: + env_validator(env_or_creator) return EnvMetaData.metadata_from_env(env_or_creator) elif not isinstance(env_or_creator, EnvBase) and not isinstance( env_or_creator, EnvCreator @@ -244,6 +262,8 @@ def get_env_metadata(env_or_creator: EnvBase | Callable, kwargs: dict | None = N if kwargs is None: kwargs = {} env = env_or_creator(**kwargs) + if env_validator is not None: + env_validator(env) return EnvMetaData.metadata_from_env(env) elif isinstance(env_or_creator, EnvCreator): if not ( @@ -256,6 +276,17 @@ def get_env_metadata(env_or_creator: EnvBase | Callable, kwargs: dict | None = N f"got EnvCreator.create_env_kwargs={env_or_creator.create_env_kwargs} and " f"kwargs = {kwargs}" ) + if env_validator is not None: + # EnvCreator caches the env on .meta_data; build a transient + # instance to validate the live transform chain. + transient = env_or_creator() + try: + env_validator(transient) + finally: + try: + transient.close() + except Exception: + pass return env_or_creator.meta_data.clone() else: raise NotImplementedError( diff --git a/torchrl/envs/transforms/_base.py b/torchrl/envs/transforms/_base.py index 5a1c9b85cff..25c74ae0eb0 100644 --- a/torchrl/envs/transforms/_base.py +++ b/torchrl/envs/transforms/_base.py @@ -413,12 +413,45 @@ def _post_step_mdp_hooks( are attached to wiring it up. :class:`~torchrl.envs.TransformedEnv` delegates ``EnvBase._post_step_mdp_hooks`` to ``self.transform._post_step_mdp_hooks``, so a transform appended to - a ``TransformedEnv`` is picked up automatically. Non-collector - entry points (e.g. ``env.rollout()``) currently do not invoke this - hook. + a ``TransformedEnv`` is picked up automatically. The hook fires + from :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset` (used by + data collectors and the non-stop path of :meth:`~torchrl.envs.EnvBase.rollout`) + and from the stop-early path of :meth:`~torchrl.envs.EnvBase.rollout`. """ return tensordict_ + def _check_batched_worker_compat(self) -> None: + """Raise if this transform should not live inside a batched-env worker. + + :class:`~torchrl.envs.SerialEnv` and :class:`~torchrl.envs.ParallelEnv` + call this on every transform of every worker env at construction + time. Transforms whose semantics rely on + :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset` or + :meth:`~torchrl.envs.EnvBase.rollout` hooks running on the *outer* + env (rather than the worker) override this to raise a clear error. + + The default is a no-op. + """ + return None + + def transform_fake_tensordict( + self, fake_tensordict: TensorDictBase + ) -> TensorDictBase: + """Adjust the env's ``fake_tensordict`` after it is built from specs. + + :meth:`~torchrl.envs.EnvBase.fake_tensordict` constructs a zero-filled + tensordict from the env's specs, which is used by data collectors to + pre-allocate the rollout storage. The TorchRL spec system shares the + observation spec between the root and ``("next", ...)`` leaves, so + transforms that want the runtime ``("next", k)`` dtype to differ from + the root ``k`` dtype need a way to fix up the fake tensordict here. + + The default is a no-op. Override only when the runtime tensordict your + transform produces does not match what the spec-derived fake + tensordict would imply. + """ + return fake_tensordict + def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform. @@ -1049,6 +1082,13 @@ def _post_step_mdp_hooks( tensordict_ = base_env._post_step_mdp_hooks(tensordict, tensordict_) return tensordict_ + def fake_tensordict(self) -> TensorDictBase: + """Build a fake tensordict and let the transform chain post-process it.""" + fake_td = super().fake_tensordict() + if self.transform is not None: + fake_td = self.transform.transform_fake_tensordict(fake_td) + return fake_td + def _set_env(self, env: EnvBase, device) -> None: if device != env.device: env = env.to(device) @@ -1641,6 +1681,17 @@ def _post_step_mdp_hooks( tensordict_ = t._post_step_mdp_hooks(tensordict, tensordict_) return tensordict_ + def transform_fake_tensordict( + self, fake_tensordict: TensorDictBase + ) -> TensorDictBase: + for t in self.transforms: + fake_tensordict = t.transform_fake_tensordict(fake_tensordict) + return fake_tensordict + + def _check_batched_worker_compat(self) -> None: + for t in self.transforms: + t._check_batched_worker_compat() + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: for t in reversed(self.transforms): tensordict = t._inv_call(tensordict) diff --git a/torchrl/envs/transforms/_observation.py b/torchrl/envs/transforms/_observation.py index b0f20e91132..84bf3e0d79a 100644 --- a/torchrl/envs/transforms/_observation.py +++ b/torchrl/envs/transforms/_observation.py @@ -1468,19 +1468,11 @@ class NextObservationDelta(Transform): representable step. .. warning:: - Rollout memory savings only materialize when the stacked output is - **not** pre-allocated at full precision. With - :class:`~torchrl.collectors.SyncDataCollector` set ``use_buffers=False`` - (or use a lazy replay-buffer storage). Pre-allocated - ``_final_rollout`` buffers will upcast writes back to the original - dtype and erase the saving. - - .. warning:: - The post-step-mdp rehydration is wired through - :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset`, which is the - entry point used by the data collectors. ``env.rollout()`` does not - currently invoke the hook, so when used directly users should - rehydrate manually if they need the next obs at full precision. + The transform must live **outside** any batched env + (``TransformedEnv(ParallelEnv(N, factory), NextObservationDelta())``). + Building a :class:`~torchrl.envs.SerialEnv` / + :class:`~torchrl.envs.ParallelEnv` whose worker contains a + ``NextObservationDelta`` raises at construction time. Example: >>> import torch @@ -1611,6 +1603,39 @@ def _post_step_mdp_hooks( tensordict_.set(key, root.to(dtype) + delta.to(dtype)) return tensordict_ + def _check_batched_worker_compat(self) -> None: + raise RuntimeError( + f"{type(self).__name__} cannot live inside a SerialEnv/ParallelEnv " + "worker: the post-step-mdp rehydration relies on the outer env's " + "`step_and_maybe_reset` invoking the hook, but a batched env's " + "`step_and_maybe_reset` does not propagate the worker's transform " + "hook, and the batched output is upcast through the shared spec " + "buffer. Place the transform OUTSIDE the batched env instead, " + "e.g. `TransformedEnv(ParallelEnv(N, base_env_factory), " + f"{type(self).__name__}(...))`." + ) + + def transform_fake_tensordict( + self, fake_tensordict: TensorDictBase + ) -> TensorDictBase: + # Cast the ("next", key) leaves to delta_dtype so collectors (and any + # other consumer of env.fake_tensordict()) pre-allocate storage for + # the compressed dtype rather than the spec dtype. Root keys are left + # at the spec dtype because the transform writes full-precision obs + # at root (via rehydration on the flowing tensordict). + in_keys = self.in_keys + if not in_keys: + return fake_tensordict + next_td = fake_tensordict.get("next", default=None) + if next_td is None: + return fake_tensordict + for key in in_keys: + leaf = next_td.get(key, default=None) + if leaf is None: + continue + next_td.set(key, leaf.to(self.delta_dtype)) + return fake_tensordict + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: raise NotImplementedError( f"{type(self).__name__} is an env-side transform; calling it directly " From 062d97e8f37fa19f454e3461011e1d28f93cf592 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 May 2026 18:52:30 +0100 Subject: [PATCH 3/4] Compute delta in source dtype, cast once Subtracting in delta_dtype (float16 by default) risks catastrophic cancellation when next_obs and obs are close. Doing the subtraction in the operands' source dtype and casting the result once preserves significand bits and is strictly more accurate on round-trip. The stored root obs is unchanged, so there is no asymmetry to preserve between the on-the-fly delta and the value reconstructed from storage. --- torchrl/envs/transforms/_observation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/_observation.py b/torchrl/envs/transforms/_observation.py index 84bf3e0d79a..1512233b23b 100644 --- a/torchrl/envs/transforms/_observation.py +++ b/torchrl/envs/transforms/_observation.py @@ -1582,7 +1582,11 @@ def _step( continue if self.auto_skip and next_obs.dtype == self.delta_dtype: continue - delta = next_obs.to(self.delta_dtype) - obs.to(self.delta_dtype) + # Subtract in the source (typically full-precision) dtype, then + # cast once. This loses fewer significant bits than casting each + # operand to ``delta_dtype`` first and subtracting in low precision + # (which would risk catastrophic cancellation for nearby values). + delta = (next_obs - obs).to(self.delta_dtype) next_tensordict.set(key, delta) return next_tensordict From 0ef498344853a25465a3a02a0799ca00fc068d7c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 May 2026 18:55:58 +0100 Subject: [PATCH 4/4] Lint fixes: flake8 C409, pydocstyle D205/D415/D417, ufmt --- .../transforms/test_observation_transforms.py | 37 ++++++++++++------- torchrl/envs/batched_envs.py | 9 +++-- torchrl/envs/env_creator.py | 2 - 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/test/transforms/test_observation_transforms.py b/test/transforms/test_observation_transforms.py index 39f906d525d..a2d47d81daa 100644 --- a/test/transforms/test_observation_transforms.py +++ b/test/transforms/test_observation_transforms.py @@ -2860,7 +2860,9 @@ def test_trans_serial_env_check(self): def test_trans_parallel_env_check(self): env = TransformedEnv( - ParallelEnv(2, lambda: ContinuousActionVecMockEnv(), mp_start_method="fork"), + ParallelEnv( + 2, lambda: ContinuousActionVecMockEnv(), mp_start_method="fork" + ), NextObservationDelta(in_keys=["observation"]), ) try: @@ -2897,11 +2899,12 @@ def test_transform_env(self): # Rehydrated flowing root == root_obs + delta (round-tripped through # the delta dtype). - expected = ( - post_step["observation"].to(torch.float32) - + post_step["next", "observation"].to(torch.float32) + expected = post_step["observation"].to(torch.float32) + post_step[ + "next", "observation" + ].to(torch.float32) + tol = self._delta_tol( + torch.float16, scale=max(1.0, expected.abs().max().item()) ) - tol = self._delta_tol(torch.float16, scale=max(1.0, expected.abs().max().item())) torch.testing.assert_close( flowing["observation"].to(torch.float32), expected, atol=tol, rtol=tol ) @@ -2919,9 +2922,7 @@ def test_transform_rb(self): ) # Extend goes through `inv`, which is a no-op for this transform # (it has no in_keys_inv) -- so the write succeeds. - rb.extend( - TensorDict({"observation": torch.zeros(4, 3)}, batch_size=[4]) - ) + rb.extend(TensorDict({"observation": torch.zeros(4, 3)}, batch_size=[4])) # Sampling, however, routes through `forward`, which is unsupported # for this env-side-only transform. with pytest.raises(NotImplementedError, match="env-side transform"): @@ -2950,10 +2951,16 @@ class _DualObsEnv(ContinuousActionVecMockEnv): ) # Access lazy in_keys via the transform. in_keys = list(env.transform.in_keys) - assert ("observation",) in [tuple([k]) if isinstance(k, str) else tuple(k) for k in in_keys] + assert ("observation",) in [ + (k,) if isinstance(k, str) else tuple(k) for k in in_keys + ] # No uint8 leaf made it in. for k in in_keys: - spec = env.observation_spec[k] if not isinstance(k, tuple) else env.observation_spec[k] + spec = ( + env.observation_spec[k] + if not isinstance(k, tuple) + else env.observation_spec[k] + ) assert spec.dtype.is_floating_point def test_multi_in_keys_explicit(self): @@ -2994,7 +3001,9 @@ def test_reset_between_steps(self): expected = out["observation"].to(torch.float32) + out["next", "observation"].to( torch.float32 ) - tol = self._delta_tol(torch.float16, scale=max(1.0, expected.abs().max().item())) + tol = self._delta_tol( + torch.float16, scale=max(1.0, expected.abs().max().item()) + ) torch.testing.assert_close( out_["observation"].to(torch.float32), expected, atol=tol, rtol=tol ) @@ -3049,9 +3058,9 @@ def make_env(): assert batch["next", "observation"].dtype == torch.float16 assert batch["observation"].dtype == torch.float32 - recon = batch["observation"].to(torch.float32) + batch["next", "observation"].to( - torch.float32 - ) + recon = batch["observation"].to(torch.float32) + batch[ + "next", "observation" + ].to(torch.float32) assert recon.shape == batch["observation"].shape assert torch.isfinite(recon).all() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index d74c8028d46..12ad0fc652a 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -716,11 +716,12 @@ def _has_dynamic_specs(self): @staticmethod def _validate_worker_env(env) -> None: - """Walk transforms on a worker env and call each transform's - :meth:`~torchrl.envs.transforms.Transform._check_batched_worker_compat`. + """Check that each transform on a worker env is batched-env compatible. - Transforms that should not live inside a batched-env worker raise here - with a clear message so the user gets immediate feedback rather than + Walks ``env.transform`` and invokes + :meth:`~torchrl.envs.transforms.Transform._check_batched_worker_compat` + on each entry. Transforms that should not live inside a batched-env + worker raise here so the user gets immediate feedback rather than silently-wrong runtime behavior. """ transform = getattr(env, "transform", None) diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 7b86a945656..2078e8cb305 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -244,8 +244,6 @@ def get_env_metadata( Args: env_or_creator: env instance or a creator callable / :class:`EnvCreator`. kwargs: optional kwargs forwarded to a creator. - - Keyword Args: env_validator: optional callable invoked on the (possibly transient) env instance before its metadata is extracted. Used by batched envs to fail loudly on worker-incompatible transforms at