Add LeRobotDatasetMetadata

This commit is contained in:
Simon Alibert
2024-11-03 18:07:37 +01:00
parent ac79e8cb36
commit e4ba084e25
25 changed files with 419 additions and 327 deletions

View File

@@ -8,7 +8,7 @@ import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
@@ -33,8 +33,8 @@ def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | No
return shapes
def get_task_index(tasks_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
@@ -313,6 +313,47 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
return _create_hf_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory(
info,
stats,
tasks,
episodes,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset_metadata(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info_dict: dict = info,
stats_dict: dict = stats,
task_dicts: list[dict] = tasks,
episode_dicts: list[dict] = episodes,
**kwargs,
) -> LeRobotDatasetMetadata:
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
)
return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info,
@@ -321,6 +362,7 @@ def lerobot_dataset_factory(
episodes,
hf_dataset,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
def _create_lerobot_dataset(
root: Path,
@@ -335,19 +377,26 @@ def lerobot_dataset_factory(
mock_snapshot_download = mock_snapshot_download_factory(
info_dict=info_dict,
stats_dict=stats_dict,
tasks_dicts=task_dicts,
episodes_dicts=episode_dicts,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
hf_ds=hf_ds,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info_dict=info_dict,
stats_dict=stats_dict,
task_dicts=task_dicts,
episode_dicts=episode_dicts,
**kwargs,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
mock_metadata_patch.return_value = mock_metadata
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)

View File

@@ -36,11 +36,11 @@ def stats_path(stats):
@pytest.fixture(scope="session")
def tasks_path(tasks):
def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path:
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks_dicts)
writer.write_all(task_dicts)
return fpath
return _create_tasks_jsonl_file

View File

@@ -26,7 +26,7 @@ def mock_snapshot_download_factory(
"""
def _mock_snapshot_download_func(
info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
):
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
@@ -53,7 +53,7 @@ def mock_snapshot_download_factory(
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes_dicts:
for episode_dict in episode_dicts:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info_dict["chunks_size"]
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
@@ -75,9 +75,9 @@ def mock_snapshot_download_factory(
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats_dict)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks_dicts)
_ = tasks_path(local_dir, task_dicts)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes_dicts)
_ = episode_path(local_dir, episode_dicts)
else:
pass
return str(local_dir)

View File

@@ -76,7 +76,7 @@ def main():
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.camera_keys[0]]
original_frame = dataset[0][dataset.meta.camera_keys[0]]
save_single_transforms(original_frame, output_dir)
save_default_config_transform(original_frame, output_dir)

View File

@@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
)
set_global_seed(1337)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)

View File

@@ -155,7 +155,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
)
assert dataset.total_episodes == 2
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
@@ -193,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
overrides=overrides,
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
out_dir = tmpdir / "logger"
logger = Logger(cfg, out_dir, wandb_job_name="debug")

View File

@@ -33,7 +33,11 @@ from lerobot.common.datasets.compute_stats import (
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.common.datasets.utils import (
create_branch,
flatten_dict,
@@ -53,14 +57,17 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
metadata_create = LeRobotDatasetMetadata.create(
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
)
dataset_create = LeRobotDataset.create(metadata_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
# Access the '_hub_version' cached_property in both instances to force its creation
_ = dataset_init._hub_version
_ = dataset_create._hub_version
_ = dataset_init.meta._hub_version
_ = dataset_create.meta._hub_version
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
@@ -78,8 +85,8 @@ def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path)
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
assert dataset.repo_id == kwargs["repo_id"]
assert dataset.total_episodes == kwargs["total_episodes"]
assert dataset.total_frames == kwargs["total_frames"]
assert dataset.meta.total_episodes == kwargs["total_episodes"]
assert dataset.meta.total_frames == kwargs["total_frames"]
assert dataset.episodes == kwargs["episodes"]
assert dataset.num_episodes == len(kwargs["episodes"])
assert dataset.num_frames == len(dataset)
@@ -118,7 +125,7 @@ def test_factory(env_name, repo_id, policy_name):
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
camera_keys = dataset.camera_keys
camera_keys = dataset.meta.camera_keys
item = dataset[0]
@@ -251,7 +258,7 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats # noqa: F841
loaded_stats = dataset.meta.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Mute logging for these tests
import io
import subprocess
import sys
@@ -29,6 +29,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
return text
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
def _run_script(path):
subprocess.run([sys.executable, path], check=True)

View File

@@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
@@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
env.step(action)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
@@ -213,7 +214,7 @@ def test_act_backbone_lr():
assert cfg.training.lr_backbone == 0.001
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr

View File

@@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"required_packages, raw_format, repo_id, make_test_data",
[