Refactor datasets into LeRobotDataset (#91)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@ from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
compute_stats,
|
||||
flatten_dict,
|
||||
@@ -26,13 +26,13 @@ from lerobot.common.utils.utils import init_hydra_config
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||
def test_factory(env_name, dataset_id, policy_name):
|
||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"dataset_id={dataset_id}",
|
||||
f"dataset.repo_id={repo_id}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
],
|
||||
@@ -94,14 +94,13 @@ def test_compute_stats_on_xarm():
|
||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
because we are working with a small dataset).
|
||||
"""
|
||||
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
|
||||
dataset = XarmDataset(
|
||||
dataset_id="xarm_lift_medium",
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
dataset = LeRobotDataset(
|
||||
"lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
)
|
||||
|
||||
# reduce size of dataset sample on which stats compute is tested to 10 frames
|
||||
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
|
||||
|
||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# dataset into even batches.
|
||||
@@ -241,16 +240,16 @@ def test_flatten_unflatten_dict():
|
||||
|
||||
def test_backward_compatibility():
|
||||
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
|
||||
dataset_id = "pusht"
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
|
||||
|
||||
dataset = PushtDataset(
|
||||
dataset_id=dataset_id,
|
||||
split="train",
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i]
|
||||
old_frame = load_file(data_dir / f"frame_{i}.safetensors")
|
||||
|
||||
Reference in New Issue
Block a user