Compare commits

...

12 Commits

Author SHA1 Message Date
Remi Cadene
7cee7a0f20 Add mobile Aloha and visu with rerun.io 2024-04-20 16:19:55 +02:00
Cadene
2a59825a00 fix online training 2024-04-20 00:12:34 +00:00
Cadene
06628ba059 fix online training 2024-04-19 23:58:38 +00:00
Cadene
b2b5329683 fix online training 2024-04-19 23:48:43 +00:00
Cadene
85f1554da8 fix visualize_dataset 2024-04-19 23:40:35 +00:00
Cadene
9b4c2e2a9f small fix 2024-04-19 23:30:39 +00:00
Cadene
20928021c0 Add tests/data 2024-04-19 23:27:11 +00:00
Cadene
c20cf2fbbc Remove Prod, Tests are passind 2024-04-19 23:27:10 +00:00
Cadene
35a573c98e Use v1.1, hf_transform_to_torch, Add 3 xarm datasets 2024-04-19 23:26:13 +00:00
Cadene
714a776277 id -> index, finish moving compute_stats before hf_dataset push_to_hub 2024-04-19 23:25:06 +00:00
Cadene
64b09ea7a7 WIP add load functions + episode_data_index 2024-04-19 23:24:08 +00:00
Cadene
0bd2ca8d82 Add meta_data, revision v1.1 2024-04-19 23:24:08 +00:00
77 changed files with 929 additions and 391 deletions

View File

