Compare commits
15 Commits
test/robot
...
qgallouede
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a33a414fb | ||
|
|
cd76980d50 | ||
|
|
1ffc0e0d94 | ||
|
|
fda092d5bf | ||
|
|
2a9ea01d5a | ||
|
|
0bca982fca | ||
|
|
dae901f556 | ||
|
|
7626b9a4a3 | ||
|
|
6d56bcb5de | ||
|
|
0ec28bf71a | ||
|
|
bc96284ca0 | ||
|
|
343579368b | ||
|
|
72751b7cf6 | ||
|
|
0660f71556 | ||
|
|
42ed7bb670 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,6 +6,7 @@ data
|
||||
outputs
|
||||
.vscode
|
||||
rl
|
||||
.DS_Store
|
||||
|
||||
# HPC
|
||||
nautilus/*.yaml
|
||||
|
||||
@@ -73,15 +73,14 @@ environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||
[pusht](https://github.com/huggingface/gym-pusht))
|
||||
and follow the same api design.
|
||||
|
||||
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
|
||||
- Update `available_datasets` in `lerobot/__init__.py`
|
||||
- Copy it in the required `available_datasets` class attribute
|
||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
|
||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||
- Update `available_policies` in `lerobot/__init__.py`
|
||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
||||
- Set the required `name` class attribute.
|
||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||
|
||||
|
||||
45
README.md
45
README.md
@@ -118,30 +118,7 @@ wandb login
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
|
||||
```python
|
||||
""" Copy pasted from `examples/1_visualize_dataset.py` """
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
from lerobot.scripts.visualize_dataset import render_dataset
|
||||
|
||||
print(lerobot.available_datasets)
|
||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||
|
||||
# TODO(rcadene): remove DATA_DIR
|
||||
dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
|
||||
|
||||
video_paths = render_dataset(
|
||||
dataset,
|
||||
out_dir="outputs/visualize_dataset/example",
|
||||
max_num_episodes=1,
|
||||
)
|
||||
print(video_paths)
|
||||
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||
```
|
||||
Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities.
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
@@ -153,7 +130,7 @@ hydra.run.dir=outputs/visualize_dataset/example
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
|
||||
Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
|
||||
|
||||
Or you can achieve the same result by executing our script from the command line:
|
||||
```bash
|
||||
@@ -176,24 +153,30 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
|
||||
Checkout [examples](./examples) to see how tou can start training a model on a dataset, which will be automatically downloaded if needed.
|
||||
|
||||
In general, you can use our training script to easily train any policy on any environment:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
env=aloha \
|
||||
task=sim_insertion \
|
||||
dataset_id=aloha_sim_insertion_scripted \
|
||||
repo_id=lerobot/aloha_sim_insertion_scripted \
|
||||
policy=act \
|
||||
hydra.run.dir=outputs/train/aloha_act
|
||||
```
|
||||
|
||||
After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoints are already evaluated but on a low number of episodes for efficiency. Checkout [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance.
|
||||
|
||||
## Contribute
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
### Add a new dataset
|
||||
|
||||
```python
|
||||
# TODO(rcadene, AdilZouitine): rewrite this section
|
||||
```
|
||||
|
||||
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
@@ -255,6 +238,10 @@ python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir
|
||||
|
||||
### Add a pretrained policy
|
||||
|
||||
```python
|
||||
# TODO(rcadene, alexander-soare): rewrite this section
|
||||
```
|
||||
|
||||
Once you have trained a policy you may upload it to the HuggingFace hub.
|
||||
|
||||
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.
|
||||
@@ -263,15 +250,13 @@ Secondly, assuming you have trained a policy, you need:
|
||||
|
||||
- `config.yaml` which you can get from the `.hydra` directory of your training output folder.
|
||||
- `model.pt` which should be one of the saved models in the `models` directory of your training output folder (they won't be named `model.pt` but you will need to choose one).
|
||||
- `stats.pth` which should point to the same file in the dataset directory (found in `data/{dataset_name}`).
|
||||
|
||||
To upload these to the hub, prepare a folder with the following structure (you can use symlinks rather than copying):
|
||||
|
||||
```
|
||||
to_upload
|
||||
├── config.yaml
|
||||
├── model.pt
|
||||
└── stats.pth
|
||||
└── model.pt
|
||||
```
|
||||
|
||||
With the folder prepared, run the following with a desired revision ID.
|
||||
|
||||
@@ -23,6 +23,7 @@ from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transf
|
||||
|
||||
|
||||
def download_and_upload(root, revision, dataset_id):
|
||||
# TODO(rcadene, adilzouitine): add community_id/user_id (e.g. "lerobot", "cadene") or repo_id (e.g. "lerobot/pusht")
|
||||
if "pusht" in dataset_id:
|
||||
download_and_upload_pusht(root, revision, dataset_id)
|
||||
elif "xarm" in dataset_id:
|
||||
@@ -149,11 +150,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
|
||||
# copy in tests folder, the first episode and the meta_data directory
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
||||
f"tests/data/{dataset_id}/train"
|
||||
f"tests/data/lerobot/{dataset_id}/train"
|
||||
)
|
||||
if Path(f"tests/data/{dataset_id}/meta_data").exists():
|
||||
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
|
||||
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
|
||||
if Path(f"tests/data/lerobot/{dataset_id}/meta_data").exists():
|
||||
shutil.rmtree(f"tests/data/lerobot/{dataset_id}/meta_data")
|
||||
shutil.copytree(meta_data_dir, f"tests/data/lerobot/{dataset_id}/meta_data")
|
||||
|
||||
|
||||
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
||||
|
||||
@@ -44,7 +44,7 @@ from datasets import load_dataset
|
||||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||
|
||||
# download/load hugging face dataset in pyarrow format
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train", revision="v1.1"), 10
|
||||
|
||||
# display name of dataset and its features
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
|
||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
|
||||
Features included in this script:
|
||||
@@ -11,22 +11,6 @@ Features included in this script:
|
||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||
|
||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
|
||||
To try a different Hugging Face dataset, you can replace:
|
||||
```python
|
||||
dataset = PushtDataset()
|
||||
```
|
||||
by one of these:
|
||||
```python
|
||||
dataset = XarmDataset("xarm_lift_medium")
|
||||
dataset = XarmDataset("xarm_lift_medium_replay")
|
||||
dataset = XarmDataset("xarm_push_medium")
|
||||
dataset = XarmDataset("xarm_push_medium_replay")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_human")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -34,31 +18,33 @@ from pathlib import Path
|
||||
import imageio
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
|
||||
# print("List of available datasets", lerobot.available_datasets)
|
||||
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
|
||||
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
|
||||
# # 'pusht', 'xarm_lift_medium']
|
||||
print("List of available datasets", lerobot.available_datasets)
|
||||
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
|
||||
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
|
||||
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
|
||||
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
# You can easily load datasets from LeRobot
|
||||
dataset = PushtDataset()
|
||||
# You can easily load a dataset from a Hugging Face repositery
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{dataset=}")
|
||||
print(f"{dataset.hf_dataset=}")
|
||||
|
||||
# and provide additional utilities for robotics and compatibility with pytorch
|
||||
# and provides additional utilities for robotics and compatibility with pytorch
|
||||
print(f"number of samples/frames: {dataset.num_samples=}")
|
||||
print(f"number of episodes: {dataset.num_episodes=}")
|
||||
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset.
|
||||
# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
# TODO(rcadene): remove this example of accessing hf_dataset
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
|
||||
@@ -85,7 +71,7 @@ delta_timestamps = {
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
dataset = PushtDataset(delta_timestamps=delta_timestamps)
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
||||
|
||||
@@ -19,7 +19,6 @@ folder = Path(snapshot_download(hub_id))
|
||||
|
||||
config_path = folder / "config.yaml"
|
||||
weights_path = folder / "model.pt"
|
||||
stats_path = folder / "stats.pth" # normalization stats
|
||||
|
||||
# Override some config parameters to do with evaluation.
|
||||
overrides = [
|
||||
@@ -36,5 +35,4 @@ cfg = init_hydra_config(config_path, overrides)
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
@@ -62,7 +62,6 @@ while not done:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
# Save the policy and configuration for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
||||
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
|
||||
|
||||
@@ -8,31 +8,25 @@ Example:
|
||||
print(lerobot.available_envs)
|
||||
print(lerobot.available_tasks_per_env)
|
||||
print(lerobot.available_datasets)
|
||||
print(lerobot.available_datasets_per_env)
|
||||
print(lerobot.available_policies)
|
||||
print(lerobot.available_policies_per_env)
|
||||
```
|
||||
|
||||
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
|
||||
- Update `available_datasets` in `lerobot/__init__.py`
|
||||
- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets`
|
||||
When implementing a new dataset loadable with LeRobotDataset follow these steps:
|
||||
- Update `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
|
||||
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
|
||||
|
||||
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||
- Update `available_policies` in `lerobot/__init__.py`
|
||||
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
|
||||
- Set the required `name` class attribute.
|
||||
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||
"""
|
||||
|
||||
from lerobot.__version__ import __version__ # noqa: F401
|
||||
|
||||
available_envs = [
|
||||
"aloha",
|
||||
"pusht",
|
||||
"xarm",
|
||||
]
|
||||
|
||||
available_tasks_per_env = {
|
||||
"aloha": [
|
||||
"AlohaInsertion-v0",
|
||||
@@ -41,22 +35,24 @@ available_tasks_per_env = {
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
available_datasets = {
|
||||
available_datasets_per_env = {
|
||||
"aloha": [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"pusht": ["lerobot/pusht"],
|
||||
"xarm": [
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
"lerobot/xarm_lift_medium",
|
||||
"lerobot/xarm_lift_medium_replay",
|
||||
"lerobot/xarm_push_medium",
|
||||
"lerobot/xarm_push_medium_replay",
|
||||
],
|
||||
}
|
||||
available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets]
|
||||
|
||||
available_policies = [
|
||||
"act",
|
||||
@@ -71,10 +67,12 @@ available_policies_per_env = {
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets]
|
||||
env_dataset_pairs = [
|
||||
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
|
||||
]
|
||||
env_dataset_policy_triplets = [
|
||||
(env, dataset, policy)
|
||||
for env, datasets in available_datasets.items()
|
||||
for env, datasets in available_datasets_per_env.items()
|
||||
for dataset in datasets
|
||||
for policy in available_policies_per_env[env]
|
||||
]
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class AlohaDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
|
||||
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
|
||||
"""
|
||||
|
||||
# Copied from lerobot/__init__.py
|
||||
available_datasets = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
fps = 50
|
||||
image_keys = ["observation.images.top"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
@@ -1,76 +1,22 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from lerobot.common.transforms import NormalizeTransform
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
stats_path=None,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
|
||||
clsfunc = XarmDataset
|
||||
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
clsfunc = PushtDataset
|
||||
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
|
||||
clsfunc = AlohaDataset
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
transforms = None
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
stats = {}
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation.state"] = {}
|
||||
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action"] = {}
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
elif stats_path is None:
|
||||
# load a first dataset to access precomputed stats
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
)
|
||||
stats = stats_dataset.stats
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
"observation.state",
|
||||
"action",
|
||||
],
|
||||
mode=normalization_mode,
|
||||
),
|
||||
]
|
||||
if cfg.env.name not in cfg.dataset.repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||
@@ -79,12 +25,20 @@ def make_dataset(
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -1,36 +1,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_info,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class XarmDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
|
||||
https://huggingface.co/datasets/lerobot/xarm_push_medium
|
||||
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
|
||||
"""
|
||||
|
||||
# Copied from lerobot/__init__.py
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
]
|
||||
fps = 15
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
repo_id: str,
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
@@ -38,16 +23,25 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.repo_id = repo_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(repo_id, version, root)
|
||||
self.stats = load_stats(repo_id, version, root)
|
||||
self.info = load_info(repo_id, version, root)
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
@@ -1,76 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/pusht
|
||||
|
||||
Arguments
|
||||
----------
|
||||
delta_timestamps : dict[list[float]] | None, optional
|
||||
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
|
||||
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
|
||||
"""
|
||||
|
||||
# Copied from lerobot/__init__.py
|
||||
available_datasets = ["pusht"]
|
||||
fps = 10
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str = "pusht",
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.episode_data_index["from"])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
@@ -15,7 +16,7 @@ from torchvision import transforms
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
|
||||
For example:
|
||||
```
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
@@ -61,19 +62,17 @@ def hf_transform_to_torch(items_dict):
|
||||
return items_dict
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
|
||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
|
||||
else:
|
||||
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
|
||||
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
@@ -84,9 +83,8 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
|
||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
||||
)
|
||||
@@ -94,7 +92,7 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
@@ -103,15 +101,32 @@ def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_info(repo_id, version, root) -> dict:
|
||||
"""info contains useful information regarding the dataset that are not stored elsewhere
|
||||
|
||||
Example:
|
||||
```python
|
||||
print("frame per second used to collect the video", info["fps"])
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||
else:
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
|
||||
|
||||
with open(path) as f:
|
||||
info = json.load(f)
|
||||
return info
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from lerobot.common.transforms import apply_inverse_transform
|
||||
|
||||
|
||||
def preprocess_observation(observation, transform=None):
|
||||
def preprocess_observation(observation):
|
||||
# map to expected inputs for the policy
|
||||
obs = {}
|
||||
|
||||
@@ -24,7 +22,7 @@ def preprocess_observation(observation, transform=None):
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
@@ -33,19 +31,11 @@ def preprocess_observation(observation, transform=None):
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||
obs["observation.state"] = torch.from_numpy(observation["agent_pos"]).float()
|
||||
|
||||
# apply same transforms as in training
|
||||
if transform is not None:
|
||||
for key in obs:
|
||||
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def postprocess_action(action, transform=None):
|
||||
action = action.to("cpu")
|
||||
# action is a batch (num_env,action_dim) instead of an item (action_dim),
|
||||
# we assume applying inverse transform on a batch works the same
|
||||
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
|
||||
def postprocess_action(action):
|
||||
action = action.to("cpu").numpy()
|
||||
assert (
|
||||
action.ndim == 2
|
||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||
|
||||
@@ -21,10 +21,20 @@ class ActionChunkingTransformerConfig:
|
||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||
[0, 1]) for normalization.
|
||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||
subtracted).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
|
||||
torchvision.
|
||||
@@ -51,6 +61,7 @@ class ActionChunkingTransformerConfig:
|
||||
"""
|
||||
|
||||
# Environment.
|
||||
# TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
|
||||
state_dim: int = 14
|
||||
action_dim: int = 14
|
||||
|
||||
@@ -60,11 +71,30 @@ class ActionChunkingTransformerConfig:
|
||||
chunk_size: int = 100
|
||||
n_action_steps: int = 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = field(
|
||||
default_factory=lambda: [0.485, 0.456, 0.406]
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.images.top": [3, 480, 640],
|
||||
"observation.state": [14],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [14],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "mean_std",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "mean_std",
|
||||
}
|
||||
)
|
||||
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
|
||||
@@ -5,7 +5,6 @@ The majority of changes here involve removing unused code, unifying naming, and
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
@@ -15,12 +14,12 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
@@ -62,7 +61,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
name = "act"
|
||||
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
@@ -72,6 +71,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = ActionChunkingTransformerConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
@@ -93,9 +94,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
self.image_normalizer = transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
|
||||
replace_stride_with_dilation=[False, False, cfg.replace_final_stride_with_dilation],
|
||||
pretrained=cfg.use_pretrained_backbone,
|
||||
@@ -129,25 +127,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
|
||||
|
||||
self._reset_parameters()
|
||||
self._create_optimizer()
|
||||
|
||||
def _create_optimizer(self):
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
|
||||
)
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||
@@ -169,10 +148,18 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
# `_forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue effectively
|
||||
# has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(self._forward(batch)[0][: self.cfg.n_action_steps].transpose(0, 1))
|
||||
actions = self._forward(batch)[0][: self.cfg.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
@@ -199,29 +186,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
return loss_dict
|
||||
|
||||
def update(self, batch, **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
loss_dict = self.forward(batch)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.cfg.lr,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||
@@ -309,7 +273,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# Camera observation features and positional embeddings.
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = self.image_normalizer(batch["observation.images"])
|
||||
images = batch["observation.images"]
|
||||
for cam_index in range(len(self.cfg.camera_names)):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -19,10 +19,20 @@ class DiffusionConfig:
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
|
||||
[0, 1]) for normalization.
|
||||
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
|
||||
subtracted).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two availables
|
||||
modes are "mean_std" which substracts the mean and divide by the standard
|
||||
deviation and "min_max" which rescale in a [-1, 1] range.
|
||||
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
@@ -60,6 +70,7 @@ class DiffusionConfig:
|
||||
|
||||
# Environment.
|
||||
# Inherit these from the environment config.
|
||||
# TODO(rcadene, alexander-soare): remove these as they are defined in input_shapes, output_shapes
|
||||
state_dim: int = 2
|
||||
action_dim: int = 2
|
||||
image_size: tuple[int, int] = (96, 96)
|
||||
@@ -69,9 +80,30 @@ class DiffusionConfig:
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5)
|
||||
input_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [2],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
)
|
||||
unnormalize_output_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": "min_max",
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
||||
@@ -11,22 +11,20 @@ TODO(alexander-soare):
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
@@ -42,7 +40,9 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
|
||||
def __init__(
|
||||
self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||
@@ -54,6 +54,8 @@ class DiffusionPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
@@ -67,25 +69,7 @@ class DiffusionPolicy(nn.Module):
|
||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
||||
|
||||
# TODO(alexander-soare): Move optimizer out of policy.
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||
)
|
||||
|
||||
# TODO(alexander-soare): Move LR scheduler out of policy.
|
||||
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
||||
self.global_step = 0
|
||||
|
||||
# configure lr scheduler
|
||||
self.lr_scheduler = get_scheduler(
|
||||
cfg.lr_scheduler,
|
||||
optimizer=self.optimizer,
|
||||
num_warmup_steps=cfg.lr_warmup_steps,
|
||||
num_training_steps=lr_scheduler_num_training_steps,
|
||||
# pytorch assumes stepping LRScheduler every epoch
|
||||
# however huggingface diffusers steps it every batch
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -126,6 +110,8 @@ class DiffusionPolicy(nn.Module):
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
@@ -135,6 +121,10 @@ class DiffusionPolicy(nn.Module):
|
||||
actions = self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
@@ -145,37 +135,6 @@ class DiffusionPolicy(nn.Module):
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
|
||||
def update(self, batch: dict[str, Tensor], **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.diffusion.parameters(),
|
||||
self.cfg.grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.ema is not None:
|
||||
self.ema.step(self.diffusion)
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
@@ -346,12 +305,6 @@ class _RgbEncoder(nn.Module):
|
||||
def __init__(self, cfg: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)):
|
||||
self.normalizer = nn.Identity()
|
||||
else:
|
||||
self.normalizer = torchvision.transforms.Normalize(
|
||||
mean=cfg.image_normalization_mean, std=cfg.image_normalization_std
|
||||
)
|
||||
if cfg.crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
@@ -397,8 +350,7 @@ class _RgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: normalize and maybe crop (if it was set up in the __init__).
|
||||
x = self.normalizer(x)
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
||||
@@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
return policy_cfg
|
||||
|
||||
|
||||
def make_policy(hydra_cfg: DictConfig):
|
||||
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
||||
if hydra_cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
@@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
elif hydra_cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
else:
|
||||
raise ValueError(hydra_cfg.policy.name)
|
||||
|
||||
196
lerobot/common/policies/normalize.py
Normal file
196
lerobot/common/policies/normalize.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def create_stats_buffers(shapes, modes, stats=None):
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max statistics.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||
and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
||||
- "mean_std": substract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
||||
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||
they are already in the policy state_dict.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing `nn.Parameters` set to
|
||||
`requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
shape = tuple(shapes[key])
|
||||
|
||||
if "image" in key:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif mode == "min_max":
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
if stats is not None:
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[key]["mean"]
|
||||
buffer["std"].data = stats[key]["std"]
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[key]["min"]
|
||||
buffer["max"].data = stats[key]["max"]
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""
|
||||
Normalizes the input data (e.g. "observation.image") for more stable and faster convergence during training.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values are their shapes (e.g. `[3,96,96]`]).
|
||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||
and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values are their normalization modes among:
|
||||
- "mean_std": substract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") and values are dictionaries of statistic types and their values
|
||||
(e.g. `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for training the model for the first time,
|
||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||
they are already in the policy state_dict.
|
||||
"""
|
||||
|
||||
def __init__(self, shapes, modes, stats=None):
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch):
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(
|
||||
mean
|
||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
std
|
||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(
|
||||
min
|
||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
max
|
||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their original range used by the environment.
|
||||
|
||||
Parameters:
|
||||
shapes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their shapes (e.g. [10]).
|
||||
These shapes are used to create the tensor buffer containing mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape is adjusted to be invariant to height
|
||||
and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "action") and values are their unnormalization modes among:
|
||||
- "mean_std": multiply by standard deviation and add mean
|
||||
- "min_max": go from [-1, 1] range to original range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "action") and values are dictionaries of statistic types and their values
|
||||
(e.g. `{"max": torch.tensor(1)}, "min": torch.tensor(0)}`). If provided, as expected for training the model for the first time,
|
||||
these statistics will overwrite the default buffers. If not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
be overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the dataset is not needed to get the stats, since
|
||||
they are already in the policy state_dict.
|
||||
"""
|
||||
|
||||
def __init__(self, shapes, modes, stats=None):
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch):
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(
|
||||
mean
|
||||
).any(), "`mean` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
std
|
||||
).any(), "`std` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(
|
||||
min
|
||||
).any(), "`min` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
assert not torch.isinf(
|
||||
max
|
||||
).any(), "`max` is infinity. You forgot to initialize with `stats` as argument, or called `policy.load_state_dict`."
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
@@ -1,65 +0,0 @@
|
||||
from torchvision.transforms.v2 import Compose, Transform
|
||||
|
||||
|
||||
def apply_inverse_transform(item, transform):
|
||||
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
|
||||
for tf in transforms[::-1]:
|
||||
if tf.invertible:
|
||||
item = tf.inverse_transform(item)
|
||||
else:
|
||||
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
|
||||
return item
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: dict,
|
||||
in_keys: list[str] = None,
|
||||
out_keys: list[str] | None = None,
|
||||
in_keys_inv: list[str] | None = None,
|
||||
out_keys_inv: list[str] | None = None,
|
||||
mode="mean_std",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.out_keys = in_keys if out_keys is None else out_keys
|
||||
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
|
||||
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
|
||||
self.stats = stats
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, item):
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
# normalize to [0,1]
|
||||
item[outkey] = (item[inkey] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
item[outkey] = item[outkey] * 2 - 1
|
||||
return item
|
||||
|
||||
def inverse_transform(self, item):
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
item[outkey] = item[inkey] * std + mean
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
item[outkey] = (item[inkey] + 1) / 2
|
||||
item[outkey] = item[outkey] * (max - min) + min
|
||||
return item
|
||||
@@ -26,7 +26,8 @@ fps: ???
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
|
||||
dataset_id: ???
|
||||
dataset:
|
||||
repo_id: ???
|
||||
|
||||
n_action_steps: ???
|
||||
n_obs_steps: ???
|
||||
|
||||
3
lerobot/configs/env/aloha.yaml
vendored
3
lerobot/configs/env/aloha.yaml
vendored
@@ -10,7 +10,8 @@ online_steps: 25000
|
||||
|
||||
fps: 50
|
||||
|
||||
dataset_id: aloha_sim_insertion_human
|
||||
dataset:
|
||||
repo_id: lerobot/aloha_sim_insertion_human
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
|
||||
3
lerobot/configs/env/pusht.yaml
vendored
3
lerobot/configs/env/pusht.yaml
vendored
@@ -10,7 +10,8 @@ online_steps: 25000
|
||||
|
||||
fps: 10
|
||||
|
||||
dataset_id: pusht
|
||||
dataset:
|
||||
repo_id: lerobot/pusht
|
||||
|
||||
env:
|
||||
name: pusht
|
||||
|
||||
3
lerobot/configs/env/xarm.yaml
vendored
3
lerobot/configs/env/xarm.yaml
vendored
@@ -9,7 +9,8 @@ online_steps: 25000
|
||||
|
||||
fps: 15
|
||||
|
||||
dataset_id: xarm_lift_medium
|
||||
dataset:
|
||||
repo_id: lerobot/xarm_lift_medium
|
||||
|
||||
env:
|
||||
name: xarm
|
||||
|
||||
@@ -11,6 +11,12 @@ log_freq: 250
|
||||
n_obs_steps: 1
|
||||
# when temporal_agg=False, n_action_steps=horizon
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
@@ -28,9 +34,19 @@ policy:
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: [0.485, 0.456, 0.406]
|
||||
image_normalization_std: [0.229, 0.224, 0.225]
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.top: [3, 480, 640]
|
||||
observation.state: ["${policy.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${policy.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.images.top: mean_std
|
||||
observation.state: mean_std
|
||||
unnormalize_output_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
|
||||
@@ -18,6 +18,20 @@ online_steps: 0
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
|
||||
override_dataset_stats:
|
||||
# TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
|
||||
observation.image:
|
||||
mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1)
|
||||
# TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
|
||||
# from the original codebase, but we should remove these and train our own pretrained model
|
||||
observation.state:
|
||||
min: [13.456424, 32.938293]
|
||||
max: [496.14618, 510.9579]
|
||||
action:
|
||||
min: [12.0, 25.0]
|
||||
max: [511.0, 511.0]
|
||||
|
||||
policy:
|
||||
name: diffusion
|
||||
|
||||
@@ -36,9 +50,19 @@ policy:
|
||||
horizon: ${horizon}
|
||||
n_action_steps: ${n_action_steps}
|
||||
|
||||
# Vision preprocessing.
|
||||
image_normalization_mean: [0.5, 0.5, 0.5]
|
||||
image_normalization_std: [0.5, 0.5, 0.5]
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.image: [3, 96, 96]
|
||||
observation.state: ["${policy.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${policy.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
normalize_input_modes:
|
||||
observation.image: mean_std
|
||||
observation.state: min_max
|
||||
unnormalize_output_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
|
||||
@@ -46,7 +46,6 @@ from huggingface_hub import snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
@@ -64,8 +63,6 @@ def eval_policy(
|
||||
policy: torch.nn.Module,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path = None,
|
||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||
transform: callable = None,
|
||||
return_episode_data: bool = False,
|
||||
seed=None,
|
||||
):
|
||||
@@ -132,10 +129,6 @@ def eval_policy(
|
||||
if return_episode_data:
|
||||
observations.append(deepcopy(observation))
|
||||
|
||||
# apply transform to normalize the observations
|
||||
for key in observation:
|
||||
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
@@ -143,8 +136,8 @@ def eval_policy(
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=step)
|
||||
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
# convert to cpu numpy
|
||||
action = postprocess_action(action)
|
||||
|
||||
# apply the next action
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
@@ -360,7 +353,7 @@ def eval_policy(
|
||||
return info
|
||||
|
||||
|
||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
def eval(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -375,10 +368,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
transform = make_dataset(cfg, stats_path=stats_path).transform
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
@@ -390,7 +379,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
policy,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
transform=transform,
|
||||
return_episode_data=False,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
@@ -423,17 +411,13 @@ if __name__ == "__main__":
|
||||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
|
||||
eval(
|
||||
cfg,
|
||||
out_dir=f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
|
||||
stats_path=stats_path,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
@@ -7,6 +8,7 @@ import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
@@ -22,6 +24,45 @@ from lerobot.common.utils.utils import (
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
start_time = time.time()
|
||||
|
||||
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
||||
model.train()
|
||||
|
||||
batch = policy.normalize_inputs(batch)
|
||||
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
# Diffusion
|
||||
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if hasattr(policy, "ema") and policy.ema is not None:
|
||||
policy.ema.step(model)
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]['lr'],
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
@@ -232,7 +273,36 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
if cfg.policy.name == "act":
|
||||
optimizer_params_dicts = [
|
||||
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad],
|
||||
"lr": cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||
)
|
||||
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
||||
global_step = 0
|
||||
# configure lr scheduler
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.lr_warmup_steps,
|
||||
num_training_steps=cfg.offline_steps,
|
||||
# pytorch assumes stepping LRScheduler every epoch
|
||||
# however huggingface diffusers steps it every batch
|
||||
last_epoch=global_step - 1,
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -293,7 +363,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy.update(batch, step=step)
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.grad_clip_norm, lr_scheduler)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
@@ -316,9 +386,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
weights = [1.0] * len(concat_dataset)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
weights, num_samples=len(concat_dataset), replacement=True
|
||||
)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(concat_dataset), replacement=True)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
num_workers=4,
|
||||
@@ -339,7 +407,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
policy,
|
||||
transform=offline_dataset.transform,
|
||||
return_episode_data=True,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
@@ -50,11 +50,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("make_dataset")
|
||||
dataset = make_dataset(
|
||||
cfg,
|
||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||
normalize=False,
|
||||
)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user