Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_25_refactor_robots

This commit is contained in:
Simon Alibert
2025-03-13 14:24:50 +01:00
88 changed files with 151 additions and 154 deletions

View File

@@ -23,7 +23,7 @@ If you know that your change will break backward compatibility, you should write
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
Example usage:
`python tests/scripts/save_dataset_to_safetensors.py`
`python tests/artifacts/datasets/save_dataset_to_safetensors.py`
"""
import shutil
@@ -88,4 +88,4 @@ if __name__ == "__main__":
"lerobot/nyu_franka_play_dataset",
"lerobot/cmu_stretch",
]:
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
save_dataset_to_safetensors("tests/artifacts/datasets", repo_id=dataset)

View File

@@ -27,7 +27,7 @@ from lerobot.common.datasets.transforms import (
)
from lerobot.common.utils.random_utils import seeded_context
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"

View File

@@ -141,5 +141,5 @@ if __name__ == "__main__":
raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@@ -146,7 +146,7 @@ def test_camera(request, camera_type, mock):
camera.connect()
if mock:
import tests.mock_cv2 as cv2
import tests.cameras.mock_cv2 as cv2
else:
import cv2

View File

@@ -473,12 +473,12 @@ def test_flatten_unflatten_dict():
)
@require_x86_64_kernel
def test_backward_compatibility(repo_id):
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
"""The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`."""
# TODO(rcadene, aliberts): remove dataset download
dataset = LeRobotDataset(repo_id, episodes=[0])
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
test_dir = Path("tests/artifacts/datasets") / repo_id
def load_and_compare(i):
new_frame = dataset[i] # noqa: B023

View File

@@ -33,7 +33,7 @@ from lerobot.scripts.visualize_image_transforms import (
save_all_transforms,
save_each_transform,
)
from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.utils import require_x86_64_kernel

View File

@@ -1,3 +1,5 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,13 +13,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
def test_calculate_episode_data_index():

View File

@@ -23,8 +23,7 @@ from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.common.envs.factory import make_env, make_env_config
from lerobot.common.envs.utils import preprocess_observation
from .utils import require_env
from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]

View File

@@ -1,38 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub import DatasetCard
from lerobot.common.datasets.utils import create_lerobot_dataset_card
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]

View File

@@ -40,7 +40,7 @@ from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -407,12 +407,10 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
"""
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy_name}_{file_name_extra}"
)
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")

View File

@@ -51,7 +51,7 @@ from lerobot.configs.control import (
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot
from tests.robots.test_robots import make_robot
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot