forked from tangger/lerobot
Compare commits
1 Commits
main
...
user/alexa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
315cbcb422 |
@@ -4,9 +4,113 @@ from typing import Optional
|
||||
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.utils import _terminated_or_truncated, step_mdp
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
class EnvBaseWithMultiStepRollouts(EnvBase):
|
||||
"""Adds handling of policies that output action trajectories to be execute with a fixed horizon."""
|
||||
|
||||
def _rollout_stop_early(
|
||||
self,
|
||||
*,
|
||||
tensordict,
|
||||
auto_cast_to_device,
|
||||
max_steps,
|
||||
policy,
|
||||
policy_device,
|
||||
env_device,
|
||||
callback,
|
||||
):
|
||||
"""Override adds handling of multi-step policies."""
|
||||
tensordicts = []
|
||||
step_ix = 0
|
||||
do_break = False
|
||||
while not do_break:
|
||||
if auto_cast_to_device:
|
||||
if policy_device is not None:
|
||||
tensordict = tensordict.to(policy_device, non_blocking=True)
|
||||
else:
|
||||
tensordict.clear_device_()
|
||||
tensordict = policy(tensordict)
|
||||
if auto_cast_to_device:
|
||||
if env_device is not None:
|
||||
tensordict = tensordict.to(env_device, non_blocking=True)
|
||||
else:
|
||||
tensordict.clear_device_()
|
||||
|
||||
for action in tensordict["action"].clone():
|
||||
tensordict["action"] = action
|
||||
tensordict = self.step(tensordict)
|
||||
tensordicts.append(tensordict.clone(False))
|
||||
|
||||
if step_ix == max_steps - 1:
|
||||
# we don't truncated as one could potentially continue the run
|
||||
do_break = True
|
||||
break
|
||||
tensordict = step_mdp(
|
||||
tensordict,
|
||||
keep_other=True,
|
||||
exclude_action=False,
|
||||
exclude_reward=True,
|
||||
reward_keys=self.reward_keys,
|
||||
action_keys=self.action_keys,
|
||||
done_keys=self.done_keys,
|
||||
)
|
||||
# done and truncated are in done_keys
|
||||
# We read if any key is done.
|
||||
any_done = _terminated_or_truncated(
|
||||
tensordict,
|
||||
full_done_spec=self.output_spec["full_done_spec"],
|
||||
key=None,
|
||||
)
|
||||
if any_done:
|
||||
break
|
||||
|
||||
if callback is not None:
|
||||
callback(self, tensordict)
|
||||
|
||||
step_ix += 1
|
||||
|
||||
return tensordicts
|
||||
|
||||
def _rollout_nonstop(
|
||||
self,
|
||||
*,
|
||||
tensordict,
|
||||
auto_cast_to_device,
|
||||
max_steps,
|
||||
policy,
|
||||
policy_device,
|
||||
env_device,
|
||||
callback,
|
||||
):
|
||||
"""Override adds handling of multi-step policies."""
|
||||
tensordicts = []
|
||||
tensordict_ = tensordict
|
||||
for i in range(max_steps):
|
||||
if auto_cast_to_device:
|
||||
if policy_device is not None:
|
||||
tensordict_ = tensordict_.to(policy_device, non_blocking=True)
|
||||
else:
|
||||
tensordict_.clear_device_()
|
||||
tensordict_ = policy(tensordict_)
|
||||
if auto_cast_to_device:
|
||||
if env_device is not None:
|
||||
tensordict_ = tensordict_.to(env_device, non_blocking=True)
|
||||
else:
|
||||
tensordict_.clear_device_()
|
||||
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
|
||||
tensordicts.append(tensordict)
|
||||
if i == max_steps - 1:
|
||||
# we don't truncated as one could potentially continue the run
|
||||
break
|
||||
if callback is not None:
|
||||
callback(self, tensordict)
|
||||
|
||||
return tensordicts
|
||||
|
||||
|
||||
class AbstractEnv(EnvBaseWithMultiStepRollouts):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
||||
@@ -4,7 +4,16 @@ import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchrl.envs.transforms import ObservationTransform, Transform
|
||||
from torchrl.envs.transforms import ObservationTransform, Transform, TransformedEnv
|
||||
from torchrl.envs.transforms.transforms import _TEnvPostInit
|
||||
|
||||
from lerobot.common.envs.abstract import EnvBaseWithMultiStepRollouts
|
||||
|
||||
|
||||
class TransformedEnv(EnvBaseWithMultiStepRollouts, TransformedEnv, metaclass=_TEnvPostInit):
|
||||
"""Keep method overrides from EnvBaseWithMultiStepRollouts."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
|
||||
Reference in New Issue
Block a user