id -> index, finish moving compute_stats before hf_dataset push_to_hub

This commit is contained in:
Cadene
2024-04-19 10:33:42 +00:00
parent 64b09ea7a7
commit 714a776277
9 changed files with 120 additions and 99 deletions

View File

@@ -19,7 +19,7 @@ from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats, flatten_dict
from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
def download_and_upload(root, revision, dataset_id):
@@ -75,28 +75,18 @@ def concatenate_episodes(ep_dicts):
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_id"].shape[0]
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id):
hf_dataset = hf_dataset.with_format("torch")
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
# push to main to indicate latest version
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
# push to version branch
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
# get stats
stats_pth_path = root / dataset_id / "stats.pth"
if stats_pth_path.exists():
stats = torch.load(stats_pth_path)
else:
stats = compute_stats(hf_dataset)
torch.save(stats, stats_pth_path)
# create and store meta_data
meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True)
@@ -237,8 +227,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": agent_pos,
"action": actions[id_from:id_to],
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
@@ -262,8 +252,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
@@ -272,11 +262,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = {
"fps": fps,
}
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
@@ -334,8 +327,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": state,
"action": action,
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
@@ -358,8 +351,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
@@ -368,11 +361,14 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = {
"fps": fps,
}
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
@@ -464,8 +460,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
{
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
@@ -493,8 +489,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
@@ -503,11 +499,14 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
info = {
"fps": fps,
}
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
if __name__ == "__main__":

View File

@@ -49,7 +49,7 @@ print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
# select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# load all frames of episode 5 in RAM in PIL format
frames = hf_dataset["observation.image"]

View File

@@ -55,7 +55,7 @@ print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset]

View File

@@ -54,7 +54,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples

View File

@@ -1,11 +1,9 @@
import logging
import os
from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@@ -52,26 +50,14 @@ def make_dataset(
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load stats if the file exists already or compute stats and save it
if DATA_DIR is None:
# TODO(rcadene): clean stats
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
else:
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
# Create a dataset for stats computation.
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset)
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(stats, precomputed_stats_path)
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)

View File

@@ -10,6 +10,8 @@ from datasets import load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from lerobot.common.utils.utils import set_global_seed
def flatten_dict(d, parent_key="", sep="/"):
items = []
@@ -42,7 +44,9 @@ def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
else:
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
return hf_dataset.with_format("torch")
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
@@ -126,7 +130,7 @@ def load_previous_and_future_frames(
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_id = item["episode_id"].item()
ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
@@ -168,34 +172,53 @@ def load_previous_and_future_frames(
return item
def get_stats_einops_patterns(dataset):
"""These einops patterns will be used to aggregate batches and compute statistics."""
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
def convert_images_to_channel_first_tensors(examples):
for key in examples:
if examples[key].ndim == 3: # we assume it's an image
# (h w c) -> (c h w)
h, w, c = examples[key].shape
assert c < h and c < w, f"expect a channel last image, but instead {examples[key].shape}"
examples[key] = [img.permute((2, 0, 1)) for img in examples[key]]
return examples
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are returned in channel first format
"""
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
if batch[key].ndim == 4 and isinstance(feats_type, datasets.features.image.Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# convert from (h w c) to (c h w) to fit pytorch convention, then apply reduce
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(dataset, batch_size=32, max_num_samples=None):
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
max_num_samples = len(hf_dataset)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
stats_patterns = get_stats_einops_patterns(hf_dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
@@ -205,10 +228,23 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
set_global_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
@@ -234,6 +270,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):

View File

@@ -46,7 +46,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples

View File

@@ -157,7 +157,7 @@ def add_episodes_inplace(
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
@@ -167,12 +167,12 @@ def add_episodes_inplace(
online_dataset.hf_dataset = hf_dataset
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_index"] += start_episode
example["index"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index

View File

@@ -1,22 +1,21 @@
import logging
from copy import deepcopy
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
import einops
import pytest
import torch
from datasets import Dataset
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
get_stats_einops_patterns,
load_previous_and_future_frames,
flatten_dict,
unflatten_dict,
)
from lerobot.common.transforms import Prod
@@ -44,8 +43,8 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required = [
("action", 1, True),
("episode_id", 0, True),
("frame_id", 0, True),
("episode_index", 0, True),
("frame_index", 0, True),
("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos?
("observation.state", 1, True),
@@ -165,13 +164,13 @@ def test_load_previous_and_future_frames_within_tolerance():
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
@@ -187,13 +186,13 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
@@ -207,13 +206,13 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_id": [0, 0, 0, 0, 0],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
@@ -224,7 +223,7 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
def test_flatten_unflatten_dict():
d = {