[Feature][Performance] NextObservationDelta env transform#3777
Open
vmoens wants to merge 4 commits into
Open
Conversation
Adds a stateless env-side transform that stores `("next", obs)` as a
low-precision delta from the root `obs`, reducing the rollout-time
memory footprint of large continuous observations.
The transform compresses next observations in `_step` and rehydrates
the flowing tensordict's root observation in a new
`_post_step_mdp_hooks` extension point on `EnvBase`. The hook was
previously half-stubbed in `common.py` / `_base.py` / `llm/chat.py`;
it is now wired through `step_and_maybe_reset` and threaded into
`Transform`, `Compose`, and `TransformedEnv`.
Caveats documented on the class:
- The compression is lossy; round-trip error scales with delta dtype
precision and observation magnitude.
- Memory savings only materialize against non-pre-allocated stacked
output (e.g. `SyncDataCollector(use_buffers=False)` or a lazy RB
storage). Pre-allocated buffers upcast the write.
- The hook fires from `step_and_maybe_reset`; direct `env.rollout()`
callers must rehydrate manually.
- `check_env_specs` rejects the transformed env in v1 because the
observation spec is shared between root and `("next", ...)` and we
do not fork it.
Includes a `TestNextObservationDelta` test class with 16 cases
(14 passing, 2 documented skips) covering single-env, serial/parallel
batched envs (inner and outer wrapping), auto-inference skipping
non-floating dtypes, multi-key, reset semantics, Compose ordering,
and an end-to-end `SyncDataCollector(use_buffers=False)` check that
the stacked batch carries `float16` `("next", obs)`.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3777
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New FailuresAs of commit 0ef4983 with merge base 996387f ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
- Wire `_post_step_mdp_hooks` in `EnvBase._rollout_stop_early` so
`env.rollout(..., break_when_any_done=True)` rehydrates the flowing
td just like `step_and_maybe_reset` already did. The non-stop path
already routed through `step_and_maybe_reset` and is unchanged.
- Add `Transform.transform_fake_tensordict(td)` hook (no-op default),
iterated by `Compose`, called by a new `TransformedEnv.fake_tensordict`
override. `NextObservationDelta` overrides it to cast the
`("next", key)` leaves to `delta_dtype` in the spec-derived fake td.
Pre-allocated `_final_rollout` in `SyncDataCollector(use_buffers=True)`
now reserves storage at the compressed dtype rather than upcasting
writes; the collector test covers both `use_buffers={True, False}`.
- Add `Transform._check_batched_worker_compat()` (no-op default).
`NextObservationDelta` raises with a clear message pointing at the
correct usage pattern. `BatchedEnvBase._get_metadata` builds a
transient probe env and runs the validator via a new `env_validator`
kwarg on `get_env_metadata`, so the inner-batched configuration
fails loudly at construction time instead of silently upcasting at
runtime.
The remaining v1 caveat in the docstring is that `check_env_specs`
still does not pass: it calls `observation_spec.contains(("next", obs))`
and TorchRL shares `observation_spec` between root and `("next", ...)`
leaves, so a compressed dtype is rejected. Working around this
properly requires forking the spec system, which is out of scope for
this PR. Tests use a reset+step smoke instead.
Subtracting in delta_dtype (float16 by default) risks catastrophic cancellation when next_obs and obs are close. Doing the subtraction in the operands' source dtype and casting the result once preserves significand bits and is strictly more accurate on round-trip. The stored root obs is unchanged, so there is no asymmetry to preserve between the on-the-fly delta and the value reconstructed from storage.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
NextObservationDelta, a stateless env-side transform that stores("next", obs)as a low-precision delta from the rootobsfor rollout memory savings on large continuous observations._post_step_mdp_hooksextension point inEnvBase.step_and_maybe_resetand threads it throughTransform,Compose, andTransformedEnv. The hook receives both the post-step and post-step-mdp tensordicts so a transform can rehydrate the flowing td that the policy reads on the next iteration.NextObservationDelta._stepwrites(next_obs - obs).to(delta_dtype)(defaultfloat16);_post_step_mdp_hooksreconstructsobs + deltainrestore_dtype(default: match root). Stateless — no caching across steps.Why this shape
The existing
compact_obscollector flag +NextStateReconstructorRB transform attack the same problem by dropping("next", obs)entirely and shifting at sample time. That is zero-storage but lossy at trajectory boundaries (which becomeNaN). The delta variant trades a small precision loss for boundary-preserving reconstruction and an env-side hook that does not need to know about collector internals.The
_post_step_mdp_hooksmechanism was already stubbed (commented out) incommon.py,transforms/_base.py, andllm/chat.py. This PR enables it. The signature was changed from the original comment ((tensordict_,) -> tensordict_) to(tensordict, tensordict_) -> tensordict_because rehydration needs read access to the post-step root obs. No caller existed before, so this is not a breaking change.v1 limitations (documented on the class)
delta_dtypeprecision and observation magnitude.SyncDataCollector(use_buffers=False)or a lazy RB storage. Pre-allocated_final_rolloutupcasts the write back to the original dtype and erases the saving.step_and_maybe_resetonly.env.rollout()is not wired in v1; direct rollout callers must rehydrate manually.check_env_specsdoes not pass on the transformed env.observation_specis shared between root and("next", ...)in TorchRL; the transform does not fork it in v1 (a follow-up could). Tests use a reset+step smoke instead.SerialEnv/ParallelEnv, the transform belongs outside the batched env (i.e.TransformedEnv(ParallelEnv(...), NextObservationDelta())) — that path uses the outerstep_and_maybe_resetand the hook fires. Putting the transform inside each worker is allowed and runs without error, but the outer batched env'sstep_and_maybe_resetdoes not currently propagate the hook so the stacked output upcasts.Out of scope (potential follow-ups)
observation_specso pre-allocated_final_rolloutbenefits from the compression._rollout_stop_earlyand inbatched_envs/async_envs/envpoolstep_and_maybe_reset.benchmarks/.Test plan
pytest test/transforms/test_observation_transforms.py::TestNextObservationDelta— 14 passed, 2 documented skips.pytest --doctest-modules torchrl/envs/transforms/_observation.py -k NextObservationDelta— passes.pytest test/envs/test_env_base.py— 47 passed, 4 skipped (no regressions from the hook wiring).GymEnv("Pendulum-v1")confirms("next", "observation").dtype == torch.float16post-step andtorch.float32on the flowing td, with bitwise-exact rehydration (max diff 0.0).Compose(NextObservationDelta, RewardSum)works in both orderings.