HF datasets works
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
@@ -25,15 +27,24 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
@@ -57,6 +68,15 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -7,12 +8,15 @@ from torchvision.transforms import v2
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
stats_path=None,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
@@ -57,6 +61,8 @@ def make_dataset(
|
||||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_stats(stats_dataset)
|
||||
@@ -86,6 +92,8 @@ def make_dataset(
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
@@ -23,18 +25,25 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
@@ -57,6 +66,15 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
@@ -19,15 +21,24 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
@@ -51,6 +62,15 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
@@ -73,10 +74,6 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
||||
|
||||
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
|
||||
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
|
||||
total_frames = dataset_dict["action"].shape[0]
|
||||
# to create test artifact
|
||||
# num_episodes = 1
|
||||
# total_frames = 50
|
||||
assert len(
|
||||
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
|
||||
), "Some data type dont have the same number of total frames."
|
||||
@@ -85,28 +82,27 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
imgs = torch.from_numpy(dataset_dict["img"])
|
||||
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
|
||||
imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
id_from = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
id_to = dataset_dict.meta["episode_ends"][episode_id]
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
num_frames = id_to - id_from
|
||||
|
||||
assert (episode_ids[idx0:idx1] == episode_id).all()
|
||||
assert (episode_ids[id_from:id_to] == episode_id).all()
|
||||
|
||||
image = imgs[idx0:idx1]
|
||||
image = imgs[id_from:id_to]
|
||||
assert image.min() >= 0.0
|
||||
assert image.max() <= 255.0
|
||||
image = image.type(torch.uint8)
|
||||
|
||||
state = states[idx0:idx1]
|
||||
state = states[id_from:id_to]
|
||||
agent_pos = state[:, :2]
|
||||
block_pos = state[:, 2:4]
|
||||
block_angle = state[:, 4]
|
||||
@@ -141,9 +137,9 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
||||
done[-1] = True
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"action": actions[id_from:id_to],
|
||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
@@ -153,33 +149,55 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1)
|
||||
assert len(data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(data_dict)
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
"next.success": Value(dtype="bool", id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode_id"].item()
|
||||
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
@@ -211,32 +229,32 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
episode_id = 0
|
||||
for i in tqdm.tqdm(range(total_frames)):
|
||||
idx1 += 1
|
||||
id_to += 1
|
||||
|
||||
if not dataset_dict["dones"][i]:
|
||||
continue
|
||||
|
||||
num_frames = idx1 - idx0
|
||||
num_frames = id_to - id_from
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
|
||||
image = einops.rearrange(image, "b c h w -> b h w c")
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
|
||||
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
|
||||
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
|
||||
# it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
|
||||
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
|
||||
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
@@ -246,34 +264,51 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
assert isinstance(episode_id, int)
|
||||
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
id_from = id_to
|
||||
episode_id += 1
|
||||
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(data_dict)
|
||||
features = {
|
||||
"observation.image": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
"next.reward": Value(dtype="float32", id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode_id"].item()
|
||||
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
@@ -338,10 +373,9 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
|
||||
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
|
||||
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
frame_idx = 0
|
||||
id_from = 0
|
||||
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
@@ -366,49 +400,68 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
|
||||
# "next.reward": reward,
|
||||
"next.done": done,
|
||||
# "next.success": success,
|
||||
"episode_data_id_from": torch.tensor([id_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
|
||||
}
|
||||
|
||||
for cam in cameras[dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_dict[f"observation.images.{cam}"] = image
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
|
||||
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image
|
||||
|
||||
assert isinstance(ep_id, int)
|
||||
data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
|
||||
assert len(data_ids_per_episode[ep_id]) == num_frames
|
||||
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
frame_idx += num_frames
|
||||
id_from += num_frames
|
||||
|
||||
data_dict = {}
|
||||
|
||||
data_dict = {}
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
|
||||
total_frames = frame_idx
|
||||
total_frames = id_from
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(data_dict)
|
||||
features = {
|
||||
"observation.images.top": Image(),
|
||||
"observation.state": Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
),
|
||||
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
|
||||
"episode_id": Value(dtype="int64", id=None),
|
||||
"frame_id": Value(dtype="int64", id=None),
|
||||
"timestamp": Value(dtype="float32", id=None),
|
||||
#'next.reward': Value(dtype='float32', id=None),
|
||||
"next.done": Value(dtype="bool", id=None),
|
||||
#'next.success': Value(dtype='bool', id=None),
|
||||
"index": Value(dtype="int64", id=None),
|
||||
"episode_data_id_from": Value(dtype="int64", id=None),
|
||||
"episode_data_id_to": Value(dtype="int64", id=None),
|
||||
}
|
||||
features = Features(features)
|
||||
dataset = Dataset.from_dict(data_dict, features=features)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode_id"].item()
|
||||
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to)
|
||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
root = "data"
|
||||
# download_and_upload_pusht(root, dataset_id="pusht")
|
||||
# download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
|
||||
download_and_upload_pusht(root, dataset_id="pusht")
|
||||
download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted")
|
||||
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
55
tests/data/aloha_sim_insertion_human/train/dataset_info.json
Normal file
55
tests/data/aloha_sim_insertion_human/train/dataset_info.json
Normal file
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/aloha_sim_insertion_human/train/state.json
Normal file
13
tests/data/aloha_sim_insertion_human/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "05980bca35112ebd",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/aloha_sim_insertion_scripted/train/state.json
Normal file
13
tests/data/aloha_sim_insertion_scripted/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "f3330a7e1d8bc55b",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/aloha_sim_transfer_cube_human/train/state.json
Normal file
13
tests/data/aloha_sim_transfer_cube_human/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "42aa77ffb6863924",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 14,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"observation.images.top": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/aloha_sim_transfer_cube_scripted/train/state.json
Normal file
13
tests/data/aloha_sim_transfer_cube_scripted/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "43f176a3740fe622",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/data/pusht/train/data-00000-of-00001.arrow
Normal file
BIN
tests/data/pusht/train/data-00000-of-00001.arrow
Normal file
Binary file not shown.
63
tests/data/pusht/train/dataset_info.json
Normal file
63
tests/data/pusht/train/dataset_info.json
Normal file
@@ -0,0 +1,63 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.image": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 2,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 2,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.success": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/pusht/train/state.json
Normal file
13
tests/data/pusht/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "f7ed966ae18000ae",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow
Normal file
BIN
tests/data/xarm_lift_medium/train/data-00000-of-00001.arrow
Normal file
Binary file not shown.
59
tests/data/xarm_lift_medium/train/dataset_info.json
Normal file
59
tests/data/xarm_lift_medium/train/dataset_info.json
Normal file
@@ -0,0 +1,59 @@
|
||||
{
|
||||
"citation": "",
|
||||
"description": "",
|
||||
"features": {
|
||||
"observation.image": {
|
||||
"_type": "Image"
|
||||
},
|
||||
"observation.state": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"action": {
|
||||
"feature": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"length": 4,
|
||||
"_type": "Sequence"
|
||||
},
|
||||
"episode_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"frame_id": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"_type": "Value"
|
||||
},
|
||||
"next.done": {
|
||||
"dtype": "bool",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_from": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"episode_data_id_to": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
},
|
||||
"index": {
|
||||
"dtype": "int64",
|
||||
"_type": "Value"
|
||||
}
|
||||
},
|
||||
"homepage": "",
|
||||
"license": ""
|
||||
}
|
||||
13
tests/data/xarm_lift_medium/train/state.json
Normal file
13
tests/data/xarm_lift_medium/train/state.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"_data_files": [
|
||||
{
|
||||
"filename": "data-00000-of-00001.arrow"
|
||||
}
|
||||
],
|
||||
"_fingerprint": "7dcd82fc3815bba6",
|
||||
"_format_columns": null,
|
||||
"_format_kwargs": {},
|
||||
"_format_type": "torch",
|
||||
"_output_all_columns": false,
|
||||
"_split": null
|
||||
}
|
||||
@@ -37,7 +37,7 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||
|
||||
keys_ndim_required = [
|
||||
("action", 1, True),
|
||||
("episode", 0, True),
|
||||
("episode_id", 0, True),
|
||||
("frame_id", 0, True),
|
||||
("timestamp", 0, True),
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
@@ -95,14 +95,12 @@ def test_compute_stats():
|
||||
"""
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
||||
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
||||
|
||||
dataset = XarmDataset(
|
||||
dataset_id="xarm_lift_medium",
|
||||
root=DATA_DIR,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
@@ -115,7 +113,13 @@ def test_compute_stats():
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
|
||||
# get all frames from the dataset in the same dtype and range as during compute_stats
|
||||
data_dict = transform(dataset.data_dict)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=16,
|
||||
batch_size=len(dataset),
|
||||
shuffle=False,
|
||||
)
|
||||
data_dict = next(iter(dataloader)) # takes 23 seconds
|
||||
|
||||
# compute stats based on all frames from the dataset without any batching
|
||||
expected_stats = {}
|
||||
|
||||
Reference in New Issue
Block a user