Skip to content

[Feature][Performance] NextObservationDelta env transform#3777

Open
vmoens wants to merge 4 commits into
pytorch:mainfrom
vmoens:worktree-next-obs-delta
Open

[Feature][Performance] NextObservationDelta env transform#3777
vmoens wants to merge 4 commits into
pytorch:mainfrom
vmoens:worktree-next-obs-delta

Conversation

@vmoens
Copy link
Copy Markdown
Collaborator

@vmoens vmoens commented May 18, 2026

Summary

  • Adds NextObservationDelta, a stateless env-side transform that stores ("next", obs) as a low-precision delta from the root obs for rollout memory savings on large continuous observations.
  • Wires up the previously-stubbed _post_step_mdp_hooks extension point in EnvBase.step_and_maybe_reset and threads it through Transform, Compose, and TransformedEnv. The hook receives both the post-step and post-step-mdp tensordicts so a transform can rehydrate the flowing td that the policy reads on the next iteration.
  • NextObservationDelta._step writes (next_obs - obs).to(delta_dtype) (default float16); _post_step_mdp_hooks reconstructs obs + delta in restore_dtype (default: match root). Stateless — no caching across steps.

Why this shape

The existing compact_obs collector flag + NextStateReconstructor RB transform attack the same problem by dropping ("next", obs) entirely and shifting at sample time. That is zero-storage but lossy at trajectory boundaries (which become NaN). The delta variant trades a small precision loss for boundary-preserving reconstruction and an env-side hook that does not need to know about collector internals.

The _post_step_mdp_hooks mechanism was already stubbed (commented out) in common.py, transforms/_base.py, and llm/chat.py. This PR enables it. The signature was changed from the original comment ((tensordict_,) -> tensordict_) to (tensordict, tensordict_) -> tensordict_ because rehydration needs read access to the post-step root obs. No caller existed before, so this is not a breaking change.

v1 limitations (documented on the class)

  • Lossy. Round-trip error scales with delta_dtype precision and observation magnitude.
  • Memory savings require non-pre-allocated stacked output. SyncDataCollector(use_buffers=False) or a lazy RB storage. Pre-allocated _final_rollout upcasts the write back to the original dtype and erases the saving.
  • Hook fires from step_and_maybe_reset only. env.rollout() is not wired in v1; direct rollout callers must rehydrate manually.
  • check_env_specs does not pass on the transformed env. observation_spec is shared between root and ("next", ...) in TorchRL; the transform does not fork it in v1 (a follow-up could). Tests use a reset+step smoke instead.
  • Batched-env composition. For SerialEnv/ParallelEnv, the transform belongs outside the batched env (i.e. TransformedEnv(ParallelEnv(...), NextObservationDelta())) — that path uses the outer step_and_maybe_reset and the hook fires. Putting the transform inside each worker is allowed and runs without error, but the outer batched env's step_and_maybe_reset does not currently propagate the hook so the stacked output upcasts.

Out of scope (potential follow-ups)

  • Forking observation_spec so pre-allocated _final_rollout benefits from the compression.
  • Wiring the hook in _rollout_stop_early and in batched_envs / async_envs / envpool step_and_maybe_reset.
  • A replay-buffer-side delta transform paired with this one.
  • Benchmark entry under benchmarks/.

Test plan

  • pytest test/transforms/test_observation_transforms.py::TestNextObservationDelta — 14 passed, 2 documented skips.
  • pytest --doctest-modules torchrl/envs/transforms/_observation.py -k NextObservationDelta — passes.
  • pytest test/envs/test_env_base.py — 47 passed, 4 skipped (no regressions from the hook wiring).
  • Manual smoke against GymEnv("Pendulum-v1") confirms ("next", "observation").dtype == torch.float16 post-step and torch.float32 on the flowing td, with bitwise-exact rehydration (max diff 0.0).
  • Compose(NextObservationDelta, RewardSum) works in both orderings.
  • Wider CI sweep (compose + env-transforms suites) — local disk filled before completing; relying on CI.

Adds a stateless env-side transform that stores `("next", obs)` as a
low-precision delta from the root `obs`, reducing the rollout-time
memory footprint of large continuous observations.

The transform compresses next observations in `_step` and rehydrates
the flowing tensordict's root observation in a new
`_post_step_mdp_hooks` extension point on `EnvBase`. The hook was
previously half-stubbed in `common.py` / `_base.py` / `llm/chat.py`;
it is now wired through `step_and_maybe_reset` and threaded into
`Transform`, `Compose`, and `TransformedEnv`.

Caveats documented on the class:

- The compression is lossy; round-trip error scales with delta dtype
  precision and observation magnitude.
- Memory savings only materialize against non-pre-allocated stacked
  output (e.g. `SyncDataCollector(use_buffers=False)` or a lazy RB
  storage). Pre-allocated buffers upcast the write.
- The hook fires from `step_and_maybe_reset`; direct `env.rollout()`
  callers must rehydrate manually.
- `check_env_specs` rejects the transformed env in v1 because the
  observation spec is shared between root and `("next", ...)` and we
  do not fork it.

Includes a `TestNextObservationDelta` test class with 16 cases
(14 passing, 2 documented skips) covering single-env, serial/parallel
batched envs (inner and outer wrapping), auto-inference skipping
non-floating dtypes, multi-key, reset semantics, Compose ordering,
and an end-to-end `SyncDataCollector(use_buffers=False)` check that
the stacked batch carries `float16` `("next", obs)`.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 18, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3777

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 4 New Failures

As of commit 0ef4983 with merge base 996387f (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 18, 2026
@github-actions github-actions Bot added Documentation Improvements or additions to documentation Transforms Feature New feature labels May 18, 2026
vmoens added 3 commits May 18, 2026 18:05
- Wire `_post_step_mdp_hooks` in `EnvBase._rollout_stop_early` so
  `env.rollout(..., break_when_any_done=True)` rehydrates the flowing
  td just like `step_and_maybe_reset` already did. The non-stop path
  already routed through `step_and_maybe_reset` and is unchanged.

- Add `Transform.transform_fake_tensordict(td)` hook (no-op default),
  iterated by `Compose`, called by a new `TransformedEnv.fake_tensordict`
  override. `NextObservationDelta` overrides it to cast the
  `("next", key)` leaves to `delta_dtype` in the spec-derived fake td.
  Pre-allocated `_final_rollout` in `SyncDataCollector(use_buffers=True)`
  now reserves storage at the compressed dtype rather than upcasting
  writes; the collector test covers both `use_buffers={True, False}`.

- Add `Transform._check_batched_worker_compat()` (no-op default).
  `NextObservationDelta` raises with a clear message pointing at the
  correct usage pattern. `BatchedEnvBase._get_metadata` builds a
  transient probe env and runs the validator via a new `env_validator`
  kwarg on `get_env_metadata`, so the inner-batched configuration
  fails loudly at construction time instead of silently upcasting at
  runtime.

The remaining v1 caveat in the docstring is that `check_env_specs`
still does not pass: it calls `observation_spec.contains(("next", obs))`
and TorchRL shares `observation_spec` between root and `("next", ...)`
leaves, so a compressed dtype is rejected. Working around this
properly requires forking the spec system, which is out of scope for
this PR. Tests use a reset+step smoke instead.
Subtracting in delta_dtype (float16 by default) risks catastrophic
cancellation when next_obs and obs are close. Doing the subtraction
in the operands' source dtype and casting the result once preserves
significand bits and is strictly more accurate on round-trip.

The stored root obs is unchanged, so there is no asymmetry to
preserve between the on-the-fly delta and the value reconstructed
from storage.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Feature New feature Transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant