diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 28701931..0b7ed24b 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -1,5 +1,7 @@ +from pathlib import Path + import torch -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from lerobot.common.datasets.utils import load_previous_and_future_frames @@ -25,15 +27,24 @@ class AlohaDataset(torch.utils.data.Dataset): self, dataset_id: str, version: str | None = "v1.0", + root: Path | None = None, + split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() self.dataset_id = dataset_id self.version = version + self.root = root + self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train") + if self.root is not None: + self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split) + else: + self.data_dict = load_dataset( + f"lerobot/{self.dataset_id}", revision=self.version, split=self.split + ) self.data_dict = self.data_dict.with_format("torch") @property @@ -57,6 +68,15 @@ class AlohaDataset(torch.utils.data.Dataset): self.delta_timestamps, ) + # convert images from channel last (PIL) to channel first (pytorch) + for key in self.image_keys: + if item[key].ndim == 3: + item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w + elif item[key].ndim == 4: + item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w + else: + raise ValueError(item[key].ndim) + if self.transform is not None: item = self.transform(item) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index ee9285a4..3bf684c9 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,4 +1,5 @@ import logging +import os from pathlib import Path import torch @@ -7,12 +8,15 @@ from torchvision.transforms import v2 from lerobot.common.datasets.utils import compute_stats from lerobot.common.transforms import NormalizeTransform, Prod +DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None + def make_dataset( cfg, # set normalize=False to remove all transformations and keep images unnormalized in [0,255] normalize=True, stats_path=None, + split="train", ): if cfg.env.name == "xarm": from lerobot.common.datasets.xarm import XarmDataset @@ -57,6 +61,8 @@ def make_dataset( # instantiate a one frame dataset with light transform stats_dataset = clsfunc( dataset_id=cfg.dataset_id, + split="train", + root=DATA_DIR, transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), ) stats = compute_stats(stats_dataset) @@ -86,6 +92,8 @@ def make_dataset( dataset = clsfunc( dataset_id=cfg.dataset_id, + split=split, + root=DATA_DIR, delta_timestamps=delta_timestamps, transform=transforms, ) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index c2705c2a..93a4a002 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -1,5 +1,7 @@ +from pathlib import Path + import torch -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from lerobot.common.datasets.utils import load_previous_and_future_frames @@ -23,18 +25,25 @@ class PushtDataset(torch.utils.data.Dataset): self, dataset_id: str, version: str | None = "v1.0", + root: Path | None = None, + split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() self.dataset_id = dataset_id self.version = version + self.root = root + self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - # self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train") - self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train") + if self.root is not None: + self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split) + else: + self.data_dict = load_dataset( + f"lerobot/{self.dataset_id}", revision=self.version, split=self.split + ) self.data_dict = self.data_dict.with_format("torch") - self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") @property def num_samples(self) -> int: @@ -57,6 +66,15 @@ class PushtDataset(torch.utils.data.Dataset): self.delta_timestamps, ) + # convert images from channel last (PIL) to channel first (pytorch) + for key in self.image_keys: + if item[key].ndim == 3: + item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w + elif item[key].ndim == 4: + item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w + else: + raise ValueError(item[key].ndim) + if self.transform is not None: item = self.transform(item) diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 729b8c6c..605dd1eb 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -1,5 +1,7 @@ +from pathlib import Path + import torch -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from lerobot.common.datasets.utils import load_previous_and_future_frames @@ -19,15 +21,24 @@ class XarmDataset(torch.utils.data.Dataset): self, dataset_id: str, version: str | None = "v1.0", + root: Path | None = None, + split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() self.dataset_id = dataset_id self.version = version + self.root = root + self.split = split self.transform = transform self.delta_timestamps = delta_timestamps - self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train") + if self.root is not None: + self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split) + else: + self.data_dict = load_dataset( + f"lerobot/{self.dataset_id}", revision=self.version, split=self.split + ) self.data_dict = self.data_dict.with_format("torch") @property @@ -51,6 +62,15 @@ class XarmDataset(torch.utils.data.Dataset): self.delta_timestamps, ) + # convert images from channel last (PIL) to channel first (pytorch) + for key in self.image_keys: + if item[key].ndim == 3: + item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w + elif item[key].ndim == 4: + item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w + else: + raise ValueError(item[key].ndim) + if self.transform is not None: item = self.transform(item) diff --git a/lerobot/scripts/download_and_upload_dataset.py b/lerobot/scripts/download_and_upload_dataset.py index 40d218ab..267b619d 100644 --- a/lerobot/scripts/download_and_upload_dataset.py +++ b/lerobot/scripts/download_and_upload_dataset.py @@ -12,7 +12,8 @@ import h5py import numpy as np import torch import tqdm -from datasets import Dataset +from datasets import Dataset, Features, Image, Sequence, Value +from PIL import Image as PILImage def download_and_extract_zip(url: str, destination_folder: Path) -> bool: @@ -73,10 +74,6 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10): episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) num_episodes = dataset_dict.meta["episode_ends"].shape[0] - total_frames = dataset_dict["action"].shape[0] - # to create test artifact - # num_episodes = 1 - # total_frames = 50 assert len( {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118 ), "Some data type dont have the same number of total frames." @@ -85,28 +82,27 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10): goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) - imgs = torch.from_numpy(dataset_dict["img"]) - imgs = einops.rearrange(imgs, "b h w c -> b c h w") + imgs = torch.from_numpy(dataset_dict["img"]) # b h w c states = torch.from_numpy(dataset_dict["state"]) actions = torch.from_numpy(dataset_dict["action"]) data_ids_per_episode = {} ep_dicts = [] - idx0 = 0 + id_from = 0 for episode_id in tqdm.tqdm(range(num_episodes)): - idx1 = dataset_dict.meta["episode_ends"][episode_id] + id_to = dataset_dict.meta["episode_ends"][episode_id] - num_frames = idx1 - idx0 + num_frames = id_to - id_from - assert (episode_ids[idx0:idx1] == episode_id).all() + assert (episode_ids[id_from:id_to] == episode_id).all() - image = imgs[idx0:idx1] + image = imgs[id_from:id_to] assert image.min() >= 0.0 assert image.max() <= 255.0 image = image.type(torch.uint8) - state = states[idx0:idx1] + state = states[id_from:id_to] agent_pos = state[:, :2] block_pos = state[:, 2:4] block_angle = state[:, 4] @@ -141,9 +137,9 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10): done[-1] = True ep_dict = { - "observation.image": image, + "observation.image": [PILImage.fromarray(x.numpy()) for x in image], "observation.state": agent_pos, - "action": actions[idx0:idx1], + "action": actions[id_from:id_to], "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), "frame_id": torch.arange(0, num_frames, 1), "timestamp": torch.arange(0, num_frames, 1) / fps, @@ -153,33 +149,55 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10): "next.reward": torch.cat([reward[1:], reward[[-1]]]), "next.done": torch.cat([done[1:], done[[-1]]]), "next.success": torch.cat([success[1:], success[[-1]]]), + "episode_data_id_from": torch.tensor([id_from] * num_frames), + "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), } ep_dicts.append(ep_dict) assert isinstance(episode_id, int) - data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) + data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1) assert len(data_ids_per_episode[episode_id]) == num_frames - idx0 = idx1 + id_from += num_frames data_dict = {} keys = ep_dicts[0].keys() for key in keys: - data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) + total_frames = id_from data_dict["index"] = torch.arange(0, total_frames, 1) - dataset = Dataset.from_dict(data_dict) + features = { + "observation.image": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_id": Value(dtype="int64", id=None), + "frame_id": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "next.reward": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + "next.success": Value(dtype="bool", id=None), + "index": Value(dtype="int64", id=None), + "episode_data_id_from": Value(dtype="int64", id=None), + "episode_data_id_to": Value(dtype="int64", id=None), + } + features = Features(features) + dataset = Dataset.from_dict(data_dict, features=features) dataset = dataset.with_format("torch") - def add_episode_data_id_from_to(frame): - ep_id = frame["episode_id"].item() - frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0] - frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1] - return frame - - dataset = dataset.map(add_episode_data_id_from_to, num_proc=4) + num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] + dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") @@ -211,32 +229,32 @@ def download_and_upload_xarm(root, dataset_id, fps=15): total_frames = dataset_dict["actions"].shape[0] - data_ids_per_episode = {} ep_dicts = [] - idx0 = 0 - idx1 = 0 + id_from = 0 + id_to = 0 episode_id = 0 for i in tqdm.tqdm(range(total_frames)): - idx1 += 1 + id_to += 1 if not dataset_dict["dones"][i]: continue - num_frames = idx1 - idx0 + num_frames = id_to - id_from - image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) - state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) - action = torch.tensor(dataset_dict["actions"][idx0:idx1]) + image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to]) + image = einops.rearrange(image, "b c h w -> b h w c") + state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to]) + action = torch.tensor(dataset_dict["actions"][id_from:id_to]) # TODO(rcadene): we have a missing last frame which is the observation when the env is done # it is critical to have this frame for tdmpc to predict a "done observation/state" - # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) - # next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) - next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) - next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) + # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to]) + # next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to]) + next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to]) + next_done = torch.tensor(dataset_dict["dones"][id_from:id_to]) ep_dict = { - "observation.image": image, + "observation.image": [PILImage.fromarray(x.numpy()) for x in image], "observation.state": state, "action": action, "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), @@ -246,34 +264,51 @@ def download_and_upload_xarm(root, dataset_id, fps=15): # "next.observation.state": next_state, "next.reward": next_reward, "next.done": next_done, + "episode_data_id_from": torch.tensor([id_from] * num_frames), + "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), } ep_dicts.append(ep_dict) - assert isinstance(episode_id, int) - data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) - assert len(data_ids_per_episode[episode_id]) == num_frames - - idx0 = idx1 + id_from = id_to episode_id += 1 data_dict = {} - keys = ep_dicts[0].keys() for key in keys: - data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) + total_frames = id_from data_dict["index"] = torch.arange(0, total_frames, 1) - dataset = Dataset.from_dict(data_dict) + features = { + "observation.image": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_id": Value(dtype="int64", id=None), + "frame_id": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "next.reward": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + #'next.success': Value(dtype='bool', id=None), + "index": Value(dtype="int64", id=None), + "episode_data_id_from": Value(dtype="int64", id=None), + "episode_data_id_to": Value(dtype="int64", id=None), + } + features = Features(features) + dataset = Dataset.from_dict(data_dict, features=features) dataset = dataset.with_format("torch") - def add_episode_data_id_from_to(frame): - ep_id = frame["episode_id"].item() - frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0] - frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1] - return frame - - dataset = dataset.map(add_episode_data_id_from_to, num_proc=4) + num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] + dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") @@ -338,10 +373,9 @@ def download_and_upload_aloha(root, dataset_id, fps=50): gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True) gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True) - data_ids_per_episode = {} ep_dicts = [] - frame_idx = 0 + id_from = 0 for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])): ep_path = raw_dir / f"episode_{ep_id}.hdf5" with h5py.File(ep_path, "r") as ep: @@ -366,49 +400,68 @@ def download_and_upload_aloha(root, dataset_id, fps=50): # "next.reward": reward, "next.done": done, # "next.success": success, + "episode_data_id_from": torch.tensor([id_from] * num_frames), + "episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames), } for cam in cameras[dataset_id]: - image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) - image = einops.rearrange(image, "b h w c -> b c h w").contiguous() - ep_dict[f"observation.images.{cam}"] = image + image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c + # image = einops.rearrange(image, "b h w c -> b c h w").contiguous() + ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image] # ep_dict[f"next.observation.images.{cam}"] = image assert isinstance(ep_id, int) - data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1) - assert len(data_ids_per_episode[ep_id]) == num_frames - ep_dicts.append(ep_dict) - frame_idx += num_frames + id_from += num_frames data_dict = {} + data_dict = {} keys = ep_dicts[0].keys() for key in keys: - data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + if torch.is_tensor(ep_dicts[0][key][0]): + data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) + else: + if key not in data_dict: + data_dict[key] = [] + for ep_dict in ep_dicts: + for x in ep_dict[key]: + data_dict[key].append(x) - total_frames = frame_idx + total_frames = id_from data_dict["index"] = torch.arange(0, total_frames, 1) - dataset = Dataset.from_dict(data_dict) + features = { + "observation.images.top": Image(), + "observation.state": Sequence( + length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + ), + "action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)), + "episode_id": Value(dtype="int64", id=None), + "frame_id": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + #'next.reward': Value(dtype='float32', id=None), + "next.done": Value(dtype="bool", id=None), + #'next.success': Value(dtype='bool', id=None), + "index": Value(dtype="int64", id=None), + "episode_data_id_from": Value(dtype="int64", id=None), + "episode_data_id_to": Value(dtype="int64", id=None), + } + features = Features(features) + dataset = Dataset.from_dict(data_dict, features=features) dataset = dataset.with_format("torch") - def add_episode_data_id_from_to(frame): - ep_id = frame["episode_id"].item() - frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0] - frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1] - return frame - - dataset = dataset.map(add_episode_data_id_from_to) + num_items_first_ep = ep_dicts[0]["frame_id"].shape[0] + dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") if __name__ == "__main__": root = "data" - # download_and_upload_pusht(root, dataset_id="pusht") - # download_and_upload_xarm(root, dataset_id="xarm_lift_medium") + download_and_upload_pusht(root, dataset_id="pusht") + download_and_upload_xarm(root, dataset_id="xarm_lift_medium") download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human") download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted") download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human") diff --git a/tests/data/aloha_sim_insertion_human/data_dict.pth b/tests/data/aloha_sim_insertion_human/data_dict.pth deleted file mode 100644 index 1370c9ea..00000000 Binary files a/tests/data/aloha_sim_insertion_human/data_dict.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth b/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/aloha_sim_insertion_human/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_human/stats.pth b/tests/data/aloha_sim_insertion_human/stats.pth deleted file mode 100644 index a7b9248f..00000000 Binary files a/tests/data/aloha_sim_insertion_human/stats.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..4d357e34 Binary files /dev/null and b/tests/data/aloha_sim_insertion_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_insertion_human/train/dataset_info.json b/tests/data/aloha_sim_insertion_human/train/dataset_info.json new file mode 100644 index 00000000..473812f3 --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/train/dataset_info.json @@ -0,0 +1,55 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_id_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_id_to": { + "dtype": "int64", + "_type": "Value" + }, + "observation.images.top": { + "_type": "Image" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_human/train/state.json b/tests/data/aloha_sim_insertion_human/train/state.json new file mode 100644 index 00000000..5b56e98c --- /dev/null +++ b/tests/data/aloha_sim_insertion_human/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "05980bca35112ebd", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/data_dict.pth b/tests/data/aloha_sim_insertion_scripted/data_dict.pth deleted file mode 100644 index 00c9f335..00000000 Binary files a/tests/data/aloha_sim_insertion_scripted/data_dict.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth b/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/aloha_sim_insertion_scripted/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_scripted/stats.pth b/tests/data/aloha_sim_insertion_scripted/stats.pth deleted file mode 100644 index 990d4647..00000000 Binary files a/tests/data/aloha_sim_insertion_scripted/stats.pth and /dev/null differ diff --git a/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..421474a2 Binary files /dev/null and b/tests/data/aloha_sim_insertion_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json b/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json new file mode 100644 index 00000000..473812f3 --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/train/dataset_info.json @@ -0,0 +1,55 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_id_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_id_to": { + "dtype": "int64", + "_type": "Value" + }, + "observation.images.top": { + "_type": "Image" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/aloha_sim_insertion_scripted/train/state.json b/tests/data/aloha_sim_insertion_scripted/train/state.json new file mode 100644 index 00000000..8f202c3a --- /dev/null +++ b/tests/data/aloha_sim_insertion_scripted/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "f3330a7e1d8bc55b", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/data_dict.pth b/tests/data/aloha_sim_transfer_cube_human/data_dict.pth deleted file mode 100644 index ab851779..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_human/data_dict.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth b/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_human/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_human/stats.pth b/tests/data/aloha_sim_transfer_cube_human/stats.pth deleted file mode 100644 index 1ae356e3..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_human/stats.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..9e371c8f Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_human/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json new file mode 100644 index 00000000..473812f3 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/train/dataset_info.json @@ -0,0 +1,55 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_id_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_id_to": { + "dtype": "int64", + "_type": "Value" + }, + "observation.images.top": { + "_type": "Image" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_human/train/state.json b/tests/data/aloha_sim_transfer_cube_human/train/state.json new file mode 100644 index 00000000..ec1fdf06 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_human/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "42aa77ffb6863924", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth b/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth deleted file mode 100644 index bd308bb0..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/data_dict.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth b/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth b/tests/data/aloha_sim_transfer_cube_scripted/stats.pth deleted file mode 100644 index 71547f09..00000000 Binary files a/tests/data/aloha_sim_transfer_cube_scripted/stats.pth and /dev/null differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow b/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..99d3363b Binary files /dev/null and b/tests/data/aloha_sim_transfer_cube_scripted/train/data-00000-of-00001.arrow differ diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json new file mode 100644 index 00000000..473812f3 --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/dataset_info.json @@ -0,0 +1,55 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 14, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_id_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_id_to": { + "dtype": "int64", + "_type": "Value" + }, + "observation.images.top": { + "_type": "Image" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/aloha_sim_transfer_cube_scripted/train/state.json b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json new file mode 100644 index 00000000..ee3cc1fe --- /dev/null +++ b/tests/data/aloha_sim_transfer_cube_scripted/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "43f176a3740fe622", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/pusht/data_dict.pth b/tests/data/pusht/data_dict.pth deleted file mode 100644 index a083c86c..00000000 Binary files a/tests/data/pusht/data_dict.pth and /dev/null differ diff --git a/tests/data/pusht/data_ids_per_episode.pth b/tests/data/pusht/data_ids_per_episode.pth deleted file mode 100644 index a1d481dd..00000000 Binary files a/tests/data/pusht/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/pusht/stats.pth b/tests/data/pusht/stats.pth deleted file mode 100644 index 636985fd..00000000 Binary files a/tests/data/pusht/stats.pth and /dev/null differ diff --git a/tests/data/pusht/train/data-00000-of-00001.arrow b/tests/data/pusht/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..71f657ab Binary files /dev/null and b/tests/data/pusht/train/data-00000-of-00001.arrow differ diff --git a/tests/data/pusht/train/dataset_info.json b/tests/data/pusht/train/dataset_info.json new file mode 100644 index 00000000..b21231fe --- /dev/null +++ b/tests/data/pusht/train/dataset_info.json @@ -0,0 +1,63 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.image": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 2, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 2, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.reward": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "next.success": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_id_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_id_to": { + "dtype": "int64", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/pusht/train/state.json b/tests/data/pusht/train/state.json new file mode 100644 index 00000000..090326e1 --- /dev/null +++ b/tests/data/pusht/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "f7ed966ae18000ae", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/data_dict.pth b/tests/data/xarm_lift_medium/data_dict.pth deleted file mode 100644 index 5c166576..00000000 Binary files a/tests/data/xarm_lift_medium/data_dict.pth and /dev/null differ diff --git a/tests/data/xarm_lift_medium/data_ids_per_episode.pth b/tests/data/xarm_lift_medium/data_ids_per_episode.pth deleted file mode 100644 index 21095017..00000000 Binary files a/tests/data/xarm_lift_medium/data_ids_per_episode.pth and /dev/null differ diff --git a/tests/data/xarm_lift_medium/stats.pth b/tests/data/xarm_lift_medium/stats.pth deleted file mode 100644 index 3ab4e05b..00000000 Binary files a/tests/data/xarm_lift_medium/stats.pth and /dev/null differ diff --git a/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow b/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow new file mode 100644 index 00000000..f6ee0e50 Binary files /dev/null and b/tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow differ diff --git a/tests/data/xarm_lift_medium/train/dataset_info.json b/tests/data/xarm_lift_medium/train/dataset_info.json new file mode 100644 index 00000000..81ba7c8c --- /dev/null +++ b/tests/data/xarm_lift_medium/train/dataset_info.json @@ -0,0 +1,59 @@ +{ + "citation": "", + "description": "", + "features": { + "observation.image": { + "_type": "Image" + }, + "observation.state": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 4, + "_type": "Sequence" + }, + "action": { + "feature": { + "dtype": "float32", + "_type": "Value" + }, + "length": 4, + "_type": "Sequence" + }, + "episode_id": { + "dtype": "int64", + "_type": "Value" + }, + "frame_id": { + "dtype": "int64", + "_type": "Value" + }, + "timestamp": { + "dtype": "float32", + "_type": "Value" + }, + "next.reward": { + "dtype": "float32", + "_type": "Value" + }, + "next.done": { + "dtype": "bool", + "_type": "Value" + }, + "episode_data_id_from": { + "dtype": "int64", + "_type": "Value" + }, + "episode_data_id_to": { + "dtype": "int64", + "_type": "Value" + }, + "index": { + "dtype": "int64", + "_type": "Value" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/data/xarm_lift_medium/train/state.json b/tests/data/xarm_lift_medium/train/state.json new file mode 100644 index 00000000..500bdb85 --- /dev/null +++ b/tests/data/xarm_lift_medium/train/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "7dcd82fc3815bba6", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": "torch", + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 348256b6..c40d478a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -37,7 +37,7 @@ def test_factory(env_name, dataset_id, policy_name): keys_ndim_required = [ ("action", 1, True), - ("episode", 0, True), + ("episode_id", 0, True), ("frame_id", 0, True), ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? @@ -95,14 +95,12 @@ def test_compute_stats(): """ from lerobot.common.datasets.xarm import XarmDataset - DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None # get transform to convert images from uint8 [0,255] to float32 [0,1] transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) dataset = XarmDataset( dataset_id="xarm_lift_medium", - root=DATA_DIR, transform=transform, ) @@ -115,7 +113,13 @@ def test_compute_stats(): stats_patterns = get_stats_einops_patterns(dataset) # get all frames from the dataset in the same dtype and range as during compute_stats - data_dict = transform(dataset.data_dict) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=16, + batch_size=len(dataset), + shuffle=False, + ) + data_dict = next(iter(dataloader)) # takes 23 seconds # compute stats based on all frames from the dataset without any batching expected_stats = {}