Compare commits

..

54 Commits

Author SHA1 Message Date
Remi Cadene
ca81b1d6f4 WIP 2024-06-08 12:10:04 +02:00
Simon Alibert
e52942a200 fix TestMakeImageTransforms 2024-06-07 17:23:54 +02:00
Simon Alibert
b60810a8b6 Update config doc 2024-06-07 10:55:18 +00:00
Simon Alibert
faacb36271 transform -> image_transforms 2024-06-06 16:53:37 +00:00
Simon Alibert
c45dd8f848 rename to image_transforms 2024-06-06 16:50:22 +00:00
Simon Alibert
a86f387554 WIP 2024-06-06 15:24:31 +00:00
Marina Barannikov
bdc0ebd36a Updated default transform parameters 2024-06-06 13:50:48 +00:00
Marina Barannikov
19f4a6568d Updated default image_transform parameters 2024-06-06 13:37:58 +00:00
Marina Barannikov
9552a4f010 Added clarification comments 2024-06-06 09:24:58 +00:00
Marina Barannikov
d657139828 Updated comments 2024-06-06 09:23:39 +00:00
Simon Alibert
b1714803a3 Disable image_transform by default 2024-06-06 08:39:52 +00:00
Simon Alibert
5d55b19cbd Fix tests 2024-06-05 16:47:52 +00:00
Simon Alibert
641d349df4 Add save_image_transforms.py & artifacts 2024-06-05 16:30:47 +00:00
Simon Alibert
e444b0d529 Add first tests 2024-06-05 16:29:54 +00:00
Marina Barannikov
8237ed9aa4 Updated visualize script 2024-06-05 16:01:37 +00:00
Simon Alibert
82e32f1fcd Fix RandomSubsetApply weighted sampling 2024-06-05 14:19:37 +00:00
Marina Barannikov
644e77e413 Renamed scripts 2024-06-05 13:35:41 +00:00
Marina Barannikov
1b1bbb1632 Minor formatting 2024-06-05 13:31:40 +00:00
Marina Barannikov
0fb3dd745b Implented visualize_image_transforms script 2024-06-05 13:30:32 +00:00
Marina Barannikov
4dbc1adb0d Updated show_transform to match config 2024-06-05 12:32:53 +00:00
Simon Alibert
ceb95592af Remove prints 2024-06-05 12:21:00 +00:00
Simon Alibert
6509c3f6d4 Implement RandomSubsetApply features 2024-06-05 12:15:36 +00:00
Marina Barannikov
8b134725d5 Merge branch 'huggingface:main' into 2024_05_30_add_data_augmentation 2024-06-05 13:56:47 +02:00
Marina Barannikov
a544949ebe Added example of torchvision image augmentation on LeRobotDataset 2024-06-05 10:53:18 +00:00
Simon Alibert
fdf56e7a62 Redesign config 2024-06-05 09:49:31 +00:00
Simon Alibert
443b06b412 refactor show_image_transforms 2024-06-05 09:34:39 +00:00
Alexander Soare
1eb4bfe2e4 Fix videos_dir documentation (#247) 2024-06-05 08:25:20 +01:00
Alexander Soare
21f222fa1d Add out_dir option to eval (#244) 2024-06-04 21:01:53 +02:00
amandip7
33362dbd17 Adding parameter dataloading_s to console logs and wandb for tracking… (#243)
Co-authored-by: Remi <re.cadene@gmail.com>
2024-06-04 17:02:05 +01:00
Marina Barannikov
22bd1f0669 Updated formatting 2024-06-04 12:06:36 +00:00
Marina Barannikov
31e3c82386 Merge remote-tracking branch 'refs/remotes/origin/2024_05_30_add_data_augmentation' into 2024_05_30_add_data_augmentation 2024-06-04 12:00:46 +00:00
Marina Barannikov
5eea2542d9 Added visualisations for image augmentation 2024-06-04 11:57:45 +00:00
Marina Barannikov
42f9cc9c2a Updated transforms arguments 2024-06-04 11:14:54 +00:00
Marina Barannikov
66629a956d Updated config to match transforms 2024-06-04 11:09:23 +00:00
Simon Alibert
7be2c35c0a Merge branch 'huggingface:main' into 2024_05_30_add_data_augmentation 2024-06-04 12:28:06 +02:00
Ruijie
b0d954c6e1 Fix bug in normalize to avoid divide by zero (#239)
Co-authored-by: rj <rj@teleopstrio-razer.lan>
Co-authored-by: Remi <re.cadene@gmail.com>
2024-06-04 12:21:28 +02:00
Marina Barannikov
14291171cc Updated default.yaml 2024-06-03 17:21:16 +00:00
Marina Barannikov
cc4b3bd8e7 Updated default.yaml 2024-06-03 17:18:33 +00:00
Simon Alibert
602ea9844b Add RandomSubsetApply 2024-06-03 17:15:37 +00:00
Simon Alibert
bd3111f28b Fix visualize_dataset.py --help (#241) 2024-06-03 16:35:16 +02:00
Marina Barannikov
9f8415fa83 Added clarification comments 2024-06-03 14:18:08 +00:00
Alexander Soare
cf15cba5fc Remove redundant slicing operation in Diffusion Policy (#240) 2024-06-03 13:04:24 +01:00
jganitzer
042e193995 Typo in examples\4_train_policy_with_script.md (#235) 2024-05-31 18:14:14 +01:00
Marina Barannikov
212a5ab29b Updated implementation on MultiLeRobotDataset 2024-05-31 16:28:08 +00:00
Marina Barannikov
c4870e5892 Added data augmentation feature to MultiLeRobotDataset 2024-05-31 15:42:31 +00:00
Marina Barannikov
20a3715469 Merge remote-tracking branch 'origin/main' into 2024_05_30_add_data_augmentation 2024-05-31 14:50:31 +00:00
marina.barannikov@huggingface.co
65e46a49e1 Implemented data augmentation with LeRobot class 2024-05-31 14:16:38 +00:00
Remi
d585c73f9f Add real-world support for ACT on Aloha/Aloha2 (#228)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-05-31 15:31:02 +02:00
Radek Osmulski
504d2aaf48 add EpisodeAwareSampler (#217)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-05-31 13:43:47 +01:00
Radek Osmulski
83f4f7f7e8 Add precision param to format_big_number (#232) 2024-05-31 10:19:01 +02:00
Alexander Soare
633115d861 Fix chaining in MultiLerobotDataset (#233) 2024-05-31 09:03:28 +01:00
Alexander Soare
57fb5fe8a6 Improve documentation on VAE encoder inputs (#215) 2024-05-30 19:16:44 +02:00
Alexander Soare
0b51a335bc Add a test for MultiLeRobotDataset making sure it produces all frames. (#230)
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-30 17:46:25 +01:00
Alexander Soare
111cd58f8a Add MultiLerobotDataset for training with multiple LeRobotDatasets (#229) 2024-05-30 16:12:21 +01:00
40 changed files with 1357 additions and 265 deletions

View File

@@ -70,7 +70,7 @@ python lerobot/scripts/train.py policy=act env=aloha
There are two things to note here:
- Config overrides are passed as `param_name=param_value`.
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/pusht.yaml`.
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._

View File

@@ -0,0 +1,51 @@
"""
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned.
"""
from pathlib import Path
from torchvision.transforms import ToPILImage, v2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
to_pil = ToPILImage()
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")
output_dir.mkdir(parents=True, exist_ok=True)
repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(repo_id, image_transforms=None)
# Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item()
# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]
# Save the original frame
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
# Define the transformations
transforms = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 1.5)),
v2.ColorJitter(contrast=(0.5, 1.5)),
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
]
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
# Save the transformed frame
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")

View File

@@ -45,6 +45,9 @@ import itertools
from lerobot.__version__ import __version__ # noqa: F401
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
# a yaml file AND a environment name. The difference should be more obvious.
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
@@ -52,7 +55,7 @@ available_tasks_per_env = {
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
}
available_envs = list(available_tasks_per_env.keys())
@@ -78,7 +81,7 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
],
"dora": [
"dora_aloha_real": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
@@ -126,17 +129,19 @@ available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
)
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
available_policies = [
"act",
"diffusion",
"tdmpc",
]
# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],
"dora": ["act"],
"pusht": ["diffusion"],
"xarm": ["tdmpc"],
"dora_aloha_real": ["act_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]

View File

@@ -16,17 +16,15 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
def get_stats_einops_patterns(dataset, num_workers=0):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are in channel first format
@@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo
return stats_patterns
def compute_stats(
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
):
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
if max_num_samples is None:
max_num_samples = len(dataset)
@@ -159,3 +156,54 @@ def compute_stats(
"min": min[key],
}
return stats
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
)
return stats

View File

@@ -16,9 +16,10 @@
import logging
import torch
from omegaconf import OmegaConf
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
def resolve_delta_timestamps(cfg):
@@ -35,25 +36,73 @@ def resolve_delta_timestamps(cfg):
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset(
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
split: Select the data subset used to create an instance of LeRobotDataset.
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
slicer in the hugging face datasets:
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
Returns:
The LeRobotDataset.
"""
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError(
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if cfg.env.name != "dora":
if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
else:
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
for dataset_repo_id in dataset_repo_ids:
if cfg.env.name not in dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
resolve_delta_timestamps(cfg)
# TODO(rcadene): add data augmentations
image_transforms = None
if cfg.image_transforms.enable:
image_transforms = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
contrast_weight=cfg.contrast.weight,
contrast_min_max=cfg.contrast.min_max,
saturation_weight=cfg.saturation.weight,
saturation_min_max=cfg.saturation.min_max,
hue_weight=cfg.hue.weight,
hue_min_max=cfg.hue.min_max,
sharpness_weight=cfg.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms,
random_order=cfg.random_order,
)
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
)
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():

View File

@@ -13,12 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import Callable
import datasets
import torch
import torch.utils
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
load_episode_data_index,
@@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: callable = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -50,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
@@ -147,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s,
)
if self.transform is not None:
item = self.transform(item)
if self.image_transforms is not None:
for cam in self.camera_keys:
item[cam] = self.image_transforms(item[cam])
return item
@@ -164,14 +169,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
@classmethod
def from_preloaded(
cls,
repo_id: str,
repo_id: str = "from_preloaded",
version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
@@ -183,18 +188,214 @@ class LeRobotDataset(torch.utils.data.Dataset):
stats=None,
info=None,
videos_dir=None,
):
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
on the filesystem or uploading to the hub.
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
meaningless depending on the downstream usage of the return dataset.
"""
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.version = version
obj.root = root
obj.split = split
obj.transform = transform
obj.image_transforms = transform
obj.delta_timestamps = delta_timestamps
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
obj.stats = stats
obj.info = info
obj.info = info if info is not None else {}
obj.videos_dir = videos_dir
return obj
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.repo_ids = repo_ids
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
version=version,
root=root,
split=split,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
)
for repo_id in repo_ids
]
# Check that some properties are consistent across datasets. Note: We may relax some of these
# consistency requirements in future iterations of this class.
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
if dataset.info != self._datasets[0].info:
raise ValueError(
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
"not yet supported."
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_data_keys = set()
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
for dataset in self._datasets:
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
if len(intersection_data_keys) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. The "
"multi-dataset functionality currently only keeps common keys."
)
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root
self.split = split
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_samples(self) -> int:
"""Number of samples/frames."""
return sum(d.num_samples for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_samples
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_samples:
start_idx += dataset.num_samples
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_data_keys:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)

View File

@@ -78,29 +78,15 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
image_keys = [key for key in df if "observation.images." in key]
num_unaligned_images = 0
max_episode = 0
def get_episode_index(row):
nonlocal num_unaligned_images
nonlocal max_episode
episode_index_per_cam = {}
for key in image_keys:
if isinstance(row[key], float):
num_unaligned_images += 1
return float("nan")
path = row[key][0]["path"]
match = re.search(r"_(\d{6}).mp4", path)
if not match:
raise ValueError(path)
episode_index = int(match.group(1))
episode_index_per_cam[key] = episode_index
if episode_index > max_episode:
assert episode_index - max_episode == 1
max_episode = episode_index
else:
assert episode_index == max_episode
if len(set(episode_index_per_cam.values())) != 1:
raise ValueError(
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
@@ -125,24 +111,11 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
del df["timestamp_utc"]
# sanity check
num_rows_with_nan = df.isna().any(axis=1).sum()
assert (
num_rows_with_nan == num_unaligned_images
), f"Found {num_rows_with_nan} rows with NaN values but {num_unaligned_images} unaligned images."
if num_unaligned_images > max_episode * 2:
# We allow a few unaligned images, typically at the beginning and end of the episodes for instance
# but if there are too many, we raise an error to avoid large chunks of missing data
raise ValueError(
f"Found {num_unaligned_images} unaligned images out of {max_episode} episodes. "
f"Check the timestamps of the cameras."
)
# Drop rows with NaN values now that we double checked and convert episode_index to int
df = df.dropna()
df["episode_index"] = df["episode_index"].astype(int)
has_nan = df.isna().any().any()
if has_nan:
raise ValueError("Dataset contains Nan values.")
# sanity check episode indices go from 0 to n-1
assert df["episode_index"].max() == max_episode
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))
if ep_ids != expected_ep_ids:
@@ -241,6 +214,8 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
if fps is None:
fps = 30
else:
raise NotImplementedError()
if not video:
raise NotImplementedError()

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterator, Union
import torch
class EpisodeAwareSampler:
def __init__(
self,
episode_data_index: dict,
episode_indices_to_use: Union[list, None] = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
):
"""Sampler that optionally incorporates episode boundary information.
Args:
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
)
self.indices = indices
self.shuffle = shuffle
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
yield i
def __len__(self) -> int:
return len(self.indices)

View File

@@ -0,0 +1,157 @@
from typing import Any, Callable, Dict, Sequence
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812
class RandomSubsetApply(Transform):
"""
Apply a random subset of N transformations from a list of transformations.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (list of floats or None, optional): probability of each transform being picked.
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
(default), all transforms have the same probability.
n_subset (int or None): number of transformations to apply. If ``None``,
all transforms are applied.
random_order (bool): apply transformations in a random order
"""
def __init__(
self,
transforms: Sequence[Callable],
p: list[float] | None = None,
n_subset: int | None = None,
random_order: bool = False,
) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
)
if n_subset is None:
n_subset = len(transforms)
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (0 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
self.n_subset = n_subset
self.random_order = random_order
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
if not self.random_order:
selected_indices = selected_indices.sort().values
selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
return (
f"transforms={self.transforms}, "
f"p={self.p}, "
f"n_subset={self.n_subset}, "
f"random_order={self.random_order}"
)
class RangeRandomSharpness(Transform):
"""Similar to v2.RandomAdjustSharpness but with p=1 and a sharpness_factor sampled randomly
each time in [range_min, range_max].
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
"""
def __init__(self, range_min: float, range_max) -> None:
super().__init__()
self.range_min, self.range_max = self._check_input(range_min, range_max)
def _check_input(self, range_min, range_max):
if range_min < 0:
raise ValueError("range_min must be non negative.")
if range_min > range_max:
raise ValueError("range_max must greater or equal to range_min")
return range_min, range_max
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1).item()
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def get_image_transforms(
brightness_weight: float = 1.0,
brightness_min_max: tuple[float, float] | None = None,
contrast_weight: float = 1.0,
contrast_min_max: tuple[float, float] | None = None,
saturation_weight: float = 1.0,
saturation_min_max: tuple[float, float] | None = None,
hue_weight: float = 1.0,
hue_min_max: tuple[float, float] | None = None,
sharpness_weight: float = 1.0,
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
):
def check_value_error(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided.")
if weight < 0.:
raise ValueError(f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight}).")
check_value_error("brightness", brightness_weight, brightness_min_max)
check_value_error("contrast", contrast_weight, contrast_min_max)
check_value_error("saturation", saturation_weight, saturation_min_max)
check_value_error("hue", hue_weight, hue_min_max)
check_value_error("sharpness", sharpness_weight, sharpness_min_max)
weights = []
transforms = []
if brightness_min_max is not None:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None:
weights.append(sharpness_weight)
transforms.append(RangeRandomSharpness(**sharpness_min_max))
if max_num_transforms is None:
n_subset = len(transforms)
else:
n_subset = min(len(transforms), max_num_transforms)
final_transforms = RandomSubsetApply(
transforms, p=weights, n_subset=n_subset, random_order=random_order
)
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return final_transforms

View File

@@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
return outdict
def hf_transform_to_torch(items_dict):
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
@@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict):
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
@@ -318,8 +320,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""
Reset the `episode_index` of the provided HuggingFace Dataset.
"""Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
@@ -338,6 +339,7 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
return hf_dataset

View File

@@ -26,11 +26,10 @@ class ACTConfig:
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state.
- "action" is required as an output key.
Args:

View File

@@ -198,15 +198,14 @@ class ACT(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.config = config
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.has_state = "observation.state" in config.input_shapes
self.latent_dim = config.latent_dim
self.use_input_state = "observation.state" in config.input_shapes
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.has_state:
if self.use_input_state:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
@@ -215,10 +214,12 @@ class ACT(nn.Module):
config.output_shapes["action"][0], config.dim_model
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
@@ -241,16 +242,16 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
if self.has_state:
if self.use_input_state:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.has_state else 1
num_input_token_decoder = 2 if self.use_input_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
@@ -298,12 +299,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.has_state:
if self.use_input_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.has_state:
if self.use_input_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
@@ -318,9 +319,9 @@ class ACT(nn.Module):
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
mu = latent_pdf_params[:, : self.config.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
@@ -328,7 +329,7 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
@@ -350,12 +351,12 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
if self.has_state:
if self.use_input_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack(encoder_in_feats, axis=0),

View File

@@ -28,10 +28,7 @@ class DiffusionConfig:
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- A key starting with "observation.image is required as an input.
- "action" is required as an output key.
Args:

View File

@@ -239,10 +239,8 @@ class DiffusionModel(nn.Module):
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling
sample = self.conditional_sample(batch_size, global_cond=global_cond)
actions = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps

View File

@@ -147,7 +147,7 @@ class Normalize(nn.Module):
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min)
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:

View File

@@ -120,13 +120,13 @@ def init_logging():
logging.getLogger().addHandler(console_handler)
def format_big_number(num):
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0
for suffix in suffixes:
if abs(num) < divisor:
return f"{num:.0f}{suffix}"
return f"{num:.{precision}f}{suffix}"
num /= divisor
return num

View File

@@ -23,6 +23,10 @@ use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: ???
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datsets are provided.
dataset_repo_id: lerobot/pusht
training:
@@ -53,3 +57,37 @@ wandb:
disable_artifact: false
project: lerobot
notes: ""
image_transforms:
# These transforms are all using standard torchvision.transforms.v2
# You can find out how these transformations affect images here:
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
# We use a custom RandomSubsetApply container to sample them.
# For each transform, the following parameters are available:
# weight: This represents the multinomial probability (with no replacement)
# used for sampling the transform. If the sum of the weights is not 1,
# they will be normalized.
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
# (following uniform distribution) when it's applied.
enable: false
# This is the number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [0, number of available transforms].
max_num_transforms: 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: false
brightness:
weight: 1
min_max: [0.8, 1.2]
contrast:
weight: 1
min_max: [0.8, 1.2]
saturation:
weight: 1
min_max: [0.5, 1.5]
hue:
weight: 1
min_max: [-0.05, 0.05]
sharpness:
weight: 1
min_max: [0.8, 1.2]

View File

@@ -11,7 +11,7 @@
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# env=dora_aloha_real
# ```
seed: 1000

View File

@@ -9,7 +9,7 @@
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real_no_state \
# env=aloha_real
# env=dora_aloha_real
# ```
seed: 1000

View File

@@ -44,6 +44,10 @@ training:
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
eval:
n_episodes: 50
batch_size: 50

View File

@@ -209,7 +209,7 @@ def eval_policy(
policy: torch.nn.Module,
n_episodes: int,
max_episodes_rendered: int = 0,
video_dir: Path | None = None,
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
enable_progbar: bool = False,
@@ -221,7 +221,7 @@ def eval_policy(
policy: The policy.
n_episodes: The number of episodes to evaluate.
max_episodes_rendered: Maximum number of episodes to render into videos.
video_dir: Where to save rendered videos.
videos_dir: Where to save rendered videos.
return_episode_data: Whether to return episode data for online training. Incorporates the data into
the "episodes" key of the returned dictionary.
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
@@ -347,8 +347,8 @@ def eval_policy(
):
if n_episodes_rendered >= max_episodes_rendered:
break
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
videos_dir.mkdir(parents=True, exist_ok=True)
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=write_video,
@@ -503,9 +503,10 @@ def _compile_episode_data(
}
def eval(
def main(
pretrained_policy_path: str | None = None,
hydra_cfg_path: str | None = None,
out_dir: str | None = None,
config_overrides: list[str] | None = None,
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
@@ -513,12 +514,8 @@ def eval(
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
out_dir = (
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
)
if out_dir is None:
raise NotImplementedError()
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
@@ -546,7 +543,7 @@ def eval(
policy,
hydra_cfg.eval.n_episodes,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
videos_dir=Path(out_dir) / "videos",
start_seed=hydra_cfg.seed,
enable_progbar=True,
enable_inner_progbar=True,
@@ -586,6 +583,13 @@ if __name__ == "__main__":
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--out-dir",
help=(
"Where to save the evaluation outputs. If not provided, outputs are saved in "
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
),
)
parser.add_argument(
"overrides",
nargs="*",
@@ -594,7 +598,7 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.pretrained_policy_name_or_path is None:
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
else:
try:
pretrained_policy_path = Path(
@@ -618,4 +622,8 @@ if __name__ == "__main__":
"repo ID, nor is it an existing local directory."
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
main(
pretrained_policy_path=pretrained_policy_path,
out_dir=args.out_dir,
config_overrides=args.overrides,
)

View File

@@ -71,9 +71,9 @@ import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict

View File

@@ -16,7 +16,6 @@
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
@@ -28,6 +27,8 @@ from termcolor import colored
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
@@ -149,6 +150,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
grad_norm = info["grad_norm"]
lr = info["lr"]
update_s = info["update_s"]
dataloading_s = info["dataloading_s"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
@@ -169,6 +171,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
f"lr:{lr:0.1e}",
# in seconds
f"updt_s:{update_s:.3f}",
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
]
logging.info(" ".join(log_items))
@@ -280,6 +283,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -319,6 +327,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step):
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
step_identifier = f"{step:0{_num_digits}d}"
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
@@ -326,11 +337,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
eval_env,
policy,
cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval",
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
@@ -344,29 +355,40 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy,
optimizer,
lr_scheduler,
identifier=str(step).zfill(
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
),
identifier=step_identifier,
)
logging.info("Resume training")
# create dataloader for offline training
if cfg.training.get("drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.training.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=cfg.training.num_workers,
batch_size=cfg.training.batch_size,
shuffle=True,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
policy.train()
is_offline = True
for _ in range(step, cfg.training.offline_steps):
if step == 0:
logging.info("Start offline training on a fixed dataset")
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
for key in batch:
batch[key] = batch[key].to(device, non_blocking=True)
@@ -381,8 +403,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
use_amp=cfg.use_amp,
)
train_info["dataloading_s"] = dataloading_s
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
@@ -390,41 +414,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
logging.info("End of offline training")
if cfg.training.online_steps == 0:
if cfg.training.eval_freq > 0:
eval_env.close()
return
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.training.batch_size,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
logging.info("End of online training")
if cfg.training.eval_freq > 0:
eval_env.close()
online_training_env.close()
eval_env.close()
logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")

View File

@@ -224,7 +224,8 @@ def main():
help=(
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
"'distant' creates a server on the distant machine where the data is stored. "
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
),
)
parser.add_argument(
@@ -245,8 +246,8 @@ def main():
default=0,
help=(
"Save a .rrd file in the directory provided by `--output-dir`. "
"It also deactivates the spawning of a viewer. ",
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
"It also deactivates the spawning of a viewer. "
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
),
)
parser.add_argument(

View File

@@ -0,0 +1,46 @@
from pathlib import Path
import hydra
from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_image_transforms
to_pil = ToPILImage()
def main(cfg, output_dir=Path("outputs/image_transforms")):
dataset = LeRobotDataset(cfg.dataset_repo_id, image_transforms=None)
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
output_dir.mkdir(parents=True, exist_ok=True)
# Get first frame of 1st episode
first_idx = dataset.episode_data_index["from"][0].item()
frame = dataset[first_idx][dataset.camera_keys[0]]
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
# Apply each single transformation
for transform_name in transforms:
for t in transforms:
if t == transform_name:
cfg.image_transforms[t].weight = 1
else:
cfg.image_transforms[t].weight = 0
transform = make_image_transforms(cfg.image_transforms)
img = transform(frame)
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms_cli(cfg: dict):
main(
cfg,
)
if __name__ == "__main__":
visualize_transforms_cli()

163
poetry.lock generated
View File

@@ -444,63 +444,63 @@ files = [
[[package]]
name = "coverage"
version = "7.5.1"
version = "7.5.3"
description = "Code coverage measurement for Python"
optional = true
python-versions = ">=3.8"
files = [
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
{file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"},
{file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"},
{file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"},
{file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"},
{file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"},
{file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"},
{file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"},
{file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"},
{file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"},
{file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"},
{file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"},
{file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"},
{file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"},
{file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"},
{file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"},
{file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"},
{file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"},
{file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"},
{file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"},
{file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"},
{file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"},
{file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"},
]
[package.dependencies]
@@ -1104,7 +1104,7 @@ pyarrow = ">=12.0.0"
type = "git"
url = "https://github.com/dora-rs/dora-lerobot.git"
reference = "HEAD"
resolved_reference = "1c6c2a401c3a2967d41444be6286ca9a28893abf"
resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79"
subdirectory = "gym_dora"
[[package]]
@@ -1310,13 +1310,13 @@ files = [
[[package]]
name = "huggingface-hub"
version = "0.23.1"
version = "0.23.2"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"},
{file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"},
{file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"},
{file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"},
]
[package.dependencies]
@@ -2102,18 +2102,15 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]]
name = "nodeenv"
version = "1.8.0"
version = "1.9.0"
description = "Node.js virtual environment builder"
optional = true
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
files = [
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
{file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"},
{file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"},
]
[package.dependencies]
setuptools = "*"
[[package]]
name = "numba"
version = "0.59.1"
@@ -3231,13 +3228,13 @@ files = [
[[package]]
name = "requests"
version = "2.32.2"
version = "2.32.3"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.8"
files = [
{file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
{file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
]
[package.dependencies]
@@ -3253,16 +3250,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rerun-sdk"
version = "0.16.0"
version = "0.16.1"
description = "The Rerun Logging SDK"
optional = false
python-versions = "<3.13,>=3.8"
files = [
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1cc6dc66d089e296f945dc238301889efb61dd6d338b5d00f76981cf7aed0a74"},
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:faf231897655e46eb975695df2b0ace07db362d697e697f9a3dff52f81c0dc5d"},
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:860a6394380d3e9b9e48bf34423bd56dda54d5b0158d2ae0e433698659b34198"},
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:5b8d1476f73a3ad1a5d3f21b61c633f3ab62aa80fa0b049f5ad10bf1227681ab"},
{file = "rerun_sdk-0.16.0-cp38-abi3-win_amd64.whl", hash = "sha256:aff0051a263b8c3067243c0126d319845baf4fe640899f04aeef7daf151f35e4"},
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"},
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"},
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"},
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"},
{file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"},
]
[package.dependencies]
@@ -3739,17 +3736,17 @@ files = [
[[package]]
name = "sympy"
version = "1.12"
version = "1.12.1"
description = "Computer algebra system (CAS) in Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
{file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
]
[package.dependencies]
mpmath = ">=0.19"
mpmath = ">=1.1.0,<1.4.0"
[[package]]
name = "tbb"
@@ -4263,13 +4260,13 @@ multidict = ">=4.0"
[[package]]
name = "zarr"
version = "2.18.1"
version = "2.18.2"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "zarr-2.18.1-py3-none-any.whl", hash = "sha256:a1770d194eec4ec0a41a01295a6f724e1c3471d704d3aca906d3b3a7f8830245"},
{file = "zarr-2.18.1.tar.gz", hash = "sha256:28c360ed123e606c425a694a83300227a907cb86a995fc9eef620ecafbe5f92d"},
{file = "zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38"},
{file = "zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47"},
]
[package.dependencies]
@@ -4284,13 +4281,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
[[package]]
name = "zipp"
version = "3.18.2"
version = "3.19.0"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
files = [
{file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"},
{file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"},
{file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"},
{file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"},
]
[package.extras]

Binary file not shown.

After

Width:  |  Height:  |  Size: 185 KiB

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ebd21273f6048b66c806f92035352843a9069908b3296863fd55d34cf71cd0ef
size 51248
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
size 5104

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b9bbf951891077320a5da27e77ddb580a6e833e8d3162b62a2f887a1989585cc
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
size 31688

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
size 68

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:42f92239223bb4df32d5c3016bc67450159f1285a7ab046307b645f699ccc34e
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
size 34928

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:52f85d6262ad1dd0b66578b25829fed96aaaca3c7458cb73ac75111350d17fcf
size 51248
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
size 5104

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5ba7c910618f0f3ca69f82f3d70c880d2b2e432456524a2a63dfd5c50efa45f0
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
size 30808

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:53ad410f43855254438790f54aa7c895a052776acdd922906ae430684f659b53
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
size 33608

View File

@@ -0,0 +1,47 @@
from pathlib import Path
import torch
from torchvision.transforms import v2
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RangeRandomSharpness
from lerobot.common.utils.utils import seeded_context
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
ARTIFACT_DIR = "tests/data/save_image_transforms"
SEED = 1336
to_pil = v2.ToPILImage()
def main(repo_id):
dataset = LeRobotDataset(repo_id, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
# Get first frame of given episode
from_idx = dataset.episode_data_index["from"][0].item()
original_frame = dataset[from_idx][dataset.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
transforms = {
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
"sharpness": RangeRandomSharpness(0.0, 2.0),
}
# frames = {"original_frame": original_frame}
for name, transform in transforms.items():
with seeded_context(SEED):
# transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)])
transformed_frame = transform(original_frame)
# frames[name] = transform(original_frame)
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
# save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors")
if __name__ == "__main__":
repo_id = "lerobot/aloha_mobile_shrimp"
main(repo_id)

View File

@@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
batch = next(iter(dataloader))
obs = {}
for k in batch:
if "observation" in k:
if k.startswith("observation"):
obs[k] = batch[k]
if "n_action_steps" in cfg.policy:
@@ -115,8 +115,8 @@ if __name__ == "__main__":
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@@ -16,6 +16,7 @@
import json
import logging
from copy import deepcopy
from itertools import chain
from pathlib import Path
import einops
@@ -25,26 +26,34 @@ from datasets import Dataset
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
)
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
compute_stats,
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.utils import (
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
lerobot.env_dataset_policy_triplets
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
)
def test_factory(env_name, repo_id, policy_name):
"""
Tests that:
- we can create a dataset with the factory.
- for a commonly used set of data keys, the data dimensions are correct.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
@@ -105,6 +114,39 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}"
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
@@ -315,3 +357,31 @@ def test_backward_compatibility(repo_id):
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
def test_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):
data_a = torch.rand(30, dtype=torch.float32)
data_b = torch.rand(20, dtype=torch.float32)
data_c = torch.rand(20, dtype=torch.float32)
hf_dataset_1 = Dataset.from_dict(
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
)
hf_dataset_1.set_transform(hf_transform_to_torch)
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
hf_dataset_2.set_transform(hf_transform_to_torch)
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
hf_dataset_3.set_transform(hf_transform_to_torch)
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))

View File

@@ -86,6 +86,9 @@ def test_policy(env_name, policy_name, extra_overrides):
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
@@ -137,7 +140,7 @@ def test_policy(env_name, policy_name, extra_overrides):
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
num_workers=0,
batch_size=2,
shuffle=True,
pin_memory=DEVICE != "cpu",

90
tests/test_sampler.py Normal file
View File

@@ -0,0 +1,90 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import Dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
def test_drop_n_first_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert list(sampler) == [1, 4, 5]
def test_drop_n_last_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert list(sampler) == [0, 3, 4]
def test_episode_indices_to_use():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert list(sampler) == [0, 1, 3, 4, 5]
def test_shuffle():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert list(sampler) == [0, 1, 2, 3, 4, 5]
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}

253
tests/test_transforms.py Normal file
View File

@@ -0,0 +1,253 @@
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
import pytest
import torch
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, get_image_transforms
from lerobot.common.utils.utils import seeded_context
# test_make_image_transforms
# -
# test backward compatibility torchvision
# - save artifacts
# test backward compatibility default yaml (enable false, enable true)
# - save artifacts
def test_get_image_transforms_no_transform():
get_image_transforms()
get_image_transforms(sharpness_weight=0.0)
get_image_transforms(max_num_transforms=0)
@pytest.fixture
def img():
# dataset = LeRobotDataset("lerobot/pusht")
# item = dataset[0]
# return item["observation.image"]
path = "tests/data/save_image_transforms/original_frame.png"
img_chw = torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
return img_chw
def test_get_image_transforms_brightness(img):
brightness_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(brightness_weight=1., brightness_min_max=brightness_min_max)
tf_expected = v2.ColorJitter(brightness=brightness_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_contrast(img):
contrast_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(contrast_weight=1., contrast_min_max=contrast_min_max)
tf_expected = v2.ColorJitter(contrast=contrast_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_saturation(img):
saturation_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(saturation_weight=1., saturation_min_max=saturation_min_max)
tf_expected = v2.ColorJitter(saturation=saturation_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_hue(img):
hue_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(hue_weight=1., hue_min_max=hue_min_max)
tf_expected = v2.ColorJitter(hue=hue_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_sharpness(img):
sharpness_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(sharpness_weight=1., sharpness_min_max=sharpness_min_max)
tf_expected = RangeRandomSharpness(**sharpness_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_max_num_transforms(img):
tf_actual = get_image_transforms(
saturation_min_max=(0.5, 0.5),
constrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=False,
)
tf_expected = v2.Compose([
v2.ColorJitter(brightness=(0.5, 0.5)),
v2.ColorJitter(contrast=(0.5, 0.5)),
v2.ColorJitter(saturation=(0.5, 0.5)),
v2.ColorJitter(hue=(0.5, 0.5)),
RangeRandomSharpness(sharpness=(0.5, 0.5)),
])
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_random_order(img):
out_imgs = []
with seeded_context(1337):
for _ in range(20):
tf = get_image_transforms(
saturation_min_max=(0.5, 0.5),
constrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=False,
)
out_imgs.append(tf(img))
for i in range(1,10):
with pytest.raises(ValueError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
def test_backward_compatibility_torchvision():
pass
def test_backward_compatibility_default_yaml():
pass
# class TestRandomSubsetApply:
# @pytest.fixture(autouse=True)
# def setup(self):
# self.jitters = [
# v2.ColorJitter(brightness=0.5),
# v2.ColorJitter(contrast=0.5),
# v2.ColorJitter(saturation=0.5),
# ]
# self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
# self.img = torch.rand(3, 224, 224)
# @pytest.mark.parametrize("p", [[0, 1], [1, 0]])
# def test_random_choice(self, p):
# random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False)
# output = random_choice(self.img)
# p_horz, _ = p
# if p_horz:
# torch.testing.assert_close(output, F.horizontal_flip(self.img))
# else:
# torch.testing.assert_close(output, F.vertical_flip(self.img))
# def test_transform_all(self):
# transform = RandomSubsetApply(self.jitters)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_transform_subset(self):
# transform = RandomSubsetApply(self.jitters, n_subset=2)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_random_order(self):
# random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# # We can't really check whether the transforms are actually applied in random order. However,
# # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# # applies them in random order, we can use a fixed order to compute the expected value.
# actual = random_order(self.img)
# expected = v2.Compose(self.flips)(self.img)
# torch.testing.assert_close(actual, expected)
# def test_probability_length_mismatch(self):
# with pytest.raises(ValueError):
# RandomSubsetApply(self.jitters, p=[0.5, 0.5])
# def test_invalid_n_subset(self):
# with pytest.raises(ValueError):
# RandomSubsetApply(self.jitters, n_subset=5)
# class TestRangeRandomSharpness:
# @pytest.fixture(autouse=True)
# def setup(self):
# self.img = torch.rand(3, 224, 224)
# def test_valid_range(self):
# transform = RangeRandomSharpness(0.1, 2.0)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_invalid_range_min_negative(self):
# with pytest.raises(ValueError):
# RangeRandomSharpness(-0.1, 2.0)
# def test_invalid_range_max_smaller(self):
# with pytest.raises(ValueError):
# RangeRandomSharpness(2.0, 0.1)
# class TestMakeImageTransforms:
# @pytest.fixture(autouse=True)
# def setup(self):
# """Seed should be the same as the one that was used to generate artifacts"""
# self.config = {
# "enable": True,
# "max_num_transforms": 1,
# "random_order": False,
# "brightness": {"weight": 0, "min": 2.0, "max": 2.0},
# "contrast": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# "saturation": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# "hue": {
# "weight": 0,
# "min": 0.5,
# "max": 0.5,
# },
# "sharpness": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# }
# self.path = Path("tests/data/save_image_transforms")
# self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
# self.transforms = {
# "brightness": v2.ColorJitter(brightness=(2.0, 2.0)),
# "contrast": v2.ColorJitter(contrast=(2.0, 2.0)),
# "saturation": v2.ColorJitter(saturation=(2.0, 2.0)),
# "hue": v2.ColorJitter(hue=(0.5, 0.5)),
# "sharpness": RangeRandomSharpness(2.0, 2.0),
# }
# @staticmethod
# def load_png_to_tensor(path: Path):
# return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
# @pytest.mark.parametrize(
# "transform_key, seed",
# [
# ("brightness", 1336),
# ("contrast", 1336),
# ("saturation", 1336),
# ("hue", 1336),
# ("sharpness", 1336),
# ],
# )
# def test_single_transform(self, transform_key, seed):
# config = self.config
# config[transform_key]["weight"] = 1
# cfg = OmegaConf.create(config)
# actual_t = make_image_transforms(cfg, to_dtype=torch.uint8)
# with seeded_context(1336):
# actual = actual_t(self.original_frame)
# expected_t = self.transforms[transform_key]
# with seeded_context(1336):
# expected = expected_t(self.original_frame)
# torch.testing.assert_close(actual, expected)