Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs_transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ Available Transforms
MeanActionSelector
ModuleTransform
MultiAction
NextObservationDelta
NextStateReconstructor
NoopResetEnv
ObservationNorm
Expand Down
298 changes: 298 additions & 0 deletions test/transforms/test_observation_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Crop,
FlattenObservation,
GrayScale,
NextObservationDelta,
ObservationNorm,
ParallelEnv,
PermuteTransform,
Expand Down Expand Up @@ -2790,3 +2791,300 @@ def test_transform_no_env(self, batch):
td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch)
td = trans(td)
assert td["pixels"].shape == torch.Size((*batch, C, D, H, W))


class TestNextObservationDelta(TransformBase):
"""Tests for the env-side delta-compression transform."""

@staticmethod
def _delta_tol(delta_dtype: torch.dtype, scale: float = 1.0) -> float:
# Round-trip error of (next_obs - obs) cast to delta_dtype then back.
# finfo.eps is the gap around 1.0; scale by max magnitude of the delta
# being represented to get a meaningful tolerance.
return torch.finfo(delta_dtype).eps * 8.0 * scale

# `check_env_specs` is not used here. It enforces
# ``observation_spec.contains(("next", obs))``, and TorchRL's spec system
# shares observation_spec between root and ("next", ...) leaves. Since
# this transform deliberately compresses ("next", obs) below the spec
# dtype, `spec.contains` fails. We assert end-to-end behavior instead:
# the env steps, the rollout / collector batch carries the compressed
# dtype, and the rehydrated flowing td is full precision.

@staticmethod
def _smoke_one_step(env):
td = env.reset()
td.set("action", env.action_spec.rand())
post_step, flowing = env.step_and_maybe_reset(td)
assert post_step["next", "observation"].dtype == torch.float16
assert flowing["observation"].dtype == torch.float32

def test_single_trans_env_check(self):
env = TransformedEnv(
ContinuousActionVecMockEnv(),
NextObservationDelta(in_keys=["observation"]),
)
self._smoke_one_step(env)

def test_serial_trans_env_check(self):
# NextObservationDelta inside a batched-env worker must raise at
# construction time -- the outer batched env's step_and_maybe_reset
# does not propagate the worker's transform hook and pre-allocated
# batched buffers would upcast the write anyway.
with pytest.raises(RuntimeError, match="cannot live inside a SerialEnv"):
SerialEnv(
2,
lambda: TransformedEnv(
ContinuousActionVecMockEnv(),
NextObservationDelta(in_keys=["observation"]),
),
)

def test_parallel_trans_env_check(self):
with pytest.raises(RuntimeError, match="cannot live inside a SerialEnv"):
ParallelEnv(
2,
lambda: TransformedEnv(
ContinuousActionVecMockEnv(),
NextObservationDelta(in_keys=["observation"]),
),
mp_start_method="fork",
)

def test_trans_serial_env_check(self):
env = TransformedEnv(
SerialEnv(2, lambda: ContinuousActionVecMockEnv()),
NextObservationDelta(in_keys=["observation"]),
)
self._smoke_one_step(env)

def test_trans_parallel_env_check(self):
env = TransformedEnv(
ParallelEnv(
2, lambda: ContinuousActionVecMockEnv(), mp_start_method="fork"
),
NextObservationDelta(in_keys=["observation"]),
)
try:
self._smoke_one_step(env)
finally:
env.close()

def test_transform_no_env(self):
# The transform is env-side only: calling it like a module / RB transform
# must raise.
t = NextObservationDelta(in_keys=["observation"])
with pytest.raises(NotImplementedError, match="env-side transform"):
t(TensorDict({"observation": torch.zeros(3)}, []))

def test_transform_compose(self):
# Composed offline call still routes through forward, which is unsupported.
t = Compose(NextObservationDelta(in_keys=["observation"]))
with pytest.raises(NotImplementedError, match="env-side transform"):
t(TensorDict({"observation": torch.zeros(3)}, []))

def test_transform_env(self):
torch.manual_seed(0)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
NextObservationDelta(in_keys=["observation"]),
)
td = env.reset()
td.set("action", env.action_spec.rand())
post_step, flowing = env.step_and_maybe_reset(td)

