fix environment seeding
add fixes for reproducibility only try to start env if it is closed revision fix normalization and data type Improve README Improve README Tests are passing, Eval pretrained model works, Add gif Update gif Update gif Update gif Update gif Update README Update README update minor Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Address suggestions Update thumbnail + stats Update thumbnail + stats Update README.md Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Add more comments Add test_examples.py
This commit is contained in:
@@ -8,6 +8,20 @@ from lerobot.common.utils import set_global_seed
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
"""
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the environment in factory.py
|
||||
available_tasks: list[str] | None = None # for instance: sim_insertion, sim_transfer_cube, pusht, lift
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
@@ -21,6 +35,14 @@ class AbstractEnv(EnvBase):
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
assert self.name is not None, "Subclasses of `AbstractEnv` should set the `name` class attribute."
|
||||
assert (
|
||||
self.available_tasks is not None
|
||||
), "Subclasses of `AbstractEnv` should set the `available_tasks` class attribute."
|
||||
assert (
|
||||
task in self.available_tasks
|
||||
), f"The provided task ({task}) is not on the list of available tasks {self.available_tasks}."
|
||||
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
|
||||
@@ -35,6 +35,8 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class AlohaEnv(AbstractEnv):
|
||||
name = "aloha"
|
||||
available_tasks = ["sim_insertion", "sim_transfer_cube"]
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -22,6 +22,8 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class PushtEnv(AbstractEnv):
|
||||
name = "pusht"
|
||||
available_tasks = ["pusht"]
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -24,6 +24,9 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class SimxarmEnv(AbstractEnv):
|
||||
name = "simxarm"
|
||||
available_tasks = ["lift"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchrl.envs.transforms import ObservationTransform, Transform
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
invertible = True
|
||||
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
self.original_dtypes = {}
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td):
|
||||
for key in self.in_keys:
|
||||
if td.get(key, None) is None:
|
||||
continue
|
||||
self.original_dtypes[key] = td[key].dtype
|
||||
td[key] = td[key].type(torch.float32) * self.prod
|
||||
return td
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for key in self.in_keys:
|
||||
if td.get(key, None) is None:
|
||||
continue
|
||||
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
|
||||
return td
|
||||
|
||||
def transform_observation_spec(self, obs_spec):
|
||||
for key in self.in_keys:
|
||||
if obs_spec.get(key, None) is None:
|
||||
continue
|
||||
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
obs_spec[key].dtype = torch.float32
|
||||
return obs_spec
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: TensorDictBase,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] | None = None,
|
||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||
mode="mean_std",
|
||||
):
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
if in_keys_inv is None:
|
||||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
self.stats = stats
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
self.mode = mode
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
# normalize to [0,1]
|
||||
td[outkey] = (td[inkey] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
td[outkey] = td[outkey] * 2 - 1
|
||||
return td
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
td[outkey] = td[inkey] * std + mean
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
td[outkey] = (td[inkey] + 1) / 2
|
||||
td[outkey] = td[outkey] * (max - min) + min
|
||||
return td
|
||||
Reference in New Issue
Block a user