HF datasets works

This commit is contained in:
Cadene
2024-04-16 12:20:38 +00:00
parent 5edd9a89a0
commit 0980fff6cc
42 changed files with 630 additions and 87 deletions

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch import torch
from datasets import load_dataset from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -25,15 +27,24 @@ class AlohaDataset(torch.utils.data.Dataset):
self, self,
dataset_id: str, dataset_id: str,
version: str | None = "v1.0", version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None, transform: callable = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
): ):
super().__init__() super().__init__()
self.dataset_id = dataset_id self.dataset_id = dataset_id
self.version = version self.version = version
self.root = root
self.split = split
self.transform = transform self.transform = transform
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train") if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch") self.data_dict = self.data_dict.with_format("torch")
@property @property
@@ -57,6 +68,15 @@ class AlohaDataset(torch.utils.data.Dataset):
self.delta_timestamps, self.delta_timestamps,
) )
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None: if self.transform is not None:
item = self.transform(item) item = self.transform(item)

View File

@@ -1,4 +1,5 @@
import logging import logging
import os
from pathlib import Path from pathlib import Path
import torch import torch
@@ -7,12 +8,15 @@ from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod from lerobot.common.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset( def make_dataset(
cfg, cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255] # set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True, normalize=True,
stats_path=None, stats_path=None,
split="train",
): ):
if cfg.env.name == "xarm": if cfg.env.name == "xarm":
from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.xarm import XarmDataset
@@ -57,6 +61,8 @@ def make_dataset(
# instantiate a one frame dataset with light transform # instantiate a one frame dataset with light transform
stats_dataset = clsfunc( stats_dataset = clsfunc(
dataset_id=cfg.dataset_id, dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
) )
stats = compute_stats(stats_dataset) stats = compute_stats(stats_dataset)
@@ -86,6 +92,8 @@ def make_dataset(
dataset = clsfunc( dataset = clsfunc(
dataset_id=cfg.dataset_id, dataset_id=cfg.dataset_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
transform=transforms, transform=transforms,
) )

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch import torch
from datasets import load_dataset from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -23,18 +25,25 @@ class PushtDataset(torch.utils.data.Dataset):
self, self,
dataset_id: str, dataset_id: str,
version: str | None = "v1.0", version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None, transform: callable = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
): ):
super().__init__() super().__init__()
self.dataset_id = dataset_id self.dataset_id = dataset_id
self.version = version self.version = version
self.root = root
self.split = split
self.transform = transform self.transform = transform
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train") if self.root is not None:
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train") self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch") self.data_dict = self.data_dict.with_format("torch")
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
@@ -57,6 +66,15 @@ class PushtDataset(torch.utils.data.Dataset):
self.delta_timestamps, self.delta_timestamps,
) )
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None: if self.transform is not None:
item = self.transform(item) item = self.transform(item)

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch import torch
from datasets import load_dataset from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -19,15 +21,24 @@ class XarmDataset(torch.utils.data.Dataset):
self, self,
dataset_id: str, dataset_id: str,
version: str | None = "v1.0", version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None, transform: callable = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
): ):
super().__init__() super().__init__()
self.dataset_id = dataset_id self.dataset_id = dataset_id
self.version = version self.version = version
self.root = root
self.split = split
self.transform = transform self.transform = transform
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train") if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch") self.data_dict = self.data_dict.with_format("torch")
@property @property
@@ -51,6 +62,15 @@ class XarmDataset(torch.utils.data.Dataset):
self.delta_timestamps, self.delta_timestamps,
) )
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None: if self.transform is not None:
item = self.transform(item) item = self.transform(item)

View File

