Merge remote-tracking branch 'origin/main' into user/rcadene/2025_02_19_port_openx

This commit is contained in:
Remi Cadene
2025-03-01 19:17:18 +00:00
123 changed files with 2489 additions and 629 deletions

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eb7b74f919adf8d4478585f65c54997e6f3bccab67eadb4048300108586a4163
size 5104

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dfbc3b1ad5e3b94311edda0f04db002b26117b0719b73dfdb56dd483dc9c409d
size 31672

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e39afdf1f3db8a72a1095a5a0ffdb7e67f478a28bd73e59cda197687da8d236c
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5dd39a554c9c3db537e98c9ceade024d172c46c4fa7ce9e27601b94116445417
size 33400

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a5ec46abc5a3c85675a5ee4a1bb362eecb3ff4c546082ff309c89fc7821f38bd
size 515400

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:50303d05caea725c4a240f1389424d6c2361961f2cee729a0010e909ebffed81
size 31672

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9bb9b195d32e05550af0edd5df88fcc761c829ab8c4b129ba970a723f39b46ee
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:683a2038185f3d070e7d7c0c31e4aa75067c11bf798daa41c9fab336f4183fda
size 33400

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc67af1d60f95d84c98d6c9ebd648990e0f0705368bd6b72d2b39533950b0179
size 5104

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:64518cf652105d15f5fd2cfc13d0681f66a4ec4797dc5d5dc2f7b0d91fe5dfd6
size 31672

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:32b6d14fab4244b5140adb345e47f662b6739c04974e04b21c3127caa988abbb
size 68

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e1904ef0338f7b6efdec70ec235ee931b5751008bf4eb433edb0b3fa0838a4f1
size 33400

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fa544a97f00bf46393a09b006b44c2499bbf7d177782360a8c21cacbf200c07a
size 515400

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83c7a8ae912300b5cedba31904f7ba22542059fd60dd86548a95e415713f719e
size 31672

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5a010633237b3a1141603c65174c551daa9e7b4c474af5a1376d73e5425bfb5d
size 68

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec8b5c440e9fcec190c9be48b28ebb79f82ae63626afe7c811e4bb0c3dd08842
size 33400

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e56a5d30778395534a06ad1742843700424614168fc26d1098558012a5df90c6
size 5104

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c9007dd51c748db4ecd6d75e70bdcabf8c312454ac97bf6710895a12e7288557
size 31672

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:170bd8365dfd1e36e8f56814bf8bc2057aa0d035c41212b7ddd7e4b9feee1633
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:11884346b41ca102c672bb0f361ea9699d2f8b33bb503038b53cc7e7fafd281b
size 34920

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0c259ea9c40aab3841ca35b2a2e708d8829b0a9163b2f9e5efd28f1c65848293
size 4600

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:77cd4127a45ded2f75d85ca9c17537808517614ef16fb3035cebb1b45547acbf
size 47424

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fcff4b736e95d685d56830b501f4542b081f4334f72d28a7415809f4d9d15d0f
size 68

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:60775e91ed550aae66cb0547ee4b0e38917f29172e942671e9361b3812364df6
size 49120

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c
size 992

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201
size 47424

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
size 68

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22
size 49120

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b
size 200

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb
size 200

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10
size 16904

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b
size 164

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170
size 36312

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:81457cfd193d9d46b6871071a3971c2901fefa544ab225576132772087b4cf3a
size 472

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
size 16904

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
size 240

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
size 36312

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6cdb181ba6acc4aa1209a9ea5dd783f077ff87760257de1026c33f8e2fb2b2b1
size 472

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d796577863740e8fd643a056e9eff891e51a858ff66019eba11f0a982cb9e9c0
size 16904

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4636751d82103a268ac7cf36f1e69f6356f356b9c40561a9fe8557bb9255e2ee
size 240

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b7d08c9518f1f15226e4efc6f2a8542d0f3e620c91421c7cacea07d9bd9025d6
size 36312

View File

@@ -338,14 +338,12 @@ def lerobot_dataset_metadata_factory(
episodes=episodes,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_revision"
) as mock_get_safe_revision_patch,
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root)
@@ -418,15 +416,13 @@ def lerobot_dataset_factory(
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_revision"
) as mock_get_safe_revision_patch,
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_metadata_patch.return_value = mock_metadata
mock_get_safe_revision_patch.side_effect = lambda repo_id, version: version
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)

View File