assert post_step["next", "observation"].dtype == torch.float16
assert post_step["observation"].dtype == torch.float32
assert flowing["observation"].dtype == torch.float32

# Rehydrated flowing root == root_obs + delta (round-tripped through
# the delta dtype).
expected = post_step["observation"].to(torch.float32) + post_step[
"next", "observation"
].to(torch.float32)
tol = self._delta_tol(
torch.float16, scale=max(1.0, expected.abs().max().item())
)
torch.testing.assert_close(
flowing["observation"].to(torch.float32), expected, atol=tol, rtol=tol
)

def test_transform_model(self):
pytest.skip("NextObservationDelta is an env-side transform; not a module hook.")

def test_transform_rb(self):
from torchrl.data import LazyTensorStorage, ReplayBuffer

rb = ReplayBuffer(
storage=LazyTensorStorage(4),
transform=NextObservationDelta(in_keys=["observation"]),
batch_size=4,
)
# Extend goes through `inv`, which is a no-op for this transform
# (it has no in_keys_inv) -- so the write succeeds.
rb.extend(TensorDict({"observation": torch.zeros(4, 3)}, batch_size=[4]))
# Sampling, however, routes through `forward`, which is unsupported
# for this env-side-only transform.
with pytest.raises(NotImplementedError, match="env-side transform"):
rb.sample()

def test_transform_inverse(self):
pytest.skip("NextObservationDelta has no inverse (no in_keys_inv).")

def test_auto_infer_keys_skips_uint8(self):
# Build an env whose observation_spec contains a uint8 image key plus
# a float vector. Only the float vector should be picked up by
# auto-inference.
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded

class _DualObsEnv(ContinuousActionVecMockEnv):
pass

env = TransformedEnv(
ContinuousActionVecMockEnv(
observation_spec=Composite(
observation=Unbounded(shape=(7,)),
pixels=Bounded(low=0, high=255, shape=(3, 4, 4), dtype=torch.uint8),
)
),
NextObservationDelta(),
)
# Access lazy in_keys via the transform.
in_keys = list(env.transform.in_keys)
assert ("observation",) in [
(k,) if isinstance(k, str) else tuple(k) for k in in_keys
]
# No uint8 leaf made it in.
for k in in_keys:
spec = (
env.observation_spec[k]
if not isinstance(k, tuple)
else env.observation_spec[k]
)
assert spec.dtype.is_floating_point

def test_multi_in_keys_explicit(self):
torch.manual_seed(1)
from torchrl.data.tensor_specs import Composite, Unbounded

env = TransformedEnv(
ContinuousActionVecMockEnv(
observation_spec=Composite(
observation=Unbounded(shape=(7,)),
observation_orig=Unbounded(shape=(7,)),
)
),
NextObservationDelta(in_keys=["observation", "observation_orig"]),
)
td = env.reset()
td.set("action", env.action_spec.rand())
out, out_ = env.step_and_maybe_reset(td)
assert out["next", "observation"].dtype == torch.float16
assert out["next", "observation_orig"].dtype == torch.float16
assert out_["observation"].dtype == torch.float32
assert out_["observation_orig"].dtype == torch.float32

def test_reset_between_steps(self):
# Stateless contract: a reset between two steps must not corrupt the
# delta computed on the second step.
torch.manual_seed(2)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
NextObservationDelta(in_keys=["observation"]),
)
td0 = env.reset()
td0.set("action", env.action_spec.rand())
env.step_and_maybe_reset(td0)
td1 = env.reset()
td1.set("action", env.action_spec.rand())
out, out_ = env.step_and_maybe_reset(td1)
expected = out["observation"].to(torch.float32) + out["next", "observation"].to(
torch.float32
)
tol = self._delta_tol(
torch.float16, scale=max(1.0, expected.abs().max().item())
)
torch.testing.assert_close(
out_["observation"].to(torch.float32), expected, atol=tol, rtol=tol
)

def test_compose_with_downstream_transform(self):
# NextObservationDelta inside a Compose with another env transform.
# Round-trip through the rehydration must still match a reference run
# that doesn't compress.
from torchrl.envs.transforms import RewardSum

