HF datasets works

This commit is contained in:
Cadene
2024-04-16 12:20:38 +00:00
parent 5edd9a89a0
commit 0980fff6cc
42 changed files with 630 additions and 87 deletions

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -25,15 +27,24 @@ class AlohaDataset(torch.utils.data.Dataset):
self,
dataset_id: str,
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
@property
@@ -57,6 +68,15 @@ class AlohaDataset(torch.utils.data.Dataset):
self.delta_timestamps,
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
import torch
@@ -7,12 +8,15 @@ from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset(
cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
stats_path=None,
split="train",
):
if cfg.env.name == "xarm":
from lerobot.common.datasets.xarm import XarmDataset
@@ -57,6 +61,8 @@ def make_dataset(
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset)
@@ -86,6 +92,8 @@ def make_dataset(
dataset = clsfunc(
dataset_id=cfg.dataset_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -23,18 +25,25 @@ class PushtDataset(torch.utils.data.Dataset):
self,
dataset_id: str,
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@property
def num_samples(self) -> int:
@@ -57,6 +66,15 @@ class PushtDataset(torch.utils.data.Dataset):
self.delta_timestamps,
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -19,15 +21,24 @@ class XarmDataset(torch.utils.data.Dataset):
self,
dataset_id: str,
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
@property
@@ -51,6 +62,15 @@ class XarmDataset(torch.utils.data.Dataset):
self.delta_timestamps,
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -12,7 +12,8 @@ import h5py
import numpy as np
import torch
import tqdm
from datasets import Dataset
from datasets import Dataset, Features, Image, Sequence, Value
from PIL import Image as PILImage
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
@@ -73,10 +74,6 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
episode_ids = torch.from_numpy(dataset_dict.get_episode_idxs())
num_episodes = dataset_dict.meta["episode_ends"].shape[0]
total_frames = dataset_dict["action"].shape[0]
# to create test artifact
# num_episodes = 1
# total_frames = 50
assert len(
{dataset_dict[key].shape[0] for key in dataset_dict.keys()} # noqa: SIM118
), "Some data type dont have the same number of total frames."
@@ -85,28 +82,27 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
goal_pos_angle = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
imgs = torch.from_numpy(dataset_dict["img"])
imgs = einops.rearrange(imgs, "b h w c -> b c h w")
imgs = torch.from_numpy(dataset_dict["img"]) # b h w c
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
id_from = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
id_to = dataset_dict.meta["episode_ends"][episode_id]
num_frames = idx1 - idx0
num_frames = id_to - id_from
assert (episode_ids[idx0:idx1] == episode_id).all()
assert (episode_ids[id_from:id_to] == episode_id).all()
image = imgs[idx0:idx1]
image = imgs[id_from:id_to]
assert image.min() >= 0.0
assert image.max() <= 255.0
image = image.type(torch.uint8)
state = states[idx0:idx1]
state = states[id_from:id_to]
agent_pos = state[:, :2]
block_pos = state[:, 2:4]
block_angle = state[:, 4]
@@ -141,9 +137,9 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
done[-1] = True
ep_dict = {
"observation.image": image,
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": agent_pos,
"action": actions[idx0:idx1],
"action": actions[id_from:id_to],
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
@@ -153,33 +149,55 @@ def download_and_upload_pusht(root, dataset_id="pusht", fps=10):
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
data_ids_per_episode[episode_id] = torch.arange(id_from, id_to, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
id_from += num_frames
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict)
features = {
"observation.image": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
"next.success": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode_id"].item()
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@@ -211,32 +229,32 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
total_frames = dataset_dict["actions"].shape[0]
data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
idx1 = 0
id_from = 0
id_to = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
id_to += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
num_frames = id_to - id_from
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
image = torch.tensor(dataset_dict["observations"]["rgb"][id_from:id_to])
image = einops.rearrange(image, "b c h w -> b h w c")
state = torch.tensor(dataset_dict["observations"]["state"][id_from:id_to])
action = torch.tensor(dataset_dict["actions"][id_from:id_to])
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][id_from:id_to])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][id_from:id_to])
next_reward = torch.tensor(dataset_dict["rewards"][id_from:id_to])
next_done = torch.tensor(dataset_dict["dones"][id_from:id_to])
ep_dict = {
"observation.image": image,
"observation.image": [PILImage.fromarray(x.numpy()) for x in image],
"observation.state": state,
"action": action,
"episode_id": torch.tensor([episode_id] * num_frames, dtype=torch.int),
@@ -246,34 +264,51 @@ def download_and_upload_xarm(root, dataset_id, fps=15):
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
}
ep_dicts.append(ep_dict)
assert isinstance(episode_id, int)
data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
id_from = id_to
episode_id += 1
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict)
features = {
"observation.image": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.reward": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode_id"].item()
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@@ -338,10 +373,9 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
data_ids_per_episode = {}
ep_dicts = []
frame_idx = 0
id_from = 0
for ep_id in tqdm.tqdm(range(num_episodes[dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
@@ -366,49 +400,68 @@ def download_and_upload_aloha(root, dataset_id, fps=50):
# "next.reward": reward,
"next.done": done,
# "next.success": success,
"episode_data_id_from": torch.tensor([id_from] * num_frames),
"episode_data_id_to": torch.tensor([id_from + num_frames - 1] * num_frames),
}
for cam in cameras[dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = image
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
# image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
# ep_dict[f"next.observation.images.{cam}"] = image
assert isinstance(ep_id, int)
data_ids_per_episode[ep_id] = torch.arange(frame_idx, frame_idx + num_frames, 1)
assert len(data_ids_per_episode[ep_id]) == num_frames
ep_dicts.append(ep_dict)
frame_idx += num_frames
id_from += num_frames
data_dict = {}
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = frame_idx
total_frames = id_from
data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(data_dict)
features = {
"observation.images.top": Image(),
"observation.state": Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
),
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
"episode_id": Value(dtype="int64", id=None),
"frame_id": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
#'next.reward': Value(dtype='float32', id=None),
"next.done": Value(dtype="bool", id=None),
#'next.success': Value(dtype='bool', id=None),
"index": Value(dtype="int64", id=None),
"episode_data_id_from": Value(dtype="int64", id=None),
"episode_data_id_to": Value(dtype="int64", id=None),
}
features = Features(features)
dataset = Dataset.from_dict(data_dict, features=features)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode_id"].item()
frame["episode_data_id_from"] = data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to)
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
dataset.select(range(num_items_first_ep)).save_to_disk(f"tests/data/{dataset_id}/train")
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
if __name__ == "__main__":
root = "data"
# download_and_upload_pusht(root, dataset_id="pusht")
# download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
download_and_upload_pusht(root, dataset_id="pusht")
download_and_upload_xarm(root, dataset_id="xarm_lift_medium")
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_human")
download_and_upload_aloha(root, dataset_id="aloha_sim_insertion_scripted")
download_and_upload_aloha(root, dataset_id="aloha_sim_transfer_cube_human")

View File

@@ -0,0 +1,55 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "05980bca35112ebd",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,55 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "f3330a7e1d8bc55b",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,55 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "42aa77ffb6863924",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,55 @@
{
"citation": "",
"description": "",
"features": {
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 14,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"observation.images.top": {
"_type": "Image"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "43f176a3740fe622",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,63 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 2,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 2,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"next.success": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "f7ed966ae18000ae",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -0,0 +1,59 @@
{
"citation": "",
"description": "",
"features": {
"observation.image": {
"_type": "Image"
},
"observation.state": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"action": {
"feature": {
"dtype": "float32",
"_type": "Value"
},
"length": 4,
"_type": "Sequence"
},
"episode_id": {
"dtype": "int64",
"_type": "Value"
},
"frame_id": {
"dtype": "int64",
"_type": "Value"
},
"timestamp": {
"dtype": "float32",
"_type": "Value"
},
"next.reward": {
"dtype": "float32",
"_type": "Value"
},
"next.done": {
"dtype": "bool",
"_type": "Value"
},
"episode_data_id_from": {
"dtype": "int64",
"_type": "Value"
},
"episode_data_id_to": {
"dtype": "int64",
"_type": "Value"
},
"index": {
"dtype": "int64",
"_type": "Value"
}
},
"homepage": "",
"license": ""
}

View File

@@ -0,0 +1,13 @@
{
"_data_files": [
{
"filename": "data-00000-of-00001.arrow"
}
],
"_fingerprint": "7dcd82fc3815bba6",
"_format_columns": null,
"_format_kwargs": {},
"_format_type": "torch",
"_output_all_columns": false,
"_split": null
}

View File

@@ -37,7 +37,7 @@ def test_factory(env_name, dataset_id, policy_name):
keys_ndim_required = [
("action", 1, True),
("episode", 0, True),
("episode_id", 0, True),
("frame_id", 0, True),
("timestamp", 0, True),
# TODO(rcadene): should we rename it agent_pos?
@@ -95,14 +95,12 @@ def test_compute_stats():
"""
from lerobot.common.datasets.xarm import XarmDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=DATA_DIR,
transform=transform,
)
@@ -115,7 +113,13 @@ def test_compute_stats():
stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
data_dict = transform(dataset.data_dict)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=16,
batch_size=len(dataset),
shuffle=False,
)
data_dict = next(iter(dataloader)) # takes 23 seconds
# compute stats based on all frames from the dataset without any batching
expected_stats = {}