forked from tangger/lerobot
Add dataset visualization with rerun.io (#131)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,6 @@ Example usage:
|
||||
`python tests/scripts/save_dataset_to_safetensors.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -29,7 +28,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
|
||||
repo_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id, root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
@@ -97,9 +96,7 @@ def test_compute_stats_on_xarm():
|
||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
because we are working with a small dataset).
|
||||
"""
|
||||
dataset = LeRobotDataset(
|
||||
"lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
)
|
||||
dataset = LeRobotDataset("lerobot/xarm_lift_medium")
|
||||
|
||||
# reduce size of dataset sample on which stats compute is tested to 10 frames
|
||||
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
|
||||
@@ -254,7 +251,6 @@ def test_backward_compatibility(repo_id):
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
|
||||
)
|
||||
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
@@ -18,7 +18,7 @@ def _run_script(path):
|
||||
def test_example_1():
|
||||
path = "examples/1_load_lerobot_dataset.py"
|
||||
_run_script(path)
|
||||
assert Path("outputs/examples/1_load_lerobot_dataset/episode_5.mp4").exists()
|
||||
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
|
||||
|
||||
|
||||
def test_examples_3_and_2():
|
||||
|
||||
@@ -1,31 +1,18 @@
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
],
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset(tmpdir, repo_id):
|
||||
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
|
||||
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
"policy=act",
|
||||
"env=aloha",
|
||||
f"dataset_repo_id={repo_id}",
|
||||
],
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
)
|
||||
video_paths = visualize_dataset(cfg, out_dir=tmpdir)
|
||||
|
||||
assert len(video_paths) > 0
|
||||
|
||||
for video_path in video_paths:
|
||||
assert video_path.exists()
|
||||
assert rrd_path.exists()
|
||||
|
||||
Reference in New Issue
Block a user