forked from tangger/lerobot
Add LeRobotDatasetMetadata
This commit is contained in:
67
tests/fixtures/dataset_factories.py
vendored
67
tests/fixtures/dataset_factories.py
vendored
@@ -8,7 +8,7 @@ import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
@@ -33,8 +33,8 @@ def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | No
|
||||
return shapes
|
||||
|
||||
|
||||
def get_task_index(tasks_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in tasks_dicts}
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
|
||||
@@ -313,6 +313,47 @@ def hf_dataset_factory(img_array_factory, episodes, tasks):
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info,
|
||||
stats,
|
||||
tasks,
|
||||
episodes,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_metadata(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info_dict: dict = info,
|
||||
stats_dict: dict = stats,
|
||||
task_dicts: list[dict] = tasks,
|
||||
episode_dicts: list[dict] = episodes,
|
||||
**kwargs,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(
|
||||
repo_id=repo_id, root=root, local_files_only=kwargs.get("local_files_only", False)
|
||||
)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info,
|
||||
@@ -321,6 +362,7 @@ def lerobot_dataset_factory(
|
||||
episodes,
|
||||
hf_dataset,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
@@ -335,19 +377,26 @@ def lerobot_dataset_factory(
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
tasks_dicts=task_dicts,
|
||||
episodes_dicts=episode_dicts,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
hf_ds=hf_ds,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info_dict=info_dict,
|
||||
stats_dict=stats_dict,
|
||||
task_dicts=task_dicts,
|
||||
episode_dicts=episode_dicts,
|
||||
**kwargs,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version
|
||||
mock_metadata_patch.return_value = mock_metadata
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
||||
4
tests/fixtures/files.py
vendored
4
tests/fixtures/files.py
vendored
@@ -36,11 +36,11 @@ def stats_path(stats):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks_dicts: list = tasks) -> Path:
|
||||
def _create_tasks_jsonl_file(dir: Path, task_dicts: list = tasks) -> Path:
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks_dicts)
|
||||
writer.write_all(task_dicts)
|
||||
return fpath
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
|
||||
8
tests/fixtures/hub.py
vendored
8
tests/fixtures/hub.py
vendored
@@ -26,7 +26,7 @@ def mock_snapshot_download_factory(
|
||||
"""
|
||||
|
||||
def _mock_snapshot_download_func(
|
||||
info_dict=info, stats_dict=stats, tasks_dicts=tasks, episodes_dicts=episodes, hf_ds=hf_dataset
|
||||
info_dict=info, stats_dict=stats, task_dicts=tasks, episode_dicts=episodes, hf_ds=hf_dataset
|
||||
):
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
@@ -53,7 +53,7 @@ def mock_snapshot_download_factory(
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes_dicts:
|
||||
for episode_dict in episode_dicts:
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info_dict["chunks_size"]
|
||||
data_path = info_dict["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
@@ -75,9 +75,9 @@ def mock_snapshot_download_factory(
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats_dict)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks_dicts)
|
||||
_ = tasks_path(local_dir, task_dicts)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes_dicts)
|
||||
_ = episode_path(local_dir, episode_dicts)
|
||||
else:
|
||||
pass
|
||||
return str(local_dir)
|
||||
|
||||
@@ -76,7 +76,7 @@ def main():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
|
||||
save_single_transforms(original_frame, output_dir)
|
||||
save_default_config_transform(original_frame, output_dir)
|
||||
|
||||
@@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
)
|
||||
set_global_seed(1337)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
assert dataset.total_episodes == 2
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
@@ -193,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
out_dir = tmpdir / "logger"
|
||||
logger = Logger(cfg, out_dir, wandb_job_name="debug")
|
||||
|
||||
@@ -33,7 +33,11 @@ from lerobot.common.datasets.compute_stats import (
|
||||
get_stats_einops_patterns,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
@@ -53,14 +57,17 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
metadata_create = LeRobotDatasetMetadata.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
|
||||
)
|
||||
dataset_create = LeRobotDataset.create(metadata_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||
_ = dataset_init._hub_version
|
||||
_ = dataset_create._hub_version
|
||||
_ = dataset_init.meta._hub_version
|
||||
_ = dataset_create.meta._hub_version
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
@@ -78,8 +85,8 @@ def test_dataset_initialization(lerobot_dataset_from_episodes_factory, tmp_path)
|
||||
dataset = lerobot_dataset_from_episodes_factory(root=tmp_path, **kwargs)
|
||||
|
||||
assert dataset.repo_id == kwargs["repo_id"]
|
||||
assert dataset.total_episodes == kwargs["total_episodes"]
|
||||
assert dataset.total_frames == kwargs["total_frames"]
|
||||
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
||||
assert dataset.meta.total_frames == kwargs["total_frames"]
|
||||
assert dataset.episodes == kwargs["episodes"]
|
||||
assert dataset.num_episodes == len(kwargs["episodes"])
|
||||
assert dataset.num_frames == len(dataset)
|
||||
@@ -118,7 +125,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
)
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.camera_keys
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
@@ -251,7 +258,7 @@ def test_compute_stats_on_xarm():
|
||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||
loaded_stats = dataset.stats # noqa: F841
|
||||
loaded_stats = dataset.meta.stats # noqa: F841
|
||||
|
||||
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
|
||||
# # test loaded stats match expected stats
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# TODO(aliberts): Mute logging for these tests
|
||||
|
||||
import io
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -29,6 +29,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
|
||||
return text
|
||||
|
||||
|
||||
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
|
||||
def _run_script(path):
|
||||
subprocess.run([sys.executable, path], check=True)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
@@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
@@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
env.step(action)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
@@ -213,7 +214,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.training.lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||
|
||||
@@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id, make_test_data",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user