@@ -208,7 +208,7 @@ HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload $HF_USER/$DATASET data/$DATAS
You will need to set the corresponding version as a default argument in your dataset class:
```python
version: str | None = "v1.0",
version: str | None = "v1.1",
```
See: [`lerobot/common/datasets/pusht.py`](https://github.com/Cadene/lerobot/blob/main/lerobot/common/datasets/pusht.py)

View File

@@ -4,6 +4,7 @@ useless dependencies when using datasets.
"""
import io
import json
import pickle
import shutil
from pathlib import Path
@@ -14,16 +15,22 @@ import numpy as np
import torch
import tqdm
from datasets import Dataset, Features, Image, Sequence, Value
from huggingface_hub import HfApi
from PIL import Image as PILImage
from safetensors.torch import save_file
from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transform_to_torch
def download_and_upload(root, root_tests, dataset_id):
def download_and_upload(root, revision, dataset_id):
if "pusht" in dataset_id:
download_and_upload_pusht(root, root_tests, dataset_id)
download_and_upload_pusht(root, revision, dataset_id)
elif "xarm" in dataset_id:
download_and_upload_xarm(root, root_tests, dataset_id)
elif "aloha" in dataset_id:
download_and_upload_aloha(root, root_tests, dataset_id)
download_and_upload_xarm(root, revision, dataset_id)
elif "aloha_sim" in dataset_id:
download_and_upload_aloha(root, revision, dataset_id)
elif "aloha_mobile" in dataset_id:
download_and_upload_aloha_mobile(root, revision, dataset_id)
else:
raise ValueError(dataset_id)
@@ -56,7 +63,102 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return False
def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id):
# push to main to indicate latest version
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
# push to version branch
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision=revision)
# create and store meta_data
meta_data_dir = root / dataset_id / "meta_data"
meta_data_dir.mkdir(parents=True, exist_ok=True)
api = HfApi()
# info
info_path = meta_data_dir / "info.json"
with open(str(info_path), "w") as f:
json.dump(info, f, indent=4)
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# stats
stats_path = meta_data_dir / "stats.safetensors"
save_file(flatten_dict(stats), stats_path)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# episode_data_index
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{dataset_id}", ""),
repo_id=f"lerobot/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
f"tests/data/{dataset_id}/train"
)
if Path(f"tests/data/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):
try:
import pymunk
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
@@ -99,6 +201,7 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
actions = torch.from_numpy(dataset_dict["action"])
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
@@ -151,8 +254,8 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": agent_pos,
"action": actions[id_from:id_to],
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
@@ -160,28 +263,15 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.image": Image(),
@@ -189,35 +279,35 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
"next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
def download_and_upload_xarm(root, revision, dataset_id, fps=15):
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
raw_dir = root / "xarm_datasets_raw"
if not raw_dir.exists():
import zipfile
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
@@ -234,13 +324,13 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
id_to = 0
episode_id = 0
total_frames = dataset_dict["actions"].shape[0]
for i in tqdm.tqdm(range(total_frames)):
id_to += 1
@@ -264,35 +354,23 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": state,
"action": action,
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from = id_to
episode_id += 1
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.image": Image(),
@@ -300,27 +378,26 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
def download_and_upload_aloha(root, revision, dataset_id, fps=50):
folder_urls = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
@@ -381,6 +458,7 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
@@ -408,40 +486,26 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
{
"observation.state": state,
"action": action,
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_index_from": torch.tensor([id_from] * num_frames),
"episode_data_index_to": torch.tensor([id_from + num_frames] * num_frames),
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
data_dict = {}
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = concatenate_episodes(ep_dicts)
features = {
"observation.images.top": Image(),
@@ -449,39 +513,144 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_index_from": Value(dtype="int64", id=None),
"episode_data_index_to": Value(dtype="int64", id=None),
}
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset = hf_dataset.with_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
def download_and_upload_aloha_mobile(root, revision, dataset_id, fps=50):
num_episodes = {
"aloha_mobile_trossen_block_handoff": 5,
}
# episode_len = {
# "aloha_sim_insertion_human": 500,
# "aloha_sim_insertion_scripted": 400,
# "aloha_sim_transfer_cube_human": 400,
# "aloha_sim_transfer_cube_scripted": 400,
# }
cameras = {
"aloha_mobile_trossen_block_handoff": ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
}
root = Path(root)
raw_dir = root / f"{dataset_id}_raw"
ep_dicts = []
episode_data_index = {"from": [], "to": []}
id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
num_frames = ep["/action"].shape[0]
#assert episode_len[dataset_id] == num_frames
# last step of demonstration is considered done
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:num_frames])
action = torch.from_numpy(ep["/action"][:num_frames])
ep_dict = {}
for cam in cameras[dataset_id]:
image = ep[f"/observations/images/{cam}"][:num_frames] # b h w c
import cv2
# un-pad and uncompress from: https://github.com/MarkFzp/act-plus-plus/blob/26bab0789d05b7496bacef04f5c6b2541a4403b5/postprocess_episodes.py#L50
image = np.array([cv2.imdecode(x, 1) for x in image])
image = [PILImage.fromarray(x) for x in image]
ep_dict[f"observation.images.{cam}"] = image
ep_dict.update(
{
"observation.state": state,
"action": action,
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward,
"next.done": done,
# "next.success": success,
}
)
assert isinstance(ep_id, int)
ep_dicts.append(ep_dict)
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
break
data_dict = concatenate_episodes(ep_dicts)
features = {}
for cam in cameras[dataset_id]:
features[f"observation.images.{cam}"] = Image()
features.update({
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
})
features = Features(features)
hf_dataset = Dataset.from_dict(data_dict, features=features)
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": fps,
}
stats = compute_stats(hf_dataset)
push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dataset_id)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
if __name__ == "__main__":
root = "data"
root_tests = "tests/data"
revision = "v1.1"
dataset_ids = [
# "pusht",
# "xarm_lift_medium",
# "xarm_lift_medium_replay",
# "xarm_push_medium",
# "xarm_push_medium_replay",
# "aloha_sim_insertion_human",
# "aloha_sim_insertion_scripted",
# "aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
# "aloha_sim_transfer_cube_scripted",
"aloha_mobile_trossen_block_handoff",
]
for dataset_id in dataset_ids:
download_and_upload(root, root_tests, dataset_id)
# assume stats have been precomputed
shutil.copy(f"{root}/{dataset_id}/stats.pth", f"{root_tests}/{dataset_id}/stats.pth")
download_and_upload(root, revision, dataset_id)

View File

@@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset
This script supports several Hugging Face datasets, among which:
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
3. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
4. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
5. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
To try a different Hugging Face dataset, you can replace this line:
```python
@@ -22,6 +25,9 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
by one of these:
```python
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
@@ -37,7 +43,7 @@ from datasets import load_dataset
# TODO(rcadene): list available datasets on lerobot page using `datasets`
# download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", revision="v1.0", split="train"), 10
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
# display name of dataset and its features
print(f"{hf_dataset=}")
@@ -45,11 +51,13 @@ print(f"{hf_dataset.features=}")
# display useful statistics about frames and episodes, which are sequences of frames from the same video
print(f"number of frames: {len(hf_dataset)=}")
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
print(
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
)
# select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# load all frames of episode 5 in RAM in PIL format
frames = hf_dataset["observation.image"]

View File

@@ -18,7 +18,10 @@ dataset = PushtDataset()
```
by one of these:
```python
dataset = XarmDataset()
dataset = XarmDataset("xarm_lift_medium")
dataset = XarmDataset("xarm_lift_medium_replay")
dataset = XarmDataset("xarm_push_medium")
dataset = XarmDataset("xarm_push_medium_replay")
dataset = AlohaDataset("aloha_sim_insertion_human")
dataset = AlohaDataset("aloha_sim_insertion_scripted")
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
@@ -55,13 +58,14 @@ print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset]
# but frames are now channel first to follow pytorch convention,
# to view them, we convert to channel last
# but frames are now float32 range [0,1] channel first to follow pytorch convention,
# to view them, we convert to uint8 range [0,255] channel last
frames = [(frame * 255).type(torch.uint8) for frame in frames]
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video

View File

@@ -50,7 +50,12 @@ available_datasets = {
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
"xarm": ["xarm_lift_medium"],
"xarm": [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
],
}
available_policies = [

View File

@@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class AlohaDataset(torch.utils.data.Dataset):
@@ -27,7 +31,7 @@ class AlohaDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.0",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@@ -40,13 +44,10 @@ class AlohaDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@@ -54,7 +55,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
@@ -66,19 +67,11 @@ class AlohaDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -1,12 +1,10 @@
import logging
import os
from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
from lerobot.common.transforms import NormalizeTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@@ -52,32 +50,18 @@ def make_dataset(
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load stats if the file exists already or compute stats and save it
if DATA_DIR is None:
# TODO(rcadene): clean stats
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
else:
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
# Create a dataset for stats computation.
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset)
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(stats, precomputed_stats_path)
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
NormalizeTransform(
stats,
in_keys=[

View File

@@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class PushtDataset(torch.utils.data.Dataset):
@@ -25,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str = "pusht",
version: str | None = "v1.0",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.episode_data_index["from"])
def __len__(self):
return self.num_samples
@@ -64,19 +65,11 @@ class PushtDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -1,15 +1,115 @@
from copy import deepcopy
from math import ceil
from pathlib import Path
import datasets
import einops
import torch
import tqdm
from datasets import Image, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
from lerobot.common.utils.utils import set_global_seed
def flatten_dict(d, parent_key="", sep="/"):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def hf_transform_to_torch(items_dict):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1].
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(Path(root) / dataset_id / split)
else:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tol: float,
) -> dict[torch.Tensor]:
@@ -31,6 +131,8 @@ def load_previous_and_future_frames(
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
@@ -46,12 +148,14 @@ def load_previous_and_future_frames(
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_data_id_from = item["episode_data_index_from"].item()
ep_data_id_to = item["episode_data_index_to"].item()
ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
@@ -82,39 +186,57 @@ def load_previous_and_future_frames(
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
return item
def get_stats_einops_patterns(dataset):
"""These einops patterns will be used to aggregate batches and compute statistics."""
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are returned in channel first format
"""
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(dataset, batch_size=32, max_num_samples=None):
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
max_num_samples = len(hf_dataset)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
stats_patterns = get_stats_einops_patterns(hf_dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
@@ -124,10 +246,22 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
set_global_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
drop_last=False,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
@@ -153,6 +287,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):

View File

@@ -1,25 +1,37 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
https://huggingface.co/datasets/lerobot/xarm_push_medium
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
"""
# Copied from lerobot/__init__.py
available_datasets = ["xarm_lift_medium"]
available_datasets = [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
]
fps = 15
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str = "xarm_lift_medium",
version: str | None = "v1.0",
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@@ -32,13 +44,10 @@ class XarmDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@@ -46,7 +55,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
@@ -58,19 +67,11 @@ class XarmDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None):
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# convert to (b c h w) torch format
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w")
img = img.type(torch.float32)
img /= 255
obs[imgkey] = img
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"

View File

@@ -1,4 +1,3 @@
import torch
from torchvision.transforms.v2 import Compose, Transform
@@ -12,40 +11,6 @@ def apply_inverse_transform(item, transform):
return item
class Prod(Transform):
invertible = True
def __init__(self, in_keys: list[str], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def forward(self, item):
for key in self.in_keys:
if key not in item:
continue
self.original_dtypes[key] = item[key].dtype
item[key] = item[key].type(torch.float32) * self.prod
return item
def inverse_transform(self, item):
for key in self.in_keys:
if key not in item:
continue
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
return item
# def transform_observation_spec(self, obs_spec):
# for key in self.in_keys:
# if obs_spec.get(key, None) is None:
# continue
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
# obs_spec[key].dtype = torch.float32
# return obs_spec
class NormalizeTransform(Transform):
invertible = True

View File

@@ -47,6 +47,7 @@ from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.logger import log_output_dir
@@ -208,11 +209,12 @@ def eval_policy(
max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist())
# similar logic is implemented in dataset preprocessing
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
num_episodes = dones.shape[0]
total_frames = 0
idx_from = 0
id_from = 0
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
total_frames += num_frames
@@ -222,19 +224,20 @@ def eval_policy(
if return_episode_data:
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
idx_from += num_frames
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
# similar logic is implemented in dataset preprocessing
if return_episode_data:
@@ -247,14 +250,29 @@ def eval_policy(
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
# c h w -> h w c
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
for img in ep_dict[key]:
# sanity check that images are channel first
c, h, w = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are float32 in range [0,1]
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
# from float32 in range [0,1] to uint8 in range [0,255]
img *= 255
img = img.type(torch.uint8)
# convert to channel last and numpy as expected by PIL
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
hf_dataset = Dataset.from_dict(data_dict)
hf_dataset.set_transform(hf_transform_to_torch)
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@@ -307,7 +325,10 @@ def eval_policy(
},
}
if return_episode_data:
info["episodes"] = hf_dataset
info["episodes"] = {
"hf_dataset": hf_dataset,
"episode_data_index": episode_data_index,
}
if max_episodes_rendered > 0:
info["videos"] = videos
return info

View File

@@ -136,6 +136,7 @@ def add_episodes_inplace(
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
@@ -151,13 +152,15 @@ def add_episodes_inplace(
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
@@ -167,21 +170,22 @@ def add_episodes_inplace(
online_dataset.hf_dataset = hf_dataset
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_index"] += start_episode
example["index"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices)
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
@@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
seed=cfg.seed,
)
online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace(
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5),
)
for _ in range(cfg.policy.utd):

View File

@@ -22,11 +22,22 @@ def visualize_dataset_cli(cfg: dict):
def cat_and_write_video(video_path, frames, fps):
# Expects images in [0, 255].
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
imageio.mimsave(video_path, frames, fps=fps)
# Expects images in [0, 1].
frame = frames[0]
_, c, h, w = frame.shape
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
# sanity check that images are float32 in range [0,1]
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
# convert to channel last uint8 [0, 255]
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = (frames * 255).type(torch.uint8)
imageio.mimsave(video_path, frames.numpy(), fps=fps)
def visualize_dataset(cfg: dict, out_dir=None):

2
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "absl-py"

30
test.py Normal file
View File

@@ -0,0 +1,30 @@
import rerun as rr
from datasets import load_from_disk
# download/load dataset in pyarrow format
print("Loading dataset…")
#dataset = load_dataset("lerobot/aloha_mobile_trossen_block_handoff", split="train")
dataset = load_from_disk("tests/data/aloha_mobile_trossen_block_handoff/train")
# select the frames belonging to episode number 5
print("Select specific episode…")
print("Starting Rerun…")
rr.init("rerun_example_lerobot", spawn=True)
print("Logging to Rerun…")
# for frame_index, timestamp, cam_high, cam_left_wrist, cam_right_wrist, state, action, next_reward in zip(
for d in dataset:
rr.set_time_sequence("frame_index", d["frame_index"])
rr.set_time_seconds("timestamp", d["timestamp"])
rr.log("observation.images.cam_high", rr.Image( d["observation.images.cam_high"]))
rr.log("observation.images.cam_left_wrist", rr.Image(d["observation.images.cam_left_wrist"]))
rr.log("observation.images.cam_right_wrist", rr.Image(d["observation.images.cam_right_wrist"]))
#rr.log("observation/state", rr.BarChart(state))
#rr.log("observation/action", rr.BarChart(action))
for idx, val in enumerate(d["action"]):
rr.log(f"action_{idx}", rr.Scalar(val))
for idx, val in enumerate(d["observation.state"]):
rr.log(f"state_{idx}", rr.Scalar(val))

View File

@@ -0,0 +1,3 @@
{
"fps": 50
}

View File

@@ -21,11 +21,11 @@
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@@ -37,14 +37,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "d79cf82ffc86f110",
"_fingerprint": "22eeca7a3f4725ee",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@@ -0,0 +1,3 @@
{
"fps": 50
}

View File

@@ -21,11 +21,11 @@
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@@ -37,14 +37,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "d8e4a817b5449498",
"_fingerprint": "97c28d4ad1536e4c",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@@ -0,0 +1,3 @@
{
"fps": 50
}

View File

@@ -21,11 +21,11 @@
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@@ -37,14 +37,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "f03482befa767127",
"_fingerprint": "cb9349b5c92951e8",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@@ -0,0 +1,3 @@
{
"fps": 50
}

View File

@@ -21,11 +21,11 @@
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@@ -37,14 +37,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "93e03c6320c7d56e",
"_fingerprint": "e4d7ad2b360db1af",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@@ -0,0 +1,3 @@
{
"fps": 10
}

Binary file not shown.

View File

@@ -21,11 +21,11 @@
"length": 2,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@@ -45,14 +45,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@@ -0,0 +1,3 @@
{
"fps": 10
}

View File

@@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "21bb9a76ed78a475",
"_fingerprint": "a04a9ce660122e23",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@@ -0,0 +1,3 @@
{
"fps": 15
}

View File

@@ -21,11 +21,11 @@
"length": 4,
"_type": "Sequence"
},
"episode_id": {
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
@@ -41,14 +41,6 @@
"dtype": "bool",
"_type": "Value"
},
"episode_data_index_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_index_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"

View File

@@ -4,7 +4,7 @@
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "a95cbec45e3bb9d6",
"_fingerprint": "cc6afdfcdd6f63ab",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",

View File

@@ -0,0 +1,3 @@
{
"fps": 15
}

View File

@@ -0,0 +1,51 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "9f8e1a8c1845df55",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,3 @@
{
"fps": 15
}

View File

@@ -0,0 +1,51 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 3,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "c900258061dd0b3f",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,3 @@
{
"fps": 15
}

View File

@@ -0,0 +1,51 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 3,
"_type": "Sequence"
},
"episode_index": {
"dtype": "int64",
"_type": "Value"
},
"frame_index": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "e51c80a33c7688c0",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -1,5 +1,7 @@
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
import einops
@@ -11,10 +13,12 @@ import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
get_stats_einops_patterns,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@@ -39,8 +43,8 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required = [
("action", 1, True),
("episode_id", 0, True),
("frame_id", 0, True),
("episode_index", 0, True),
("frame_index", 0, True),
("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos?
("observation.state", 1, True),
@@ -48,12 +52,6 @@ def test_factory(env_name, dataset_id, policy_name):
("next.done", 0, False),
]
for key in image_keys:
keys_ndim_required.append(
(key, 3, True),
)
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required:
if key not in item:
@@ -98,22 +96,18 @@ def test_compute_stats_on_xarm():
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=data_dir,
transform=transform,
)
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
computed_stats = compute_stats(dataset.hf_dataset, batch_size=int(len(dataset) * 0.25))
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
stats_patterns = get_stats_einops_patterns(dataset.hf_dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
@@ -122,18 +116,19 @@ def test_compute_stats_on_xarm():
batch_size=len(dataset),
shuffle=False,
)
hf_dataset = next(iter(dataloader))
full_batch = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
full_batch[k] = full_batch[k].float()
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
# test computed stats match expected stats
for k in stats_patterns:
@@ -142,11 +137,10 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# TODO(rcadene): check that the stats used for training are correct too
# # load stats that are expected to match the ones returned by computed_stats
# assert (dataset.data_dir / "stats.pth").exists()
# loaded_stats = torch.load(dataset.data_dir / "stats.pth")
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
@@ -160,15 +154,18 @@ def test_load_previous_and_future_frames_within_tolerance():
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
@@ -179,16 +176,19 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
@@ -196,17 +196,43 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"