diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py index 0754fb768..882580bba 100644 --- a/lerobot/common/envs/abstract.py +++ b/lerobot/common/envs/abstract.py @@ -4,9 +4,113 @@ from typing import Optional from tensordict import TensorDict from torchrl.envs import EnvBase +from torchrl.envs.utils import _terminated_or_truncated, step_mdp -class AbstractEnv(EnvBase): +class EnvBaseWithMultiStepRollouts(EnvBase): + """Adds handling of policies that output action trajectories to be execute with a fixed horizon.""" + + def _rollout_stop_early( + self, + *, + tensordict, + auto_cast_to_device, + max_steps, + policy, + policy_device, + env_device, + callback, + ): + """Override adds handling of multi-step policies.""" + tensordicts = [] + step_ix = 0 + do_break = False + while not do_break: + if auto_cast_to_device: + if policy_device is not None: + tensordict = tensordict.to(policy_device, non_blocking=True) + else: + tensordict.clear_device_() + tensordict = policy(tensordict) + if auto_cast_to_device: + if env_device is not None: + tensordict = tensordict.to(env_device, non_blocking=True) + else: + tensordict.clear_device_() + + for action in tensordict["action"].clone(): + tensordict["action"] = action + tensordict = self.step(tensordict) + tensordicts.append(tensordict.clone(False)) + + if step_ix == max_steps - 1: + # we don't truncated as one could potentially continue the run + do_break = True + break + tensordict = step_mdp( + tensordict, + keep_other=True, + exclude_action=False, + exclude_reward=True, + reward_keys=self.reward_keys, + action_keys=self.action_keys, + done_keys=self.done_keys, + ) + # done and truncated are in done_keys + # We read if any key is done. + any_done = _terminated_or_truncated( + tensordict, + full_done_spec=self.output_spec["full_done_spec"], + key=None, + ) + if any_done: + break + + if callback is not None: + callback(self, tensordict) + + step_ix += 1 + + return tensordicts + + def _rollout_nonstop( + self, + *, + tensordict, + auto_cast_to_device, + max_steps, + policy, + policy_device, + env_device, + callback, + ): + """Override adds handling of multi-step policies.""" + tensordicts = [] + tensordict_ = tensordict + for i in range(max_steps): + if auto_cast_to_device: + if policy_device is not None: + tensordict_ = tensordict_.to(policy_device, non_blocking=True) + else: + tensordict_.clear_device_() + tensordict_ = policy(tensordict_) + if auto_cast_to_device: + if env_device is not None: + tensordict_ = tensordict_.to(env_device, non_blocking=True) + else: + tensordict_.clear_device_() + tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_) + tensordicts.append(tensordict) + if i == max_steps - 1: + # we don't truncated as one could potentially continue the run + break + if callback is not None: + callback(self, tensordict) + + return tensordicts + + +class AbstractEnv(EnvBaseWithMultiStepRollouts): def __init__( self, task, diff --git a/lerobot/common/envs/transforms.py b/lerobot/common/envs/transforms.py index 4832c91bf..b758bf7fe 100644 --- a/lerobot/common/envs/transforms.py +++ b/lerobot/common/envs/transforms.py @@ -4,7 +4,16 @@ import torch from tensordict import TensorDictBase from tensordict.nn import dispatch from tensordict.utils import NestedKey -from torchrl.envs.transforms import ObservationTransform, Transform +from torchrl.envs.transforms import ObservationTransform, Transform, TransformedEnv +from torchrl.envs.transforms.transforms import _TEnvPostInit + +from lerobot.common.envs.abstract import EnvBaseWithMultiStepRollouts + + +class TransformedEnv(EnvBaseWithMultiStepRollouts, TransformedEnv, metaclass=_TEnvPostInit): + """Keep method overrides from EnvBaseWithMultiStepRollouts.""" + + pass class Prod(ObservationTransform):