Compare commits

...

1 Commits

Author SHA1 Message Date
Alexander Soare
315cbcb422 wip - requesting feedback 2024-03-13 12:54:31 +00:00
2 changed files with 115 additions and 2 deletions

View File

@@ -4,9 +4,113 @@ from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
from torchrl.envs.utils import _terminated_or_truncated, step_mdp
class AbstractEnv(EnvBase):
class EnvBaseWithMultiStepRollouts(EnvBase):
"""Adds handling of policies that output action trajectories to be execute with a fixed horizon."""
def _rollout_stop_early(
self,
*,
tensordict,
auto_cast_to_device,
max_steps,
policy,
policy_device,
env_device,
callback,
):
"""Override adds handling of multi-step policies."""
tensordicts = []
step_ix = 0
do_break = False
while not do_break:
if auto_cast_to_device:
if policy_device is not None:
tensordict = tensordict.to(policy_device, non_blocking=True)
else:
tensordict.clear_device_()
tensordict = policy(tensordict)
if auto_cast_to_device:
if env_device is not None:
tensordict = tensordict.to(env_device, non_blocking=True)
else:
tensordict.clear_device_()
for action in tensordict["action"].clone():
tensordict["action"] = action
tensordict = self.step(tensordict)
tensordicts.append(tensordict.clone(False))
if step_ix == max_steps - 1:
# we don't truncated as one could potentially continue the run
do_break = True
break
tensordict = step_mdp(
tensordict,
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=self.reward_keys,
action_keys=self.action_keys,
done_keys=self.done_keys,
)
# done and truncated are in done_keys
# We read if any key is done.
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key=None,
)
if any_done:
break
if callback is not None:
callback(self, tensordict)
step_ix += 1
return tensordicts
def _rollout_nonstop(
self,
*,
tensordict,
auto_cast_to_device,
max_steps,
policy,
policy_device,
env_device,
callback,
):
"""Override adds handling of multi-step policies."""
tensordicts = []
tensordict_ = tensordict
for i in range(max_steps):
if auto_cast_to_device:
if policy_device is not None:
tensordict_ = tensordict_.to(policy_device, non_blocking=True)
else:
tensordict_.clear_device_()
tensordict_ = policy(tensordict_)
if auto_cast_to_device:
if env_device is not None:
tensordict_ = tensordict_.to(env_device, non_blocking=True)
else:
tensordict_.clear_device_()
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
tensordicts.append(tensordict)
if i == max_steps - 1:
# we don't truncated as one could potentially continue the run
break
if callback is not None:
callback(self, tensordict)
return tensordicts
class AbstractEnv(EnvBaseWithMultiStepRollouts):
def __init__(
self,
task,

View File

@@ -4,7 +4,16 @@ import torch
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform, Transform
from torchrl.envs.transforms import ObservationTransform, Transform, TransformedEnv
from torchrl.envs.transforms.transforms import _TEnvPostInit
from lerobot.common.envs.abstract import EnvBaseWithMultiStepRollouts
class TransformedEnv(EnvBaseWithMultiStepRollouts, TransformedEnv, metaclass=_TEnvPostInit):
"""Keep method overrides from EnvBaseWithMultiStepRollouts."""
pass
class Prod(ObservationTransform):