Refactor datasets into LeRobotDataset (#91)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 12:23:12 +02:00
committed by GitHub
parent e760e4cd63
commit 659c69a1c0
90 changed files with 167 additions and 352 deletions

View File

@@ -12,7 +12,7 @@ from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
@@ -26,13 +26,13 @@ from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, dataset_id, policy_name):
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, repo_id, policy_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"dataset_id={dataset_id}",
f"dataset.repo_id={repo_id}",
f"policy={policy_name}",
f"device={DEVICE}",
],
@@ -94,14 +94,13 @@ def test_compute_stats_on_xarm():
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
from lerobot.common.datasets.xarm import XarmDataset
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
dataset = LeRobotDataset(
"lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
)
# reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
@@ -241,16 +240,16 @@ def test_flatten_unflatten_dict():
def test_backward_compatibility():
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset_id = "pusht"
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
repo_id = "lerobot/pusht"
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
def load_and_compare(i):
new_frame = dataset[i]
old_frame = load_file(data_dir / f"frame_{i}.safetensors")