id -> index, finish moving compute_stats before hf_dataset push_to_hub
This commit is contained in:
@@ -19,7 +19,7 @@ from huggingface_hub import HfApi
|
|||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
from lerobot.common.datasets.utils import compute_stats, convert_images_to_channel_first_tensors, flatten_dict
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload(root, revision, dataset_id):
|
def download_and_upload(root, revision, dataset_id):
|
||||||
@@ -75,28 +75,18 @@ def concatenate_episodes(ep_dicts):
|
|||||||
for x in ep_dict[key]:
|
for x in ep_dict[key]:
|
||||||
data_dict[key].append(x)
|
data_dict[key].append(x)
|
||||||
|
|
||||||
total_frames = data_dict["frame_id"].shape[0]
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
def push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id):
|
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
|
||||||
|
|
||||||
# push to main to indicate latest version
|
# push to main to indicate latest version
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
|
|
||||||
# push to version branch
|
# push to version branch
|
||||||
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
|
||||||
|
|
||||||
# get stats
|
|
||||||
stats_pth_path = root / dataset_id / "stats.pth"
|
|
||||||
if stats_pth_path.exists():
|
|
||||||
stats = torch.load(stats_pth_path)
|
|
||||||
else:
|
|
||||||
stats = compute_stats(hf_dataset)
|
|
||||||
torch.save(stats, stats_pth_path)
|
|
||||||
|
|
||||||
# create and store meta_data
|
# create and store meta_data
|
||||||
meta_data_dir = root / dataset_id / "meta_data"
|
meta_data_dir = root / dataset_id / "meta_data"
|
||||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -237,8 +227,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
|||||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||||
"observation.state": agent_pos,
|
"observation.state": agent_pos,
|
||||||
"action": actions[id_from:id_to],
|
"action": actions[id_from:id_to],
|
||||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
# "next.observation.image": image[1:],
|
# "next.observation.image": image[1:],
|
||||||
# "next.observation.state": agent_pos[1:],
|
# "next.observation.state": agent_pos[1:],
|
||||||
@@ -262,8 +252,8 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
|||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
),
|
),
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
"episode_id": Value(dtype="int64", id=None),
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
"frame_id": Value(dtype="int64", id=None),
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
"next.reward": Value(dtype="float32", id=None),
|
||||||
"next.done": Value(dtype="bool", id=None),
|
"next.done": Value(dtype="bool", id=None),
|
||||||
@@ -272,11 +262,14 @@ def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
|
|||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
}
|
}
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
|
stats = compute_stats(hf_dataset)
|
||||||
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
||||||
@@ -334,8 +327,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
|||||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||||
"observation.state": state,
|
"observation.state": state,
|
||||||
"action": action,
|
"action": action,
|
||||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
# "next.observation.image": next_image,
|
# "next.observation.image": next_image,
|
||||||
# "next.observation.state": next_state,
|
# "next.observation.state": next_state,
|
||||||
@@ -358,8 +351,8 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
|||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
),
|
),
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
"episode_id": Value(dtype="int64", id=None),
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
"frame_id": Value(dtype="int64", id=None),
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
"next.reward": Value(dtype="float32", id=None),
|
"next.reward": Value(dtype="float32", id=None),
|
||||||
"next.done": Value(dtype="bool", id=None),
|
"next.done": Value(dtype="bool", id=None),
|
||||||
@@ -368,11 +361,14 @@ def download_and_upload_xarm(root, revision, dataset_id, fps=15):
|
|||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
}
|
}
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
|
stats = compute_stats(hf_dataset)
|
||||||
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
||||||
@@ -464,8 +460,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
|||||||
{
|
{
|
||||||
"observation.state": state,
|
"observation.state": state,
|
||||||
"action": action,
|
"action": action,
|
||||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||||
"frame_id": torch.arange(0, num_frames, 1),
|
"frame_index": torch.arange(0, num_frames, 1),
|
||||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||||
# "next.observation.state": state,
|
# "next.observation.state": state,
|
||||||
# TODO(rcadene): compute reward and success
|
# TODO(rcadene): compute reward and success
|
||||||
@@ -493,8 +489,8 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
|||||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||||
),
|
),
|
||||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||||
"episode_id": Value(dtype="int64", id=None),
|
"episode_index": Value(dtype="int64", id=None),
|
||||||
"frame_id": Value(dtype="int64", id=None),
|
"frame_index": Value(dtype="int64", id=None),
|
||||||
"timestamp": Value(dtype="float32", id=None),
|
"timestamp": Value(dtype="float32", id=None),
|
||||||
#'next.reward': Value(dtype='float32', id=None),
|
#'next.reward': Value(dtype='float32', id=None),
|
||||||
"next.done": Value(dtype="bool", id=None),
|
"next.done": Value(dtype="bool", id=None),
|
||||||
@@ -503,11 +499,14 @@ def download_and_upload_aloha(root, revision, dataset_id, fps=50):
|
|||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
}
|
}
|
||||||
push_to_hub(hf_dataset, episode_data_index, info, root, revision, dataset_id)
|
stats = compute_stats(hf_dataset)
|
||||||
|
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
|
|||||||
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
|
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
|
||||||
|
|
||||||
# select the frames belonging to episode number 5
|
# select the frames belonging to episode number 5
|
||||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||||
|
|
||||||
# load all frames of episode 5 in RAM in PIL format
|
# load all frames of episode 5 in RAM in PIL format
|
||||||
frames = hf_dataset["observation.image"]
|
frames = hf_dataset["observation.image"]
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ print(f"frames per second used during data collection: {dataset.fps=}")
|
|||||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||||
|
|
||||||
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||||
|
|
||||||
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
|
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
|
||||||
frames = [sample["observation.image"] for sample in dataset]
|
frames = [sample["observation.image"] for sample in dataset]
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
return len(self.hf_dataset.unique("episode_id"))
|
return len(self.hf_dataset.unique("episode_index"))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats
|
|
||||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
@@ -52,26 +50,14 @@ def make_dataset(
|
|||||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
elif stats_path is None:
|
elif stats_path is None:
|
||||||
# load stats if the file exists already or compute stats and save it
|
# load a first dataset to access precomputed stats
|
||||||
if DATA_DIR is None:
|
stats_dataset = clsfunc(
|
||||||
# TODO(rcadene): clean stats
|
dataset_id=cfg.dataset_id,
|
||||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
split="train",
|
||||||
else:
|
root=DATA_DIR,
|
||||||
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
|
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||||
if precomputed_stats_path.exists():
|
)
|
||||||
stats = torch.load(precomputed_stats_path)
|
stats = stats_dataset.stats
|
||||||
else:
|
|
||||||
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
|
||||||
# Create a dataset for stats computation.
|
|
||||||
stats_dataset = clsfunc(
|
|
||||||
dataset_id=cfg.dataset_id,
|
|
||||||
split="train",
|
|
||||||
root=DATA_DIR,
|
|
||||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
|
||||||
)
|
|
||||||
stats = compute_stats(stats_dataset)
|
|
||||||
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
torch.save(stats, precomputed_stats_path)
|
|
||||||
else:
|
else:
|
||||||
stats = torch.load(stats_path)
|
stats = torch.load(stats_path)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from datasets import load_dataset, load_from_disk
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from lerobot.common.utils.utils import set_global_seed
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(d, parent_key="", sep="/"):
|
def flatten_dict(d, parent_key="", sep="/"):
|
||||||
items = []
|
items = []
|
||||||
@@ -42,7 +44,9 @@ def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
|
|||||||
else:
|
else:
|
||||||
repo_id = f"lerobot/{dataset_id}"
|
repo_id = f"lerobot/{dataset_id}"
|
||||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||||
return hf_dataset.with_format("torch")
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
hf_dataset.set_transform(convert_images_to_channel_first_tensors)
|
||||||
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
|
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
|
||||||
@@ -126,7 +130,7 @@ def load_previous_and_future_frames(
|
|||||||
issues with timestamps during data collection.
|
issues with timestamps during data collection.
|
||||||
"""
|
"""
|
||||||
# get indices of the frames associated to the episode, and their timestamps
|
# get indices of the frames associated to the episode, and their timestamps
|
||||||
ep_id = item["episode_id"].item()
|
ep_id = item["episode_index"].item()
|
||||||
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
||||||
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
||||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||||
@@ -168,34 +172,53 @@ def load_previous_and_future_frames(
|
|||||||
return item
|
return item
|
||||||
|
|
||||||
|
|
||||||
def get_stats_einops_patterns(dataset):
|
def convert_images_to_channel_first_tensors(examples):
|
||||||
"""These einops patterns will be used to aggregate batches and compute statistics."""
|
for key in examples:
|
||||||
stats_patterns = {
|
if examples[key].ndim == 3: # we assume it's an image
|
||||||
"action": "b c -> c",
|
# (h w c) -> (c h w)
|
||||||
"observation.state": "b c -> c",
|
h, w, c = examples[key].shape
|
||||||
}
|
assert c < h and c < w, f"expect a channel last image, but instead {examples[key].shape}"
|
||||||
for key in dataset.image_keys:
|
examples[key] = [img.permute((2, 0, 1)) for img in examples[key]]
|
||||||
stats_patterns[key] = "b c h w -> c 1 1"
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
def get_stats_einops_patterns(hf_dataset):
|
||||||
|
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||||
|
|
||||||
|
Note: We assume the images are returned in channel first format
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
hf_dataset,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
|
stats_patterns = {}
|
||||||
|
for key, feats_type in hf_dataset.features.items():
|
||||||
|
if batch[key].ndim == 4 and isinstance(feats_type, datasets.features.image.Image):
|
||||||
|
# sanity check that images are channel first
|
||||||
|
_, c, h, w = batch[key].shape
|
||||||
|
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||||
|
# convert from (h w c) to (c h w) to fit pytorch convention, then apply reduce
|
||||||
|
stats_patterns[key] = "b c h w -> c 1 1"
|
||||||
|
elif batch[key].ndim == 2:
|
||||||
|
stats_patterns[key] = "b c -> c "
|
||||||
|
elif batch[key].ndim == 1:
|
||||||
|
stats_patterns[key] = "b -> 1"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||||
|
|
||||||
return stats_patterns
|
return stats_patterns
|
||||||
|
|
||||||
|
|
||||||
def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
||||||
if max_num_samples is None:
|
if max_num_samples is None:
|
||||||
max_num_samples = len(dataset)
|
max_num_samples = len(hf_dataset)
|
||||||
else:
|
|
||||||
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
|
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
stats_patterns = get_stats_einops_patterns(hf_dataset)
|
||||||
dataset,
|
|
||||||
num_workers=4,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
# pin_memory=cfg.device != "cpu",
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get einops patterns to aggregate batches and compute statistics
|
|
||||||
stats_patterns = get_stats_einops_patterns(dataset)
|
|
||||||
|
|
||||||
# mean and std will be computed incrementally while max and min will track the running value.
|
# mean and std will be computed incrementally while max and min will track the running value.
|
||||||
mean, std, max, min = {}, {}, {}, {}
|
mean, std, max, min = {}, {}, {}, {}
|
||||||
@@ -205,10 +228,23 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
|||||||
max[key] = torch.tensor(-float("inf")).float()
|
max[key] = torch.tensor(-float("inf")).float()
|
||||||
min[key] = torch.tensor(float("inf")).float()
|
min[key] = torch.tensor(float("inf")).float()
|
||||||
|
|
||||||
|
def create_seeded_dataloader(hf_dataset, batch_size, seed):
|
||||||
|
set_global_seed(seed)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
hf_dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
# pin_memory=cfg.device != "cpu",
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||||
# surprises when rerunning the sampler.
|
# surprises when rerunning the sampler.
|
||||||
first_batch = None
|
first_batch = None
|
||||||
running_item_count = 0 # for online mean computation
|
running_item_count = 0 # for online mean computation
|
||||||
|
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||||
for i, batch in enumerate(
|
for i, batch in enumerate(
|
||||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||||
):
|
):
|
||||||
@@ -234,6 +270,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
|||||||
|
|
||||||
first_batch_ = None
|
first_batch_ = None
|
||||||
running_item_count = 0 # for online std computation
|
running_item_count = 0 # for online std computation
|
||||||
|
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||||
for i, batch in enumerate(
|
for i, batch in enumerate(
|
||||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
return len(self.hf_dataset.unique("episode_id"))
|
return len(self.hf_dataset.unique("episode_index"))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ def add_episodes_inplace(
|
|||||||
Raises:
|
Raises:
|
||||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||||
"""
|
"""
|
||||||
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
|
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
||||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||||
@@ -167,12 +167,12 @@ def add_episodes_inplace(
|
|||||||
online_dataset.hf_dataset = hf_dataset
|
online_dataset.hf_dataset = hf_dataset
|
||||||
else:
|
else:
|
||||||
# find episode index and data frame indices according to previous episode in online_dataset
|
# find episode index and data frame indices according to previous episode in online_dataset
|
||||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
|
||||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
||||||
|
|
||||||
def shift_indices(example):
|
def shift_indices(example):
|
||||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||||
example["episode_id"] += start_episode
|
example["episode_index"] += start_episode
|
||||||
example["index"] += start_index
|
example["index"] += start_index
|
||||||
example["episode_data_index_from"] += start_index
|
example["episode_data_index_from"] += start_index
|
||||||
example["episode_data_index_to"] += start_index
|
example["episode_data_index_to"] += start_index
|
||||||
|
|||||||
@@ -1,22 +1,21 @@
|
|||||||
import logging
|
|
||||||
from copy import deepcopy
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
compute_stats,
|
compute_stats,
|
||||||
|
flatten_dict,
|
||||||
get_stats_einops_patterns,
|
get_stats_einops_patterns,
|
||||||
load_previous_and_future_frames,
|
load_previous_and_future_frames,
|
||||||
flatten_dict,
|
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
from lerobot.common.transforms import Prod
|
from lerobot.common.transforms import Prod
|
||||||
@@ -44,8 +43,8 @@ def test_factory(env_name, dataset_id, policy_name):
|
|||||||
|
|
||||||
keys_ndim_required = [
|
keys_ndim_required = [
|
||||||
("action", 1, True),
|
("action", 1, True),
|
||||||
("episode_id", 0, True),
|
("episode_index", 0, True),
|
||||||
("frame_id", 0, True),
|
("frame_index", 0, True),
|
||||||
("timestamp", 0, True),
|
("timestamp", 0, True),
|
||||||
# TODO(rcadene): should we rename it agent_pos?
|
# TODO(rcadene): should we rename it agent_pos?
|
||||||
("observation.state", 1, True),
|
("observation.state", 1, True),
|
||||||
@@ -165,13 +164,13 @@ def test_load_previous_and_future_frames_within_tolerance():
|
|||||||
{
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_id": [0, 0, 0, 0, 0],
|
"episode_index": [0, 0, 0, 0, 0],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
episode_data_index = {
|
episode_data_index = {
|
||||||
"from": torch.tensor([0]),
|
"from": torch.tensor([0]),
|
||||||
"to": torch.tensor([5]),
|
"to": torch.tensor([5]),
|
||||||
}
|
}
|
||||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
@@ -187,13 +186,13 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
|
|||||||
{
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_id": [0, 0, 0, 0, 0],
|
"episode_index": [0, 0, 0, 0, 0],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
episode_data_index = {
|
episode_data_index = {
|
||||||
"from": torch.tensor([0]),
|
"from": torch.tensor([0]),
|
||||||
"to": torch.tensor([5]),
|
"to": torch.tensor([5]),
|
||||||
}
|
}
|
||||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
@@ -207,13 +206,13 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
|
|||||||
{
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_id": [0, 0, 0, 0, 0],
|
"episode_index": [0, 0, 0, 0, 0],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
hf_dataset = hf_dataset.with_format("torch")
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
episode_data_index = {
|
episode_data_index = {
|
||||||
"from": torch.tensor([0]),
|
"from": torch.tensor([0]),
|
||||||
"to": torch.tensor([5]),
|
"to": torch.tensor([5]),
|
||||||
}
|
}
|
||||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
@@ -224,7 +223,7 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
|
|||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
is_pad, torch.tensor([True, False, False, True, True])
|
is_pad, torch.tensor([True, False, False, True, True])
|
||||||
), "Padding does not match expected values"
|
), "Padding does not match expected values"
|
||||||
|
|
||||||
|
|
||||||
def test_flatten_unflatten_dict():
|
def test_flatten_unflatten_dict():
|
||||||
d = {
|
d = {
|
||||||
|
|||||||
Reference in New Issue
Block a user