From 714a7762776d7845f69b2eed8c39cf36e76f1795 Mon Sep 17 00:00:00 2001 From: Cadene Date: Fri, 19 Apr 2024 10:33:42 +0000 Subject: [PATCH] id -> index, finish moving compute_stats before hf_dataset push_to_hub --- download_and_upload_dataset.py | 55 ++++++++-------- examples/1_load_hugging_face_dataset.py | 2 +- examples/2_load_lerobot_dataset.py | 2 +- lerobot/common/datasets/aloha.py | 2 +- lerobot/common/datasets/factory.py | 30 +++------ lerobot/common/datasets/utils.py | 87 ++++++++++++++++++------- lerobot/common/datasets/xarm.py | 2 +- lerobot/scripts/train.py | 8 +-- tests/test_datasets.py | 31 +++++---- 9 files changed, 120 insertions(+), 99 deletions(-) diff --git a/download_and_upload_dataset.py b/download_and_upload_dataset.py index 6a54833e5..062db690a 100644 --- a/download_and_upload_dataset.py +++ b/download_and_upload_dataset.py @@ -19,7 +19,7 @@ from huggingface_hub import HfApi from PIL import Image as PILImage from safetensors.torch import save_file -from lerobot.common.datasets.utils import compute_stats, flatten_dict +from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict def download_and_upload(root, revision, dataset_id): @@ -75,28 +75,18 @@ def concatenate_episodes(ep_dicts): for x in ep_dict[key]: data_dict[key].append(x) - total_frames = data_dict["frame_id"].shape[0] + total_frames = data_dict["frame_index"].shape[0] data_dict["index"] = torch.arange(0, total_frames, 1) return data_dict -def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id): - hf_dataset = hf_dataset.with_format("torch") - +def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id): # push to main to indicate latest version hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) # push to version branch hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision) - # get stats - stats_pth_path = root / dataset_id / "stats.pth" - if stats_pth_path.exists(): - stats = torch.load(stats_pth_path) - else: - stats = compute_stats(hf_dataset) - torch.save(stats, stats_pth_path) - # create and store meta_data meta_data_dir = root / dataset_id / "meta_data" meta_data_dir.mkdir(parents=True, exist_ok=True) @@ -237,8 +227,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): "observation.image": [PILImage.fromarray(x.numpy()) for x in image], "observation.state": agent_pos, "action": actions[id_from:id_to], - "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), + "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_index": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / fps, # "next.observation.image": image[1:], # "next.observation.state": agent_pos[1:], @@ -262,8 +252,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_id": Value(dtype="int64", id=None), - "frame_id": Value(dtype="int64", id=None), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), "timestamp": Value(dtype="float32", id=None), "next.reward": Value(dtype="float32", id=None), "next.done": Value(dtype="bool", id=None), @@ -272,11 +262,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10): } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) + hf_dataset = hf_dataset.with_format("torch") + hf_dataset.set_transform(convert_images_to_channel_first_tensors) info = { "fps": fps, } - push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id) + stats = compute_stats(hf_dataset) + push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) def download_and_upload_xarm(root, revision, dataset_id, fps=15): @@ -334,8 +327,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15): "observation.image": [PILImage.fromarray(x.numpy()) for x in image], "observation.state": state, "action": action, - "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), - "frame_id": torch.arange(0, num_frames, 1), + "episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int), + "frame_index": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / fps, # "next.observation.image": next_image, # "next.observation.state": next_state, @@ -358,8 +351,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15): length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_id": Value(dtype="int64", id=None), - "frame_id": Value(dtype="int64", id=None), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), "timestamp": Value(dtype="float32", id=None), "next.reward": Value(dtype="float32", id=None), "next.done": Value(dtype="bool", id=None), @@ -368,11 +361,14 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15): } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) + hf_dataset = hf_dataset.with_format("torch") + hf_dataset.set_transform(convert_images_to_channel_first_tensors) info = { "fps": fps, } - push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id) + stats = compute_stats(hf_dataset) + push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) def download_and_upload_aloha(root, revision, dataset_id, fps=50): @@ -464,8 +460,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50): { "observation.state": state, "action": action, - "episode_id": torch.tensor([ep_id] * num_frames), - "frame_id": torch.arange(0, num_frames, 1), + "episode_index": torch.tensor([ep_id] * num_frames), + "frame_index": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / fps, # "next.observation.state": state, # TODO(rcadene): compute reward and success @@ -493,8 +489,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50): length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) ), "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), - "episode_id": Value(dtype="int64", id=None), - "frame_id": Value(dtype="int64", id=None), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), "timestamp": Value(dtype="float32", id=None), #'next.reward': Value(dtype='float32', id=None), "next.done": Value(dtype="bool", id=None), @@ -503,11 +499,14 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50): } features = Features(features) hf_dataset = Dataset.from_dict(data_dict, features=features) + hf_dataset = hf_dataset.with_format("torch") + hf_dataset.set_transform(convert_images_to_channel_first_tensors) info = { "fps": fps, } - push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id) + stats = compute_stats(hf_dataset) + push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id) if __name__ == "__main__": diff --git a/examples/1_load_hugging_face_dataset.py b/examples/1_load_hugging_face_dataset.py index 17d289146..d70a12865 100644 --- a/examples/1_load_hugging_face_dataset.py +++ b/examples/1_load_hugging_face_dataset.py @@ -49,7 +49,7 @@ print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}") print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}") # select the frames belonging to episode number 5 -hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5) +hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5) # load all frames of episode 5 in RAM in PIL format frames = hf_dataset["observation.image"] diff --git a/examples/2_load_lerobot_dataset.py b/examples/2_load_lerobot_dataset.py index 49a53d8ef..e782e66ff 100644 --- a/examples/2_load_lerobot_dataset.py +++ b/examples/2_load_lerobot_dataset.py @@ -55,7 +55,7 @@ print(f"frames per second used during data collection: {dataset.fps=}") print(f"keys to access images from cameras: {dataset.image_keys=}") # While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5. -dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5) +dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5) # LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames. frames = [sample["observation.image"] for sample in dataset] diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 785b68e5b..6d993df0c 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -54,7 +54,7 @@ class AlohaDataset(torch.utils.data.Dataset): @property def num_episodes(self) -> int: - return len(self.hf_dataset.unique("episode_id")) + return len(self.hf_dataset.unique("episode_index")) def __len__(self): return self.num_samples diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 07afb614b..1d4a751ed 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,11 +1,9 @@ -import logging import os from pathlib import Path import torch from torchvision.transforms import v2 -from lerobot.common.datasets.utils import compute_stats from lerobot.common.transforms import NormalizeTransform, Prod DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None @@ -52,26 +50,14 @@ def make_dataset( stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) elif stats_path is None: - # load stats if the file exists already or compute stats and save it - if DATA_DIR is None: - # TODO(rcadene): clean stats - precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth" - else: - precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth" - if precomputed_stats_path.exists(): - stats = torch.load(precomputed_stats_path) - else: - logging.info(f"compute_stats and save to {precomputed_stats_path}") - # Create a dataset for stats computation. - stats_dataset = clsfunc( - dataset_id=cfg.dataset_id, - split="train", - root=DATA_DIR, - transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), - ) - stats = compute_stats(stats_dataset) - precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True) - torch.save(stats, precomputed_stats_path) + # load a first dataset to access precomputed stats + stats_dataset = clsfunc( + dataset_id=cfg.dataset_id, + split="train", + root=DATA_DIR, + transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), + ) + stats = stats_dataset.stats else: stats = torch.load(stats_path) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 92799c2a2..296f74313 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -10,6 +10,8 @@ from datasets import load_dataset, load_from_disk from huggingface_hub import hf_hub_download from safetensors.torch import load_file +from lerobot.common.utils.utils import set_global_seed + def flatten_dict(d, parent_key="", sep="/"): items = [] @@ -42,7 +44,9 @@ def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset: else: repo_id = f"lerobot/{dataset_id}" hf_dataset = load_dataset(repo_id, revision=version, split=split) - return hf_dataset.with_format("torch") + hf_dataset = hf_dataset.with_format("torch") + hf_dataset.set_transform(convert_images_to_channel_first_tensors) + return hf_dataset def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]: @@ -126,7 +130,7 @@ def load_previous_and_future_frames( issues with timestamps during data collection. """ # get indices of the frames associated to the episode, and their timestamps - ep_id = item["episode_id"].item() + ep_id = item["episode_index"].item() ep_data_id_from = episode_data_index["from"][ep_id].item() ep_data_id_to = episode_data_index["to"][ep_id].item() ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) @@ -168,34 +172,53 @@ def load_previous_and_future_frames( return item -def get_stats_einops_patterns(dataset): - """These einops patterns will be used to aggregate batches and compute statistics.""" - stats_patterns = { - "action": "b c -> c", - "observation.state": "b c -> c", - } - for key in dataset.image_keys: - stats_patterns[key] = "b c h w -> c 1 1" +def convert_images_to_channel_first_tensors(examples): + for key in examples: + if examples[key].ndim == 3: # we assume it's an image + # (h w c) -> (c h w) + h, w, c = examples[key].shape + assert c < h and c < w, f"expect a channel last image, but instead {examples[key].shape}" + examples[key] = [img.permute((2, 0, 1)) for img in examples[key]] + return examples + + +def get_stats_einops_patterns(hf_dataset): + """These einops patterns will be used to aggregate batches and compute statistics. + + Note: We assume the images are returned in channel first format + """ + + dataloader = torch.utils.data.DataLoader( + hf_dataset, + num_workers=0, + batch_size=2, + shuffle=False, + ) + batch = next(iter(dataloader)) + + stats_patterns = {} + for key, feats_type in hf_dataset.features.items(): + if batch[key].ndim == 4 and isinstance(feats_type, datasets.features.image.Image): + # sanity check that images are channel first + _, c, h, w = batch[key].shape + assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" + # convert from (h w c) to (c h w) to fit pytorch convention, then apply reduce + stats_patterns[key] = "b c h w -> c 1 1" + elif batch[key].ndim == 2: + stats_patterns[key] = "b c -> c " + elif batch[key].ndim == 1: + stats_patterns[key] = "b -> 1" + else: + raise ValueError(f"{key}, {feats_type}, {batch[key].shape}") + return stats_patterns -def compute_stats(dataset, batch_size=32, max_num_samples=None): +def compute_stats(hf_dataset, batch_size=32, max_num_samples=None): if max_num_samples is None: - max_num_samples = len(dataset) - else: - raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.") + max_num_samples = len(hf_dataset) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=4, - batch_size=batch_size, - shuffle=False, - # pin_memory=cfg.device != "cpu", - drop_last=False, - ) - - # get einops patterns to aggregate batches and compute statistics - stats_patterns = get_stats_einops_patterns(dataset) + stats_patterns = get_stats_einops_patterns(hf_dataset) # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} @@ -205,10 +228,23 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None): max[key] = torch.tensor(-float("inf")).float() min[key] = torch.tensor(float("inf")).float() + def create_seeded_dataloader(hf_dataset, batch_size, seed): + set_global_seed(seed) + dataloader = torch.utils.data.DataLoader( + hf_dataset, + num_workers=4, + batch_size=batch_size, + shuffle=False, + # pin_memory=cfg.device != "cpu", + drop_last=False, + ) + return dataloader + # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # surprises when rerunning the sampler. first_batch = None running_item_count = 0 # for online mean computation + dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337) for i, batch in enumerate( tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") ): @@ -234,6 +270,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None): first_batch_ = None running_item_count = 0 # for online std computation + dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337) for i, batch in enumerate( tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") ): diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 385b7d999..4adff9e96 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -46,7 +46,7 @@ class XarmDataset(torch.utils.data.Dataset): @property def num_episodes(self) -> int: - return len(self.hf_dataset.unique("episode_id")) + return len(self.hf_dataset.unique("episode_index")) def __len__(self): return self.num_samples diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 4d8c2478c..473bf2370 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -157,7 +157,7 @@ def add_episodes_inplace( Raises: - AssertionError: If the first episode_id or index in hf_dataset is not 0 """ - first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item() + first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item() first_index = hf_dataset.select_columns("index")[0]["index"].item() assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}" assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}" @@ -167,12 +167,12 @@ def add_episodes_inplace( online_dataset.hf_dataset = hf_dataset else: # find episode index and data frame indices according to previous episode in online_dataset - start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1 + start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1 start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1 def shift_indices(example): - # note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to - example["episode_id"] += start_episode + # note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to + example["episode_index"] += start_episode example["index"] += start_index example["episode_data_index_from"] += start_index example["episode_data_index_to"] += start_index diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 3dee5fba3..8b0428d9a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,22 +1,21 @@ -import logging -from copy import deepcopy import json +import logging import os +from copy import deepcopy from pathlib import Path import einops import pytest import torch - from datasets import Dataset import lerobot from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import ( compute_stats, + flatten_dict, get_stats_einops_patterns, load_previous_and_future_frames, - flatten_dict, unflatten_dict, ) from lerobot.common.transforms import Prod @@ -44,8 +43,8 @@ def test_factory(env_name, dataset_id, policy_name): keys_ndim_required = [ ("action", 1, True), - ("episode_id", 0, True), - ("frame_id", 0, True), + ("episode_index", 0, True), + ("frame_index", 0, True), ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? ("observation.state", 1, True), @@ -165,13 +164,13 @@ def test_load_previous_and_future_frames_within_tolerance(): { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_id": [0, 0, 0, 0, 0], + "episode_index": [0, 0, 0, 0, 0], } ) hf_dataset = hf_dataset.with_format("torch") episode_data_index = { - "from": torch.tensor([0]), - "to": torch.tensor([5]), + "from": torch.tensor([0]), + "to": torch.tensor([5]), } delta_timestamps = {"index": [-0.2, 0, 0.139]} tol = 0.04 @@ -187,13 +186,13 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range( { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_id": [0, 0, 0, 0, 0], + "episode_index": [0, 0, 0, 0, 0], } ) hf_dataset = hf_dataset.with_format("torch") episode_data_index = { - "from": torch.tensor([0]), - "to": torch.tensor([5]), + "from": torch.tensor([0]), + "to": torch.tensor([5]), } delta_timestamps = {"index": [-0.2, 0, 0.141]} tol = 0.04 @@ -207,13 +206,13 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], "index": [0, 1, 2, 3, 4], - "episode_id": [0, 0, 0, 0, 0], + "episode_index": [0, 0, 0, 0, 0], } ) hf_dataset = hf_dataset.with_format("torch") episode_data_index = { - "from": torch.tensor([0]), - "to": torch.tensor([5]), + "from": torch.tensor([0]), + "to": torch.tensor([5]), } delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} tol = 0.04 @@ -224,7 +223,7 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range assert torch.equal( is_pad, torch.tensor([True, False, False, True, True]) ), "Padding does not match expected values" - + def test_flatten_unflatten_dict(): d = {