torch.manual_seed(3)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
Compose(
NextObservationDelta(in_keys=["observation"]),
RewardSum(),
),
)
td = env.reset()
td.set("action", env.action_spec.rand())
out, out_ = env.step_and_maybe_reset(td)
assert out["next", "observation"].dtype == torch.float16
assert out_["observation"].dtype == torch.float32

@pytest.mark.parametrize("use_buffers", [True, False])
def test_collector_compressed(self, use_buffers):
# End-to-end: the stacked rollout actually carries float16
# ("next", obs) in both the pre-allocated (use_buffers=True) and the
# lazy (use_buffers=False) collector paths.
from torchrl.collectors import SyncDataCollector

torch.manual_seed(4)

def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(),
NextObservationDelta(in_keys=["observation"]),
)

collector = SyncDataCollector(
create_env_fn=make_env,
policy=None,
frames_per_batch=16,
total_frames=16,
use_buffers=use_buffers,
)
try:
batch = next(iter(collector))
finally:
collector.shutdown()

assert batch["next", "observation"].dtype == torch.float16
assert batch["observation"].dtype == torch.float32

recon = batch["observation"].to(torch.float32) + batch[
"next", "observation"
].to(torch.float32)
assert recon.shape == batch["observation"].shape
assert torch.isfinite(recon).all()

def test_env_rollout_hook_fires(self):
# env.rollout() must invoke _post_step_mdp_hooks so the flowing td
# carries full-precision root obs even on the stop-early path.
torch.manual_seed(5)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
NextObservationDelta(in_keys=["observation"]),
)
out = env.rollout(8, break_when_any_done=True)
# Stacked ("next", obs) is compressed; root obs is full precision
# because each iteration's flowing td was rehydrated by the hook
# before becoming the next iteration's root.
assert out["next", "observation"].dtype == torch.float16
assert out["observation"].dtype == torch.float32

def test_env_rollout_nonstop_hook_fires(self):
torch.manual_seed(6)
env = TransformedEnv(
ContinuousActionVecMockEnv(),
NextObservationDelta(in_keys=["observation"]),
)
out = env.rollout(8, break_when_any_done=False)
assert out["next", "observation"].dtype == torch.float16
assert out["observation"].dtype == torch.float32
2 changes: 2 additions & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
MeanActionSelector,
MultiAction,
MultiStepTransform,
NextObservationDelta,
NoopResetEnv,
ObservationNorm,
ObservationTransform,
Expand Down Expand Up @@ -210,6 +211,7 @@
"MultiStepTransform",
"MultiThreadedEnv",
"MultiThreadedEnvWrapper",
"NextObservationDelta",
"NoopResetEnv",
"ObservationNorm",
"ObservationTransform",
Expand Down
24 changes: 22 additions & 2 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,29 @@ def __getstate__(self):
def _has_dynamic_specs(self):
return not self._use_buffers

@staticmethod
def _validate_worker_env(env) -> None:
"""Check that each transform on a worker env is batched-env compatible.

Walks ``env.transform`` and invokes
:meth:`~torchrl.envs.transforms.Transform._check_batched_worker_compat`
on each entry. Transforms that should not live inside a batched-env
worker raise here so the user gets immediate feedback rather than
silently-wrong runtime behavior.
"""
transform = getattr(env, "transform", None)
if transform is not None:
transform._check_batched_worker_compat()

def _get_metadata(
self, create_env_fn: list[Callable], create_env_kwargs: list[dict]
):
if self._single_task:
# if EnvCreator, the metadata are already there
meta_data: EnvMetaData = get_env_metadata(
create_env_fn[0], create_env_kwargs[0]
create_env_fn[0],
create_env_kwargs[0],
env_validator=self._validate_worker_env,
)
self.meta_data = meta_data.expand(
*(self.num_workers, *meta_data.batch_size)
Expand All @@ -740,7 +756,11 @@ def _get_metadata(
self.meta_data: list[EnvMetaData] = []
for i in range(n_tasks):
self.meta_data.append(
get_env_metadata(create_env_fn[i], create_env_kwargs[i]).clone()
get_env_metadata(
create_env_fn[i],
create_env_kwargs[i],
env_validator=self._validate_worker_env,
).clone()
)
if self.share_individual_td is not True:
share_individual_td = not _stackable(
Expand Down
Loading
Loading