@@ -12,7 +12,8 @@ import h5py
import numpy as np import numpy as np
import torch import torch
import tqdm import tqdm
from datasets import Dataset from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
def download_and_extract_zip(url: str, destination_folder: Path) -> bool: def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
@@ -73,10 +74,6 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs()) episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0] num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len( assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118 {dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames." ), "Some data type dont have the same number of total frames."
@@ -85,28 +82,27 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians) goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"]) imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
states = torch.from_numpy(dataset_dict["state"]) states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"]) actions = torch.from_numpy(dataset_dict["action"])
data_ids_per_episode = {} data_ids_per_episode = {}
ep_dicts = [] ep_dicts = []
idx0 = 0 id_from = 0
for episode_id in tqdm.tqdm(range(num_episodes)): for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id] id_to = dataset_dict.meta["episode_ends"][episode_id]
num_frames = idx1 - idx0 num_frames = id_to - id_from
assert (episode_ids[idx0:idx1] == episode_id).all() assert (episode_ids[id_from:id_to] == episode_id).all()
image = imgs[idx0:idx1] image = imgs[id_from:id_to]
assert image.min() >= 0.0 assert image.min() >= 0.0
assert image.max() <= 255.0 assert image.max() <= 255.0
image = image.type(torch.uint8) image = image.type(torch.uint8)
state = states[idx0:idx1] state = states[id_from:id_to]
agent_pos = state[:, :2] agent_pos = state[:, :2]
block_pos = state[:, 2:4] block_pos = state[:, 2:4]
block_angle = state[:, 4] block_angle = state[:, 4]
@@ -141,9 +137,9 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
done[-1] = True done[-1] = True
ep_dict = { ep_dict = {
"observation.image": image, "observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": agent_pos, "observation.state": agent_pos,
"action": actions[idx0:idx1], "action": actions[id_from:id_to],
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1), "frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps, "timestamp": torch.arange(0, num_frames, 1) / fps,
@@ -153,33 +149,55 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
"next.reward": torch.cat([reward[1:], reward[[-1]]]), "next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]), "next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]), "next.success": torch.cat([success[1:], success[[-1]]]),
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
} }
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
assert isinstance(episode_id, int) assert isinstance(episode_id, int)
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1) data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames assert len(data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1 id_from += num_frames
data_dict = {} data_dict = {}
keys = ep_dicts[0].keys() keys = ep_dicts[0].keys()
for key in keys: for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts]) if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict) features = {
"observation.image": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
"next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch") dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame): num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
ep_id = frame["episode_id"].item() dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@@ -211,32 +229,32 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
total_frames = dataset_dict["actions"].shape[0] total_frames = dataset_dict["actions"].shape[0]
data_ids_per_episode = {}
ep_dicts = [] ep_dicts = []
idx0 = 0 id_from = 0
idx1 = 0 id_to = 0
episode_id = 0 episode_id = 0
for i in tqdm.tqdm(range(total_frames)): for i in tqdm.tqdm(range(total_frames)):
idx1 += 1 id_to += 1
if not dataset_dict["dones"][i]: if not dataset_dict["dones"][i]:
continue continue
num_frames = idx1 - idx0 num_frames = id_to - id_from
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1]) image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1]) image = einops.rearrange(image, "b c h w -> b h w c")
action = torch.tensor(dataset_dict["actions"][idx0:idx1]) state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done # TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state" # it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1]) # next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1]) # next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1]) next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1]) next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
ep_dict = { ep_dict = {
"observation.image": image, "observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": state, "observation.state": state,
"action": action, "action": action,
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int), "episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
@@ -246,34 +264,51 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
# "next.observation.state": next_state, # "next.observation.state": next_state,
"next.reward": next_reward, "next.reward": next_reward,
"next.done": next_done, "next.done": next_done,
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
} }
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
assert isinstance(episode_id, int) id_from = id_to
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
episode_id += 1 episode_id += 1
data_dict = {} data_dict = {}
keys = ep_dicts[0].keys() keys = ep_dicts[0].keys()
for key in keys: for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts]) if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict) features = {
"observation.image": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch") dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame): num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
ep_id = frame["episode_id"].item() dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@@ -338,10 +373,9 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True) gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True) gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
data_ids_per_episode = {}
ep_dicts = [] ep_dicts = []
frame_idx = 0 id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])): for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5" ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep: with h5py.File(ep_path, "r") as ep:
@@ -366,49 +400,68 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
# "next.reward": reward, # "next.reward": reward,
"next.done": done, "next.done": done,
# "next.success": success, # "next.success": success,
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
} }
for cam in cameras[dataset_id]: for cam in cameras[dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
image = einops.rearrange(image, "b h w c -> b c h w").contiguous() # image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = image ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
# ep_dict[f"next.observation.images.{cam}"] = image # ep_dict[f"next.observation.images.{cam}"] = image
assert isinstance(ep_id, int) assert isinstance(ep_id, int)
data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
assert len(data_ids_per_episode[ep_id]) == num_frames
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
frame_idx += num_frames id_from += num_frames
data_dict = {} data_dict = {}
data_dict = {}
keys = ep_dicts[0].keys() keys = ep_dicts[0].keys()
for key in keys: for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts]) if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = frame_idx total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict) features = {
"observation.images.top": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch") dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame): num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
ep_id = frame["episode_id"].item() dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True) dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0") dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
if __name__ == "__main__": if __name__ == "__main__":
root = "data" root = "data"
# download_and_upload_pusht(root, dataset_id="pusht") download_and_upload_pusht(root, dataset_id="pusht")
# download_and_upload_xarm(root, dataset_id="xarm_lift_medium") download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human") download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human")
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted") download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted")
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human") download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human")

View File

@@ -0,0 +1,55 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "05980bca35112ebd",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,55 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "f3330a7e1d8bc55b",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,55 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "42aa77ffb6863924",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,55 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "43f176a3740fe622",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,63 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 2,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 2,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"next.success": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "f7ed966ae18000ae",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,59 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "7dcd82fc3815bba6",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -37,7 +37,7 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required = [ keys_ndim_required = [
("action", 1, True), ("action", 1, True),
("episode", 0, True), ("episode_id", 0, True),
("frame_id", 0, True), ("frame_id", 0, True),
("timestamp", 0, True), ("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos? # TODO(rcadene): should we rename it agent_pos?
@@ -95,14 +95,12 @@ def test_compute_stats():
""" """
from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.datasets.xarm import XarmDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1] # get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset( dataset = XarmDataset(
dataset_id="xarm_lift_medium", dataset_id="xarm_lift_medium",
root=DATA_DIR,
transform=transform, transform=transform,
) )
@@ -115,7 +113,13 @@ def test_compute_stats():
stats_patterns = get_stats_einops_patterns(dataset) stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats # get all frames from the dataset in the same dtype and range as during compute_stats
data_dict = transform(dataset.data_dict) dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=16,
batch_size=len(dataset),
shuffle=False,
)
data_dict = next(iter(dataloader)) # takes 23 seconds
# compute stats based on all frames from the dataset without any batching # compute stats based on all frames from the dataset without any batching
expected_stats = {} expected_stats = {}