Refactor datasets into LeRobotDataset (#91)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 12:23:12 +02:00
committed by GitHub
parent e760e4cd63
commit 659c69a1c0
90 changed files with 167 additions and 352 deletions

View File

@@ -16,22 +16,18 @@ from pathlib import Path
from safetensors.torch import save_file
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def save_dataset_to_safetensors(output_dir, dataset_id="pusht"):
data_dir = Path(output_dir) / dataset_id
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
data_dir = Path(output_dir) / repo_id
if data_dir.exists():
shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
)
dataset = LeRobotDataset(repo_id)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()

View File

@@ -4,9 +4,6 @@ import gymnasium as gym
import pytest
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@@ -27,25 +24,6 @@ def test_available_env_task(env_name: str, task_name: list):
assert gym_handle in gym.envs.registry, gym_handle
@pytest.mark.parametrize(
"env_name, dataset_class",
[
("aloha", AlohaDataset),
("pusht", PushtDataset),
("xarm", XarmDataset),
],
)
def test_available_datasets(env_name, dataset_class):
"""
This test verifies that the class attribute `available_datasets` for all
dataset classes is consistent with those listed in `lerobot/__init__.py`.
"""
available_env_datasets = lerobot.available_datasets[env_name]
assert set(available_env_datasets) == set(
dataset_class.available_datasets
), f"{env_name=} {available_env_datasets=}"
def test_available_policies():
"""
This test verifies that the class attribute `name` for all policies is
@@ -58,3 +36,12 @@ def test_available_policies():
]
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies
def test_print():
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)

View File

@@ -12,7 +12,7 @@ from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
@@ -26,13 +26,13 @@ from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, dataset_id, policy_name):
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, repo_id, policy_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"dataset_id={dataset_id}",
f"dataset.repo_id={repo_id}",
f"policy={policy_name}",
f"device={DEVICE}",
],
@@ -94,14 +94,13 @@ def test_compute_stats_on_xarm():
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
from lerobot.common.datasets.xarm import XarmDataset
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
dataset = LeRobotDataset(
"lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
)
# reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
@@ -241,16 +240,16 @@ def test_flatten_unflatten_dict():
def test_backward_compatibility():
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset_id = "pusht"
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
repo_id = "lerobot/pusht"
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
def load_and_compare(i):
new_frame = dataset[i]
old_frame = load_file(data_dir / f"frame_{i}.safetensors")

View File

@@ -19,10 +19,22 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
("xarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
(
"aloha",
"act",
["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"],
),
(
"aloha",
"act",
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"],
),
(
"aloha",
"act",
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
),
],
)
@require_env

View File

@@ -7,12 +7,12 @@ from .utils import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"dataset_id",
"repo_id",
[
"aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_human",
],
)
def test_visualize_dataset(tmpdir, dataset_id):
def test_visualize_dataset(tmpdir, repo_id):
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
cfg = init_hydra_config(
@@ -20,7 +20,7 @@ def test_visualize_dataset(tmpdir, dataset_id):
overrides=[
"policy=act",
"env=aloha",
f"dataset_id={dataset_id}",
f"dataset.repo_id={repo_id}",
],
)
video_paths = visualize_dataset(cfg, out_dir=tmpdir)