Compare commits

...

1 Commits

Author SHA1 Message Date
Alexander Soare
315cbcb422 wip - requesting feedback 2024-03-13 12:54:31 +00:00
2 changed files with 115 additions and 2 deletions

View File

@@ -4,9 +4,113 @@ from typing import Optional
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
from torchrl.envs.utils import _terminated_or_truncated, step_mdp
class AbstractEnv(EnvBase): class EnvBaseWithMultiStepRollouts(EnvBase):
"""Adds handling of policies that output action trajectories to be execute with a fixed horizon."""
def _rollout_stop_early(
self,
*,
tensordict,
auto_cast_to_device,
max_steps,
policy,
policy_device,
env_device,
callback,
):
"""Override adds handling of multi-step policies."""
tensordicts = []
step_ix = 0
do_break = False
while not do_break:
if auto_cast_to_device:
if policy_device is not None:
tensordict = tensordict.to(policy_device, non_blocking=True)
else:
tensordict.clear_device_()
tensordict = policy(tensordict)
if auto_cast_to_device:
if env_device is not None:
tensordict = tensordict.to(env_device, non_blocking=True)
else:
tensordict.clear_device_()
for action in tensordict["action"].clone():
tensordict["action"] = action
tensordict = self.step(tensordict)
tensordicts.append(tensordict.clone(False))
if step_ix == max_steps - 1:
# we don't truncated as one could potentially continue the run
do_break = True
break
tensordict = step_mdp(
tensordict,
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=self.reward_keys,
action_keys=self.action_keys,
done_keys=self.done_keys,
)
# done and truncated are in done_keys
# We read if any key is done.
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key=None,
)
if any_done:
break
if callback is not None:
callback(self, tensordict)
step_ix += 1
return tensordicts
def _rollout_nonstop(
self,
*,
tensordict,
auto_cast_to_device,
max_steps,
policy,
policy_device,
env_device,
callback,
):
"""Override adds handling of multi-step policies."""
tensordicts = []
tensordict_ = tensordict
for i in range(max_steps):
if auto_cast_to_device:
if policy_device is not None:
tensordict_ = tensordict_.to(policy_device, non_blocking=True)
else:
tensordict_.clear_device_()
tensordict_ = policy(tensordict_)
if auto_cast_to_device:
if env_device is not None:
tensordict_ = tensordict_.to(env_device, non_blocking=True)
else:
tensordict_.clear_device_()
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
tensordicts.append(tensordict)
if i == max_steps - 1:
# we don't truncated as one could potentially continue the run
break
if callback is not None:
callback(self, tensordict)
return tensordicts
class AbstractEnv(EnvBaseWithMultiStepRollouts):
def __init__( def __init__(
self, self,
task, task,

View File

@@ -4,7 +4,16 @@ import torch
from tensordict import TensorDictBase from tensordict import TensorDictBase
from tensordict.nn import dispatch from tensordict.nn import dispatch
from tensordict.utils import NestedKey from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform, Transform from torchrl.envs.transforms import ObservationTransform, Transform, TransformedEnv
from torchrl.envs.transforms.transforms import _TEnvPostInit
from lerobot.common.envs.abstract import EnvBaseWithMultiStepRollouts
class TransformedEnv(EnvBaseWithMultiStepRollouts, TransformedEnv, metaclass=_TEnvPostInit):
"""Keep method overrides from EnvBaseWithMultiStepRollouts."""
pass
class Prod(ObservationTransform): class Prod(ObservationTransform):