fix environment seeding
add fixes for reproducibility only try to start env if it is closed revision fix normalization and data type Improve README Improve README Tests are passing, Eval pretrained model works, Add gif Update gif Update gif Update gif Update gif Update README Update README update minor Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Address suggestions Update thumbnail + stats Update thumbnail + stats Update README.md Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Add more comments Add test_examples.py
This commit is contained in:
@@ -1 +1,59 @@
|
||||
"""
|
||||
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
|
||||
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import lerobot
|
||||
print(lerobot.available_envs)
|
||||
print(lerobot.available_tasks_per_env)
|
||||
print(lerobot.available_datasets_per_env)
|
||||
print(lerobot.available_datasets)
|
||||
print(lerobot.available_policies)
|
||||
```
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
from lerobot.__version__ import __version__ # noqa: F401
|
||||
|
||||
available_envs = [
|
||||
"aloha",
|
||||
"pusht",
|
||||
"simxarm",
|
||||
]
|
||||
|
||||
available_tasks_per_env = {
|
||||
"aloha": [
|
||||
"sim_insertion",
|
||||
"sim_transfer_cube",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"simxarm": ["lift"],
|
||||
}
|
||||
|
||||
available_datasets_per_env = {
|
||||
"aloha": [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"simxarm": ["xarm_lift_medium"],
|
||||
}
|
||||
|
||||
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
|
||||
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
]
|
||||
|
||||
@@ -9,7 +9,7 @@ import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
@@ -17,22 +17,56 @@ from torchrl.envs.transforms.transforms import Compose
|
||||
HF_USER = "lerobot"
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
class AbstractDataset(TensorDictReplayBuffer):
|
||||
"""
|
||||
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
|
||||
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
|
||||
These implementations can vary based on the source of the data, the environment the data pertains to,
|
||||
or the specific kind of data manipulation applied.
|
||||
|
||||
Note:
|
||||
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
|
||||
functionality for storing and retrieving `TensorDict`-like data.
|
||||
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
|
||||
It is expected that these variants correspond to a HuggingFace dataset on the hub.
|
||||
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
|
||||
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
available_datasets: list[str] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = None,
|
||||
batch_size: int = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
assert (
|
||||
self.available_datasets is not None
|
||||
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
|
||||
assert (
|
||||
dataset_id in self.available_datasets
|
||||
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
|
||||
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.shuffle = shuffle
|
||||
|
||||
@@ -9,11 +9,11 @@ import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
|
||||
DATASET_IDS = [
|
||||
"aloha_sim_insertion_human",
|
||||
@@ -80,24 +80,24 @@ def download(data_dir, dataset_id):
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
|
||||
class AlohaExperienceReplay(AbstractExperienceReplay):
|
||||
class AlohaDataset(AbstractDataset):
|
||||
available_datasets = DATASET_IDS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
assert dataset_id in DATASET_IDS
|
||||
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
|
||||
from lerobot.common.envs.transforms import NormalizeTransform, Prod
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
@@ -16,6 +16,7 @@ DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
def make_offline_buffer(
|
||||
cfg,
|
||||
overwrite_sampler=None,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
@@ -64,25 +65,27 @@ def make_offline_buffer(
|
||||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||
|
||||
clsfunc = SimxarmExperienceReplay
|
||||
dataset_id = f"xarm_{cfg.env.task}_medium"
|
||||
clsfunc = SimxarmDataset
|
||||
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.datasets.pusht import PushtExperienceReplay
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
clsfunc = PushtExperienceReplay
|
||||
dataset_id = "pusht"
|
||||
clsfunc = PushtDataset
|
||||
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.datasets.aloha import AlohaExperienceReplay
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
|
||||
clsfunc = AlohaExperienceReplay
|
||||
dataset_id = f"aloha_{cfg.env.task}"
|
||||
clsfunc = AlohaDataset
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
# TODO(rcadene): backward compatiblity to load pretrained pusht policy
|
||||
dataset_id = cfg.get("dataset_id")
|
||||
if dataset_id is None and cfg.env.name == "pusht":
|
||||
dataset_id = "pusht"
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=dataset_id,
|
||||
sampler=sampler,
|
||||
@@ -100,36 +103,40 @@ def make_offline_buffer(
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
if normalize:
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for
|
||||
# now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for
|
||||
# now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
if cfg.policy.name == "tdmpc":
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor(
|
||||
[13.456424, 32.938293], dtype=torch.float32
|
||||
)
|
||||
stats["observation", "state", "max"] = torch.tensor(
|
||||
[496.14618, 510.9579], dtype=torch.float32
|
||||
)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
|
||||
offline_buffer.set_transform(transforms)
|
||||
offline_buffer.set_transform(transforms)
|
||||
|
||||
if not overwrite_sampler:
|
||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||
|
||||
@@ -9,11 +9,11 @@ import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
@@ -83,20 +83,22 @@ def add_tee(
|
||||
return body
|
||||
|
||||
|
||||
class PushtExperienceReplay(AbstractExperienceReplay):
|
||||
class PushtDataset(AbstractDataset):
|
||||
available_datasets = ["pusht"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
super().__init__(
|
||||
|
||||
@@ -8,12 +8,12 @@ import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
SliceSampler,
|
||||
Sampler,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractExperienceReplay
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
|
||||
|
||||
def download():
|
||||
@@ -32,7 +32,7 @@ def download():
|
||||
Path(download_path).unlink()
|
||||
|
||||
|
||||
class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
class SimxarmDataset(AbstractDataset):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
@@ -41,15 +41,15 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
batch_size: int = None,
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
super().__init__(
|
||||
|
||||
@@ -8,6 +8,20 @@ from lerobot.common.utils import set_global_seed
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
"""
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the environment in factory.py
|
||||
available_tasks: list[str] | None = None # for instance: sim_insertion, sim_transfer_cube, pusht, lift
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
@@ -21,6 +35,14 @@ class AbstractEnv(EnvBase):
|
||||
num_prev_action=0,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
assert self.name is not None, "Subclasses of `AbstractEnv` should set the `name` class attribute."
|
||||
assert (
|
||||
self.available_tasks is not None
|
||||
), "Subclasses of `AbstractEnv` should set the `available_tasks` class attribute."
|
||||
assert (
|
||||
task in self.available_tasks
|
||||
), f"The provided task ({task}) is not on the list of available tasks {self.available_tasks}."
|
||||
|
||||
self.task = task
|
||||
self.frame_skip = frame_skip
|
||||
self.from_pixels = from_pixels
|
||||
|
||||
@@ -35,6 +35,8 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class AlohaEnv(AbstractEnv):
|
||||
name = "aloha"
|
||||
available_tasks = ["sim_insertion", "sim_transfer_cube"]
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -22,6 +22,8 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class PushtEnv(AbstractEnv):
|
||||
name = "pusht"
|
||||
available_tasks = ["pusht"]
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -24,6 +24,9 @@ _has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||
|
||||
|
||||
class SimxarmEnv(AbstractEnv):
|
||||
name = "simxarm"
|
||||
available_tasks = ["lift"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
||||
@@ -9,8 +9,19 @@ class AbstractPolicy(nn.Module):
|
||||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
documentation for more information.
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
||||
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
||||
- for classes inheriting from `AbstractPolicy`: `name`
|
||||
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
||||
3. update variables in `tests/test_available.py` by importing your new class
|
||||
"""
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
@@ -18,6 +29,7 @@ class AbstractPolicy(nn.Module):
|
||||
adds that dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
assert self.name is not None, "Subclasses of `AbstractPolicy` should set the `name` class attribute."
|
||||
self.n_action_steps = n_action_steps
|
||||
self.clear_action_queue()
|
||||
|
||||
|
||||
@@ -42,6 +42,8 @@ def kl_divergence(mu, logvar):
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
|
||||
@@ -13,6 +13,8 @@ from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
|
||||
@@ -3,9 +3,9 @@ def make_policy(cfg):
|
||||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPC
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
policy = TDMPC(cfg.policy, cfg.device)
|
||||
policy = TDMPCPolicy(cfg.policy, cfg.device)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
|
||||
|
||||
@@ -87,9 +87,11 @@ class TOLD(nn.Module):
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
|
||||
class TDMPC(AbstractPolicy):
|
||||
class TDMPCPolicy(AbstractPolicy):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__(None)
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
@@ -26,6 +26,8 @@ fps: ???
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
|
||||
dataset_id: ???
|
||||
|
||||
n_action_steps: ???
|
||||
n_obs_steps: ???
|
||||
env: ???
|
||||
|
||||
4
lerobot/configs/env/aloha.yaml
vendored
4
lerobot/configs/env/aloha.yaml
vendored
@@ -10,9 +10,11 @@ online_steps: 25000
|
||||
|
||||
fps: 50
|
||||
|
||||
dataset_id: aloha_sim_insertion_human
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
task: sim_insertion_human
|
||||
task: sim_insertion
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: [3, 480, 640]
|
||||
|
||||
2
lerobot/configs/env/pusht.yaml
vendored
2
lerobot/configs/env/pusht.yaml
vendored
@@ -10,6 +10,8 @@ online_steps: 25000
|
||||
|
||||
fps: 10
|
||||
|
||||
dataset_id: pusht
|
||||
|
||||
env:
|
||||
name: pusht
|
||||
task: pusht
|
||||
|
||||
2
lerobot/configs/env/simxarm.yaml
vendored
2
lerobot/configs/env/simxarm.yaml
vendored
@@ -9,6 +9,8 @@ online_steps: 25000
|
||||
|
||||
fps: 15
|
||||
|
||||
dataset_id: xarm_lift_medium
|
||||
|
||||
env:
|
||||
name: simxarm
|
||||
task: lift
|
||||
|
||||
@@ -13,8 +13,10 @@ Examples:
|
||||
You have a specific config file to go with trained model weights, and want to run 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval.py --config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth` eval_episodes=10
|
||||
python lerobot/scripts/eval.py \
|
||||
--config PATH/TO/FOLDER/config.yaml \
|
||||
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth \
|
||||
eval_episodes=10
|
||||
```
|
||||
|
||||
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case,
|
||||
|
||||
@@ -25,7 +25,7 @@ def visualize_dataset_cli(cfg: dict):
|
||||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
# Expects images in [0, 1].
|
||||
# Expects images in [0, 255].
|
||||
frames = torch.cat(frames)
|
||||
assert frames.max() <= 1 and frames.min() >= 0
|
||||
frames = (255 * frames).to(dtype=torch.uint8)
|
||||
@@ -47,44 +47,63 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(
|
||||
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
|
||||
cfg,
|
||||
overwrite_sampler=sampler,
|
||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||
normalize=False,
|
||||
overwrite_batch_size=1,
|
||||
overwrite_prefetch=12,
|
||||
)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
||||
for video_path in video_paths:
|
||||
logging.info(video_path)
|
||||
|
||||
|
||||
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
|
||||
out_dir = Path(out_dir)
|
||||
video_paths = []
|
||||
threads = []
|
||||
frames = {}
|
||||
current_ep_idx = 0
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
|
||||
for i in range(max_num_samples):
|
||||
# TODO(rcadene): make it work with bsize > 1
|
||||
ep_td = offline_buffer.sample(1)
|
||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||
|
||||
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
num_frames_left = offline_buffer._sampler._sample_list.numel()
|
||||
episode_is_done = ep_idx != current_ep_idx
|
||||
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
if episode_is_done:
|
||||
logging.info(f"Rendering episode {current_ep_idx}")
|
||||
|
||||
for im_key in offline_buffer.image_keys:
|
||||
if new_episode or no_more_frames:
|
||||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(offline_buffer.transform(ep_td["next"])[im_key])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
|
||||
# when first frame of episode, initialize frames dict
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
# add current frame to list of frames to render
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
else:
|
||||
# When episode has no more frame in its list of observation,
|
||||
# one frame still remains. It is the result of the last action taken.
|
||||
# It is stored in `"next"`, so we add it to the list of frames to render.
|
||||
frames[im_key].append(ep_td["next"][im_key])
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
camera = im_key[-1]
|
||||
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
else:
|
||||
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
|
||||
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
|
||||
thread = threading.Thread(
|
||||
target=cat_and_write_video,
|
||||
args=(str(video_path), frames[im_key], cfg.fps),
|
||||
args=(str(video_path), frames[im_key], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
@@ -94,12 +113,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
# reset list of frames
|
||||
del frames[im_key]
|
||||
|
||||
# append current cameras images to list of frames
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
|
||||
if no_more_frames:
|
||||
if num_frames_left == 0:
|
||||
logging.info("Ran out of frames")
|
||||
break
|
||||
|
||||
@@ -110,6 +124,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
thread.join()
|
||||
|
||||
logging.info("End of visualize_dataset")
|
||||
return video_paths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user