@@ -18,11 +18,11 @@ def _generate_image(width: int, height: int):
return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
def cvtColor(color_image, color_convertion): # noqa: N802
if color_convertion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
def cvtColor(color_image, color_conversion): # noqa: N802
if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]:
return color_image[:, :, [2, 1, 0]]
else:
raise NotImplementedError(color_convertion)
raise NotImplementedError(color_conversion)
def rotate(color_image, rotation):

View File

@@ -27,16 +27,13 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
# TODO(rcadene, aliberts): env_name?
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
set_seed(1337)
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu",
**train_kwargs,
)
train_cfg.validate() # Needed for auto-setting some parameters
@@ -54,8 +51,11 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
batch = next(iter(dataloader))
loss, output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
else:
output_dict = {"loss": loss}
loss.backward()
grad_stats = {}
@@ -101,30 +101,27 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`")
print(f" - Revert with: `git checkout -- {output_dir}`")
shutil.rmtree(output_dir)
if env_policy_dir.exists():
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
print(f" - Validate with: `git add {env_policy_dir}`")
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
shutil.rmtree(env_policy_dir)
env_policy_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
save_file(actions, env_policy_dir / "actions.safetensors")
output_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
save_file(output_dict, output_dir / "output_dict.safetensors")
save_file(grad_stats, output_dir / "grad_stats.safetensors")
save_file(param_stats, output_dir / "param_stats.safetensors")
save_file(actions, output_dir / "actions.safetensors")
if __name__ == "__main__":
env_policies = [
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"),
artifacts_cfg = [
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
(
"lerobot/pusht",
"pusht",
"diffusion",
{
"n_action_steps": 8,
@@ -133,18 +130,17 @@ if __name__ == "__main__":
},
"",
),
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
"_1000_steps",
"1000_steps",
),
]
if len(env_policies) == 0:
if len(artifacts_cfg) == 0:
raise RuntimeError("No policies were provided!")
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra
)
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@@ -27,7 +27,7 @@ import pytest
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
# Maximum absolute difference between two consecutive images recored by a camera.
# Maximum absolute difference between two consecutive images recorded by a camera.
# This value differs with respect to the camera.
MAX_PIXEL_DIFFERENCE = 25

View File

@@ -179,7 +179,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
policy.save_pretrained(pretrained_policy_path)
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
# during inference, to reach constent fps, so we test this here.
# during inference, to reach constant fps, so we test this here.
if robot_type == "aloha":
num_image_writer_processes = 1

View File

@@ -486,7 +486,7 @@ def test_backward_compatibility(repo_id):
old_frame = load_file(test_dir / f"frame_{i}.safetensors") # noqa: B023
# ignore language instructions (if exists) in language conditioned datasets
# TODO (michel-aractingi): transform language obs to langauge embeddings via tokenizer
# TODO (michel-aractingi): transform language obs to language embeddings via tokenizer
new_frame.pop("language_instruction", None)
old_frame.pop("language_instruction", None)
new_frame.pop("task", None)

View File

@@ -1,55 +1,78 @@
from itertools import accumulate
import datasets
import numpy as np
import pyarrow.compute as pc
import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
hf_transform_to_torch,
)
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
@pytest.fixture(scope="module")
def synced_hf_dataset_factory(hf_dataset_factory):
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
return hf_dataset_factory(fps=fps)
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
return _create_synced_hf_dataset
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lengths = list(accumulate(episode_lengths))
return {
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
"to": np.array(cumulative_lengths, dtype=np.int64),
}
@pytest.fixture(scope="module")
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just outside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
episode_data_index = calculate_episode_data_index(hf_dataset)
return timestamps, episode_indices, episode_data_index
return _create_unsynced_hf_dataset
return _create_synced_timestamps
@pytest.fixture(scope="module")
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just inside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
def unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_hf_dataset
return _create_unsynced_timestamps
@pytest.fixture(scope="module")
def slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_timestamps
@pytest.fixture(scope="module")
@@ -100,42 +123,42 @@ def delta_indices_factory():
return _delta_indices
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
def test_check_timestamps_sync_synced(synced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
synced_hf_dataset = synced_hf_dataset_factory(fps)
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
result = check_timestamps_sync(
hf_dataset=synced_hf_dataset,
episode_data_index=episode_data_index,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
with pytest.raises(ValueError):
check_timestamps_sync(
hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
@@ -143,14 +166,14 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory
assert result is False
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
hf_dataset=slightly_off_hf_dataset,
episode_data_index=episode_data_index,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
@@ -158,33 +181,13 @@ def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
def test_check_timestamps_sync_single_timestamp():
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx = np.array([0.0]), np.array([0])
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
result = check_timestamps_sync(
hf_dataset=single_timestamp_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
# TODO(aliberts): Change behavior of hf_transform_to_torch so that it can work with empty dataset
@pytest.mark.skip("TODO: fix")
def test_check_timestamps_sync_empty_dataset():
fps = 30
tolerance_s = 1e-4
empty_hf_dataset = Dataset.from_dict({"timestamp": [], "episode_index": []})
empty_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"to": torch.tensor([], dtype=torch.int64),
"from": torch.tensor([], dtype=torch.int64),
}
result = check_timestamps_sync(
hf_dataset=empty_hf_dataset,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,

View File

@@ -53,7 +53,7 @@ def test_example_1(tmp_path, lerobot_dataset_factory):
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'),
(
"LeRobotDataset(repo_id",
f"LeRobotDataset(repo_id, root='{str(tmp_path)}', local_files_only=True",
f"LeRobotDataset(repo_id, root='{str(tmp_path)}'",
),
],
)

View File

@@ -88,7 +88,7 @@ def test_motors_bus(request, motor_type, mock):
motors_bus = make_motors_bus(motor_type, mock=mock)
# Test reading and writting before connecting raises an error
# Test reading and writing before connecting raises an error
with pytest.raises(RobotDeviceNotConnectedError):
motors_bus.read("Torque_Enable")
with pytest.raises(RobotDeviceNotConnectedError):

View File

@@ -166,7 +166,7 @@ def test_delta_timestamps_within_tolerance():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
assert not is_pad.any(), "Unexpected padding detected"
@@ -236,7 +236,7 @@ def test_compute_sampler_weights_trivial(
elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights /= expected_weights.sum()
assert torch.allclose(weights, expected_weights)
torch.testing.assert_close(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
@@ -248,7 +248,7 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
assert torch.allclose(
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
)
@@ -261,7 +261,7 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_datase
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
)
assert torch.allclose(
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
)
@@ -279,4 +279,4 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))

View File

@@ -363,37 +363,33 @@ def test_normalize(insert_temporal_dim):
@pytest.mark.parametrize(
"ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra",
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
[
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
# to test with `policy.use_mpc=false`.
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, {"batch_size": 25}, "use_policy"),
# ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, {}, "use_mpc"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
# Thus, we deactivate this test for now.
# (
# "lerobot/pusht",
# "pusht",
# "diffusion",
# {
# "n_action_steps": 8,
# "num_inference_steps": 10,
# "down_dims": [128, 256, 512],
# },
# {"batch_size": 64},
# "",
# ),
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, {}, ""),
(
"lerobot/pusht",
"diffusion",
{
"n_action_steps": 8,
"num_inference_steps": 10,
"down_dims": [128, 256, 512],
},
"",
),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
{},
"_1000_steps",
"1000_steps",
),
],
)
@@ -401,9 +397,7 @@ def test_normalize(insert_temporal_dim):
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra
):
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
@@ -416,26 +410,26 @@ def test_backward_compatibility(
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
6. Remember to stage and commit the resulting changes to `tests/data`.
"""
env_policy_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = (
Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy_name}_{file_name_extra}"
)
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
saved_actions = load_file(env_policy_dir / "actions.safetensors")
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
saved_actions = load_file(artifact_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs
)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
for key in saved_output_dict:
assert torch.allclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7)
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
for key in saved_grad_stats:
assert torch.allclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7)
torch.testing.assert_close(grad_stats[key], saved_grad_stats[key])
for key in saved_param_stats:
assert torch.allclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-7)
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
for key in saved_actions:
assert torch.allclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7)
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
def test_act_temporal_ensembler():
@@ -490,4 +484,4 @@ def test_act_temporal_ensembler():
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)

View File

@@ -86,7 +86,7 @@ def test_robot(tmp_path, request, robot_type, mock):
robot.connect()
robot.teleop_step()
# Test data recorded during teleop are well formated
# Test data recorded during teleop are well formatted
observation, action = robot.teleop_step(record_data=True)
# State
assert "observation.state" in observation
@@ -114,7 +114,7 @@ def test_robot(tmp_path, request, robot_type, mock):
if "image" in name:
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
continue
assert torch.allclose(captured_observation[name], observation[name], atol=1)
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
assert captured_observation[name].shape == observation[name].shape
# Test send_action can run