Add dataset visualization with rerun.io (#131)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Remi
2024-05-04 16:07:14 +02:00
committed by GitHub
parent c015252e20
commit 19812ca470
12 changed files with 280 additions and 148 deletions

View File

@@ -11,7 +11,6 @@ Example usage:
`python tests/scripts/save_dataset_to_safetensors.py`
"""
import os
import shutil
from pathlib import Path
@@ -29,7 +28,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
repo_dir.mkdir(parents=True, exist_ok=True)
dataset = LeRobotDataset(
repo_id=repo_id, root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
repo_id=repo_id,
)
# save 2 first frames of first episode

View File

@@ -1,6 +1,5 @@
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
@@ -97,9 +96,7 @@ def test_compute_stats_on_xarm():
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
dataset = LeRobotDataset(
"lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
)
dataset = LeRobotDataset("lerobot/xarm_lift_medium")
# reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
@@ -254,7 +251,6 @@ def test_backward_compatibility(repo_id):
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id

View File

@@ -18,7 +18,7 @@ def _run_script(path):
def test_example_1():
path = "examples/1_load_lerobot_dataset.py"
_run_script(path)
assert Path("outputs/examples/1_load_lerobot_dataset/episode_5.mp4").exists()
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
def test_examples_3_and_2():

View File

@@ -1,31 +1,18 @@
import pytest
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.visualize_dataset import visualize_dataset
from .utils import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"repo_id",
[
"lerobot/aloha_sim_insertion_human",
],
["lerobot/pusht"],
)
def test_visualize_dataset(tmpdir, repo_id):
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
"policy=act",
"env=aloha",
f"dataset_repo_id={repo_id}",
],
rrd_path = visualize_dataset(
repo_id,
episode_index=0,
batch_size=32,
save=True,
output_dir=tmpdir,
)
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
assert len(video_paths) > 0
for video_path in video_paths:
assert video_path.exists()
assert rrd_path.exists()