forked from tangger/lerobot
Compare commits
1 Commits
torchcodec
...
user/alexa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
315cbcb422 |
@@ -4,9 +4,113 @@ from typing import Optional
|
|||||||
|
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.envs import EnvBase
|
from torchrl.envs import EnvBase
|
||||||
|
from torchrl.envs.utils import _terminated_or_truncated, step_mdp
|
||||||
|
|
||||||
|
|
||||||
class AbstractEnv(EnvBase):
|
class EnvBaseWithMultiStepRollouts(EnvBase):
|
||||||
|
"""Adds handling of policies that output action trajectories to be execute with a fixed horizon."""
|
||||||
|
|
||||||
|
def _rollout_stop_early(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tensordict,
|
||||||
|
auto_cast_to_device,
|
||||||
|
max_steps,
|
||||||
|
policy,
|
||||||
|
policy_device,
|
||||||
|
env_device,
|
||||||
|
callback,
|
||||||
|
):
|
||||||
|
"""Override adds handling of multi-step policies."""
|
||||||
|
tensordicts = []
|
||||||
|
step_ix = 0
|
||||||
|
do_break = False
|
||||||
|
while not do_break:
|
||||||
|
if auto_cast_to_device:
|
||||||
|
if policy_device is not None:
|
||||||
|
tensordict = tensordict.to(policy_device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
tensordict.clear_device_()
|
||||||
|
tensordict = policy(tensordict)
|
||||||
|
if auto_cast_to_device:
|
||||||
|
if env_device is not None:
|
||||||
|
tensordict = tensordict.to(env_device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
tensordict.clear_device_()
|
||||||
|
|
||||||
|
for action in tensordict["action"].clone():
|
||||||
|
tensordict["action"] = action
|
||||||
|
tensordict = self.step(tensordict)
|
||||||
|
tensordicts.append(tensordict.clone(False))
|
||||||
|
|
||||||
|
if step_ix == max_steps - 1:
|
||||||
|
# we don't truncated as one could potentially continue the run
|
||||||
|
do_break = True
|
||||||
|
break
|
||||||
|
tensordict = step_mdp(
|
||||||
|
tensordict,
|
||||||
|
keep_other=True,
|
||||||
|
exclude_action=False,
|
||||||
|
exclude_reward=True,
|
||||||
|
reward_keys=self.reward_keys,
|
||||||
|
action_keys=self.action_keys,
|
||||||
|
done_keys=self.done_keys,
|
||||||
|
)
|
||||||
|
# done and truncated are in done_keys
|
||||||
|
# We read if any key is done.
|
||||||
|
any_done = _terminated_or_truncated(
|
||||||
|
tensordict,
|
||||||
|
full_done_spec=self.output_spec["full_done_spec"],
|
||||||
|
key=None,
|
||||||
|
)
|
||||||
|
if any_done:
|
||||||
|
break
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback(self, tensordict)
|
||||||
|
|
||||||
|
step_ix += 1
|
||||||
|
|
||||||
|
return tensordicts
|
||||||
|
|
||||||
|
def _rollout_nonstop(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tensordict,
|
||||||
|
auto_cast_to_device,
|
||||||
|
max_steps,
|
||||||
|
policy,
|
||||||
|
policy_device,
|
||||||
|
env_device,
|
||||||
|
callback,
|
||||||
|
):
|
||||||
|
"""Override adds handling of multi-step policies."""
|
||||||
|
tensordicts = []
|
||||||
|
tensordict_ = tensordict
|
||||||
|
for i in range(max_steps):
|
||||||
|
if auto_cast_to_device:
|
||||||
|
if policy_device is not None:
|
||||||
|
tensordict_ = tensordict_.to(policy_device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
tensordict_.clear_device_()
|
||||||
|
tensordict_ = policy(tensordict_)
|
||||||
|
if auto_cast_to_device:
|
||||||
|
if env_device is not None:
|
||||||
|
tensordict_ = tensordict_.to(env_device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
tensordict_.clear_device_()
|
||||||
|
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
|
||||||
|
tensordicts.append(tensordict)
|
||||||
|
if i == max_steps - 1:
|
||||||
|
# we don't truncated as one could potentially continue the run
|
||||||
|
break
|
||||||
|
if callback is not None:
|
||||||
|
callback(self, tensordict)
|
||||||
|
|
||||||
|
return tensordicts
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractEnv(EnvBaseWithMultiStepRollouts):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task,
|
task,
|
||||||
|
|||||||
@@ -4,7 +4,16 @@ import torch
|
|||||||
from tensordict import TensorDictBase
|
from tensordict import TensorDictBase
|
||||||
from tensordict.nn import dispatch
|
from tensordict.nn import dispatch
|
||||||
from tensordict.utils import NestedKey
|
from tensordict.utils import NestedKey
|
||||||
from torchrl.envs.transforms import ObservationTransform, Transform
|
from torchrl.envs.transforms import ObservationTransform, Transform, TransformedEnv
|
||||||
|
from torchrl.envs.transforms.transforms import _TEnvPostInit
|
||||||
|
|
||||||
|
from lerobot.common.envs.abstract import EnvBaseWithMultiStepRollouts
|
||||||
|
|
||||||
|
|
||||||
|
class TransformedEnv(EnvBaseWithMultiStepRollouts, TransformedEnv, metaclass=_TEnvPostInit):
|
||||||
|
"""Keep method overrides from EnvBaseWithMultiStepRollouts."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Prod(ObservationTransform):
|
class Prod(ObservationTransform):
|
||||||
|
|||||||
Reference in New Issue
Block a user