Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
@@ -20,8 +20,8 @@ import pytest
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from tests.utils import DEVICE, make_camera, make_motors_bus
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
@@ -43,11 +43,7 @@ def is_robot_available(robot_type):
|
||||
)
|
||||
|
||||
try:
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
|
||||
config_path = ROBOT_CONFIG_PATH_TEMPLATE.format(robot=robot_type)
|
||||
robot_cfg = init_hydra_config(config_path)
|
||||
robot = make_robot(robot_cfg)
|
||||
robot = make_robot(robot_type)
|
||||
robot.connect()
|
||||
del robot
|
||||
return True
|
||||
|
||||
@@ -43,6 +43,7 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
repo_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
episodes=[0],
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
|
||||
@@ -19,29 +19,21 @@ import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
from lerobot.common.datasets.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
with seeded_context(1337):
|
||||
img_tf = default_tf(original_frame)
|
||||
@@ -51,29 +43,26 @@ def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path
|
||||
|
||||
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||
transforms = {
|
||||
"brightness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"contrast": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"saturation": [(0.5, 0.5), (2.0, 2.0)],
|
||||
"hue": [(-0.25, -0.25), (0.25, 0.25)],
|
||||
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
|
||||
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
}
|
||||
|
||||
frames = {"original_frame": original_frame}
|
||||
for transform, values in transforms.items():
|
||||
for min_max in values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
for tf_type, tf_name, min_max_values in transforms.items():
|
||||
for min_max in min_max_values:
|
||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
|
||||
frames[key] = tf(original_frame)
|
||||
|
||||
save_file(frames, output_dir / "single_transforms.safetensors")
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, episodes=[0], image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
|
||||
@@ -20,32 +20,35 @@ import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.common.utils.utils import set_global_seed
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
"device=cpu",
|
||||
]
|
||||
+ extra_overrides,
|
||||
)
|
||||
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
|
||||
# TODO(rcadene, aliberts): env_name?
|
||||
set_global_seed(1337)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
device="cpu",
|
||||
**train_kwargs,
|
||||
)
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
dataset = make_dataset(train_cfg)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device)
|
||||
policy.train()
|
||||
|
||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=cfg.training.batch_size,
|
||||
batch_size=train_cfg.batch_size,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
@@ -72,24 +75,28 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
optimizer.zero_grad()
|
||||
policy.reset()
|
||||
|
||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||
dataset.delta_timestamps = None
|
||||
# HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension
|
||||
# We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors
|
||||
# indicating padding (those ending with "_is_pad")
|
||||
dataset.delta_indices = None
|
||||
batch = next(iter(dataloader))
|
||||
obs = {}
|
||||
for k in batch:
|
||||
if k.endswith("_is_pad"):
|
||||
continue
|
||||
if k.startswith("observation"):
|
||||
obs[k] = batch[k]
|
||||
|
||||
if "n_action_steps" in cfg.policy:
|
||||
actions_queue = cfg.policy.n_action_steps
|
||||
if hasattr(train_cfg.policy, "n_action_steps"):
|
||||
actions_queue = train_cfg.policy.n_action_steps
|
||||
else:
|
||||
actions_queue = cfg.policy.n_action_repeats
|
||||
actions_queue = train_cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, policy_kwargs, file_name_extra):
|
||||
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
|
||||
if env_policy_dir.exists():
|
||||
@@ -99,7 +106,7 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, policy_kwargs)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
@@ -108,26 +115,31 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_policies = [
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
# (
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
# [
|
||||
# "policy.n_action_steps=8",
|
||||
# "policy.num_inference_steps=10",
|
||||
# "policy.down_dims=[128, 256, 512]",
|
||||
# ],
|
||||
# "",
|
||||
# ),
|
||||
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, "use_policy"),
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
||||
(
|
||||
"lerobot/pusht",
|
||||
"pusht",
|
||||
"diffusion",
|
||||
{
|
||||
"n_action_steps": 8,
|
||||
"num_inference_steps": 10,
|
||||
"down_dims": [128, 256, 512],
|
||||
},
|
||||
"",
|
||||
),
|
||||
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, ""),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"aloha",
|
||||
"act",
|
||||
{"n_action_steps": 1000, "chunk_size": 1000},
|
||||
"_1000_steps",
|
||||
),
|
||||
]
|
||||
if len(env_policies) == 0:
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||
for ds_repo_id, env, policy, policy_kwargs, file_name_extra in env_policies:
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||
"tests/data/save_policy_to_safetensors", ds_repo_id, env, policy, policy_kwargs, file_name_extra
|
||||
)
|
||||
|
||||
@@ -51,8 +51,10 @@ def test_camera(request, camera_type, mock):
|
||||
if camera_type == "opencv" and not mock:
|
||||
pytest.skip("TODO(rcadene): fix test for opencv physical camera")
|
||||
|
||||
camera_kwargs = {"camera_type": camera_type, "mock": mock}
|
||||
|
||||
# Test instantiating
|
||||
camera = make_camera(camera_type, mock=mock)
|
||||
camera = make_camera(**camera_kwargs)
|
||||
|
||||
# Test reading, async reading, disconnecting before connecting raises an error
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
@@ -66,7 +68,7 @@ def test_camera(request, camera_type, mock):
|
||||
del camera
|
||||
|
||||
# Test connecting
|
||||
camera = make_camera(camera_type, mock=mock)
|
||||
camera = make_camera(**camera_kwargs)
|
||||
camera.connect()
|
||||
assert camera.is_connected
|
||||
assert camera.fps is not None
|
||||
@@ -106,12 +108,12 @@ def test_camera(request, camera_type, mock):
|
||||
assert camera.thread is None
|
||||
|
||||
# Test disconnecting with `__del__`
|
||||
camera = make_camera(camera_type, mock=mock)
|
||||
camera = make_camera(**camera_kwargs)
|
||||
camera.connect()
|
||||
del camera
|
||||
|
||||
# Test acquiring a bgr image
|
||||
camera = make_camera(camera_type, color_mode="bgr", mock=mock)
|
||||
camera = make_camera(**camera_kwargs, color_mode="bgr")
|
||||
camera.connect()
|
||||
assert camera.color_mode == "bgr"
|
||||
bgr_color_image = camera.read()
|
||||
@@ -121,13 +123,13 @@ def test_camera(request, camera_type, mock):
|
||||
del camera
|
||||
|
||||
# Test acquiring a rotated image
|
||||
camera = make_camera(camera_type, mock=mock)
|
||||
camera = make_camera(**camera_kwargs)
|
||||
camera.connect()
|
||||
ori_color_image = camera.read()
|
||||
del camera
|
||||
|
||||
for rotation in [None, 90, 180, -90]:
|
||||
camera = make_camera(camera_type, rotation=rotation, mock=mock)
|
||||
camera = make_camera(**camera_kwargs, rotation=rotation)
|
||||
camera.connect()
|
||||
|
||||
if mock:
|
||||
@@ -159,7 +161,7 @@ def test_camera(request, camera_type, mock):
|
||||
# TODO(rcadene): Add a test for a camera that supports fps=60
|
||||
|
||||
# Test width and height can be set
|
||||
camera = make_camera(camera_type, fps=30, width=1280, height=720, mock=mock)
|
||||
camera = make_camera(**camera_kwargs, fps=30, width=1280, height=720)
|
||||
camera.connect()
|
||||
assert camera.fps == 30
|
||||
assert camera.width == 1280
|
||||
@@ -172,7 +174,7 @@ def test_camera(request, camera_type, mock):
|
||||
del camera
|
||||
|
||||
# Test not supported width and height raise an error
|
||||
camera = make_camera(camera_type, fps=30, width=0, height=0, mock=mock)
|
||||
camera = make_camera(**camera_kwargs, fps=30, width=0, height=0)
|
||||
with pytest.raises(OSError):
|
||||
camera.connect()
|
||||
del camera
|
||||
|
||||
@@ -30,17 +30,27 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
RecordControlConfig,
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_teleoperate(tmpdir, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
@@ -49,39 +59,44 @@ def test_teleoperate(tmpdir, request, robot_type, mock):
|
||||
tmpdir = Path(tmpdir)
|
||||
calibration_dir = tmpdir / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
overrides = None
|
||||
pass
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
teleoperate(robot, teleop_time_s=1)
|
||||
teleoperate(robot, fps=30, teleop_time_s=1)
|
||||
teleoperate(robot, fps=60, teleop_time_s=1)
|
||||
robot = make_robot(**robot_kwargs)
|
||||
teleoperate(robot, TeleoperateControlConfig(teleop_time_s=1))
|
||||
teleoperate(robot, TeleoperateControlConfig(fps=30, teleop_time_s=1))
|
||||
teleoperate(robot, TeleoperateControlConfig(fps=60, teleop_time_s=1))
|
||||
del robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_calibrate(tmpdir, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
tmpdir = Path(tmpdir)
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
|
||||
calibrate(robot, arms=robot.available_arms)
|
||||
robot = make_robot(**robot_kwargs)
|
||||
calib_cfg = CalibrateControlConfig(arms=robot.available_arms)
|
||||
calibrate(robot, calib_cfg)
|
||||
del robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
# Avoid using cameras
|
||||
overrides = ["~cameras"]
|
||||
robot_kwargs["cameras"] = {}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
@@ -90,33 +105,38 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = Path(tmpdir) / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides.append(f"calibration_dir={calibration_dir}")
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
record(
|
||||
robot,
|
||||
fps=30,
|
||||
root=root,
|
||||
robot = make_robot(**robot_kwargs)
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
single_task=single_task,
|
||||
warmup_time_s=1,
|
||||
root=root,
|
||||
fps=30,
|
||||
warmup_time_s=0.1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
run_compute_stats=False,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
record(robot, rec_cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@require_robot
|
||||
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
tmpdir = Path(tmpdir)
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
@@ -125,28 +145,24 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = None
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
repo_id = "lerobot_test/debug"
|
||||
root = tmpdir / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
robot = make_robot(**robot_kwargs)
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
single_task=single_task,
|
||||
root=root,
|
||||
fps=1,
|
||||
warmup_time_s=0.5,
|
||||
warmup_time_s=0.1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
# TODO(rcadene, aliberts): test video=True
|
||||
@@ -155,56 +171,34 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True)
|
||||
|
||||
# TODO(rcadene, aliberts): rethink this design
|
||||
if robot_type == "aloha":
|
||||
env_name = "aloha_real"
|
||||
policy_name = "act_aloha_real"
|
||||
elif robot_type in ["koch", "koch_bimanual"]:
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
elif robot_type == "so100":
|
||||
env_name = "so100_real"
|
||||
policy_name = "act_so100_real"
|
||||
elif robot_type == "moss":
|
||||
env_name = "moss_real"
|
||||
policy_name = "act_moss_real"
|
||||
else:
|
||||
raise NotImplementedError(robot_type)
|
||||
|
||||
overrides = [
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
|
||||
if robot_type == "koch_bimanual":
|
||||
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
||||
|
||||
overrides += ["wandb.enable=false"]
|
||||
overrides += ["env.fps=1"]
|
||||
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=overrides,
|
||||
replay_cfg = ReplayControlConfig(
|
||||
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True
|
||||
)
|
||||
replay(robot, replay_cfg)
|
||||
|
||||
policy_cfg = ACTConfig()
|
||||
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
out_dir = tmpdir / "logger"
|
||||
logger = Logger(cfg, out_dir, wandb_job_name="debug")
|
||||
logger.save_checkpoint(
|
||||
0,
|
||||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=0,
|
||||
|
||||
ds_cfg = DatasetConfig(repo_id, local_files_only=True)
|
||||
train_cfg = TrainPipelineConfig(
|
||||
dataset=ds_cfg,
|
||||
policy=policy_cfg,
|
||||
output_dir=out_dir,
|
||||
device=DEVICE,
|
||||
)
|
||||
pretrained_policy_name_or_path = out_dir / "checkpoints/last/pretrained_model"
|
||||
logger = Logger(train_cfg)
|
||||
logger.save_checkpoint(
|
||||
train_step=0,
|
||||
identifier=0,
|
||||
policy=policy,
|
||||
)
|
||||
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
|
||||
|
||||
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
|
||||
# during inference, to reach constent fps, so we test this here.
|
||||
@@ -230,15 +224,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
eval_repo_id = "lerobot/eval_debug"
|
||||
eval_root = tmpdir / "data" / eval_repo_id
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
eval_root,
|
||||
eval_repo_id,
|
||||
single_task,
|
||||
pretrained_policy_name_or_path,
|
||||
warmup_time_s=1,
|
||||
rec_eval_cfg = RecordControlConfig(
|
||||
repo_id=eval_repo_id,
|
||||
root=eval_root,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0.1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
run_compute_stats=False,
|
||||
push_to_hub=False,
|
||||
@@ -246,8 +239,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
device=DEVICE,
|
||||
use_amp=False,
|
||||
)
|
||||
|
||||
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)
|
||||
rec_eval_cfg.policy.pretrained_path = pretrained_policy_path
|
||||
|
||||
dataset = record(robot, rec_eval_cfg)
|
||||
assert dataset.num_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
@@ -257,6 +256,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
@@ -264,48 +265,50 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
record_kwargs = {
|
||||
"robot": robot,
|
||||
"root": root,
|
||||
"repo_id": repo_id,
|
||||
"single_task": single_task,
|
||||
"fps": 1,
|
||||
"warmup_time_s": 0,
|
||||
"episode_time_s": 1,
|
||||
"push_to_hub": False,
|
||||
"video": False,
|
||||
"display_cameras": False,
|
||||
"play_sounds": False,
|
||||
"run_compute_stats": False,
|
||||
"local_files_only": True,
|
||||
"num_episodes": 1,
|
||||
}
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
local_files_only=True,
|
||||
num_episodes=1,
|
||||
)
|
||||
|
||||
dataset = record(**record_kwargs)
|
||||
dataset = record(robot, rec_cfg)
|
||||
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
# Dataset already exists, but resume=False by default
|
||||
record(**record_kwargs)
|
||||
record(robot, rec_cfg)
|
||||
|
||||
dataset = record(**record_kwargs, resume=True)
|
||||
rec_cfg.resume = True
|
||||
dataset = record(robot, rec_cfg)
|
||||
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock and robot_type != "aloha":
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
@@ -313,12 +316,13 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
@@ -330,11 +334,10 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
@@ -345,6 +348,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
@@ -354,6 +358,8 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@require_robot
|
||||
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
@@ -361,12 +367,13 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
@@ -378,12 +385,11 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
fps=2,
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
repo_id=repo_id,
|
||||
fps=2,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
@@ -394,6 +400,8 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
run_compute_stats=False,
|
||||
)
|
||||
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
@@ -403,6 +411,8 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
)
|
||||
@require_robot
|
||||
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
@@ -410,12 +420,13 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||
# and avoid writing calibration files in user .cache/calibration folder
|
||||
calibration_dir = tmpdir / robot_type
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides = [f"calibration_dir={calibration_dir}"]
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
else:
|
||||
# Use the default .cache/calibration folder when mock=False or for aloha
|
||||
overrides = []
|
||||
# Use the default .cache/calibration folder when mock=False
|
||||
pass
|
||||
|
||||
robot = make_robot(**robot_kwargs)
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
@@ -427,14 +438,14 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
rec_cfg = RecordControlConfig(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
reset_time_s=0.1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
@@ -444,5 +455,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
dataset = record(robot, rec_cfg)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
@@ -43,9 +43,14 @@ from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
||||
from tests.utils import DEVICE, require_x86_64_kernel
|
||||
|
||||
|
||||
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||
@@ -98,11 +103,13 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
||||
# - [ ] test smaller methods
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
lerobot.env_dataset_policy_triplets
|
||||
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
# Single dataset
|
||||
lerobot.env_dataset_policy_triplets,
|
||||
# Multi-dataset
|
||||
# TODO after fix multidataset
|
||||
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
@@ -110,15 +117,14 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
- we can create a dataset with the factory.
|
||||
- for a commonly used set of data keys, the data dimensions are correct.
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"dataset_repo_id={repo_id}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
],
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
||||
env=make_env_config(env_name),
|
||||
policy=make_policy_config(policy_name),
|
||||
device=DEVICE,
|
||||
)
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
@@ -171,8 +177,8 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_multilerobotdataset_frames():
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
@@ -205,14 +211,14 @@ def test_multilerobotdataset_frames():
|
||||
|
||||
|
||||
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_compute_stats_on_xarm():
|
||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
|
||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||
because we are working with a small dataset).
|
||||
"""
|
||||
dataset = LeRobotDataset("lerobot/xarm_lift_medium")
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset = LeRobotDataset("lerobot/xarm_lift_medium", episodes=[0])
|
||||
|
||||
# reduce size of dataset sample on which stats compute is tested to 10 frames
|
||||
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
|
||||
@@ -289,7 +295,6 @@ def test_flatten_unflatten_dict():
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
@@ -301,11 +306,12 @@ def test_flatten_unflatten_dict():
|
||||
# "lerobot/cmu_stretch",
|
||||
],
|
||||
)
|
||||
# TODO(rcadene, aliberts): all these tests fail locally on Mac M1, but not on Linux
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0])
|
||||
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
@@ -318,6 +324,11 @@ def test_backward_compatibility(repo_id):
|
||||
new_frame.pop("language_instruction", None)
|
||||
old_frame.pop("language_instruction", None)
|
||||
|
||||
# Remove task_index to allow for backward compatibility
|
||||
# TODO(rcadene): remove when new features have been generated
|
||||
if "task_index" not in old_frame:
|
||||
del new_frame["task_index"]
|
||||
|
||||
new_keys = set(new_frame.keys())
|
||||
old_keys = set(old_frame.keys())
|
||||
assert new_keys == old_keys, f"{new_keys=} and {old_keys=} are not the same"
|
||||
@@ -361,8 +372,8 @@ def test_backward_compatibility(repo_id):
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_aggregate_stats():
|
||||
@pytest.mark.skip("TODO after fix multidataset")
|
||||
def test_multidataset_aggregate_stats():
|
||||
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
|
||||
with seeded_context(0):
|
||||
data_a = torch.rand(30, dtype=torch.float32)
|
||||
|
||||
@@ -54,8 +54,10 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
|
||||
) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
@@ -91,8 +93,11 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
return {key: list(range(-10, 10)) for key in keys}
|
||||
def delta_indices_factory():
|
||||
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
|
||||
return {key: list(range(*min_max_range)) for key in keys}
|
||||
|
||||
return _delta_indices
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
|
||||
@@ -248,9 +253,10 @@ def test_check_delta_timestamps_empty():
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_delta_indices(valid_delta_timestamps_factory, delta_indices):
|
||||
fps = 30
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps)
|
||||
expected_delta_indices = delta_indices
|
||||
def test_delta_indices(valid_delta_timestamps_factory, delta_indices_factory):
|
||||
fps = 50
|
||||
min_max_range = (-100, 100)
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, min_max_range=min_max_range)
|
||||
expected_delta_indices = delta_indices_factory(min_max_range=min_max_range)
|
||||
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
|
||||
assert expected_delta_indices == actual_delta_indices
|
||||
|
||||
@@ -21,11 +21,10 @@ import torch
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.factory import make_env, make_env_config
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
from .utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
|
||||
@@ -47,11 +46,7 @@ def test_env(env_name, env_task, obs_type):
|
||||
@pytest.mark.parametrize("env_name", lerobot.available_envs)
|
||||
@require_env
|
||||
def test_factory(env_name):
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
cfg = make_env_config(env_name)
|
||||
env = make_env(cfg, n_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -21,13 +20,21 @@ from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
from lerobot.common.datasets.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
RandomSubsetApply,
|
||||
SharpnessJitter,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -44,21 +51,38 @@ def single_transforms():
|
||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_tensor(single_transforms):
|
||||
return single_transforms["original_frame"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform(img_tensor_factory):
|
||||
def test_get_image_transforms_no_transform_enable_false(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
tf_cfg = ImageTransformsConfig() # default is enable=False
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform_max_num_transforms_0(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True, max_num_transforms=0)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
@@ -66,7 +90,10 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
@@ -74,7 +101,11 @@ def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
@@ -82,7 +113,10 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
@@ -90,21 +124,49 @@ def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
random_order=False,
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
max_num_transforms=5,
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.5, 0.5)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.5, 0.5)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 0.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (0.5, 0.5)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 0.5)},
|
||||
),
|
||||
},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 0.5)),
|
||||
@@ -121,68 +183,79 @@ def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||
def test_get_image_transforms_random_order(img_tensor_factory):
|
||||
out_imgs = []
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
random_order=True,
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.5, 0.5)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.5, 0.5)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 0.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (0.5, 0.5)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 0.5)},
|
||||
),
|
||||
},
|
||||
)
|
||||
with seeded_context(1337):
|
||||
tf = ImageTransforms(tf_cfg)
|
||||
|
||||
with seeded_context(1338):
|
||||
for _ in range(10):
|
||||
out_imgs.append(tf(img_tensor))
|
||||
|
||||
tmp_img_tensor = img_tensor
|
||||
for sub_tf in tf.tf.selected_transforms:
|
||||
tmp_img_tensor = sub_tf(tmp_img_tensor)
|
||||
torch.testing.assert_close(tmp_img_tensor, out_imgs[-1])
|
||||
|
||||
for i in range(1, len(out_imgs)):
|
||||
with pytest.raises(AssertionError):
|
||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"transform, min_max_values",
|
||||
"tf_type, tf_name, min_max_values",
|
||||
[
|
||||
("brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
def test_backward_compatibility_single_transforms(
|
||||
img_tensor, tf_type, tf_name, min_max_values, single_transforms
|
||||
):
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
actual = tf(img_tensor)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
|
||||
expected = single_transforms[key]
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
with seeded_context(1337):
|
||||
actual = default_tf(img_tensor)
|
||||
@@ -260,26 +333,36 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, n_examples",
|
||||
[
|
||||
("lerobot/aloha_sim_transfer_cube_human", 3),
|
||||
],
|
||||
)
|
||||
def test_visualize_image_transforms(repo_id, n_examples):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
|
||||
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
|
||||
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
|
||||
output_dir = output_dir / repo_id.split("/")[-1]
|
||||
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
n_examples = 3
|
||||
|
||||
# Check if the original frame image exists
|
||||
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."
|
||||
save_all_transforms(tf_cfg, img_tensor, tmp_path, n_examples)
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = tmp_path / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert any(
|
||||
combined_transforms_dir.iterdir()
|
||||
), "No transformed images found in combined transforms directory."
|
||||
for i in range(1, n_examples + 1):
|
||||
assert (
|
||||
combined_transforms_dir / f"{i}.png"
|
||||
).exists(), f"Combined transform image {i}.png was not found."
|
||||
|
||||
|
||||
def test_save_each_transform(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
n_examples = 3
|
||||
|
||||
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
|
||||
|
||||
# Check if the transformed images exist for each transform type
|
||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
||||
for transform in transforms:
|
||||
transform_dir = output_dir / transform
|
||||
transform_dir = tmp_path / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
|
||||
@@ -289,14 +372,3 @@ def test_visualize_image_transforms(repo_id, n_examples):
|
||||
assert (
|
||||
transform_dir / file_name
|
||||
).exists(), f"{file_name} was not found in {transform} directory."
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = output_dir / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert any(
|
||||
combined_transforms_dir.iterdir()
|
||||
), "No transformed images found in combined transforms directory."
|
||||
for i in range(1, n_examples + 1):
|
||||
assert (
|
||||
combined_transforms_dir / f"{i}.png"
|
||||
).exists(), f"Combined transform image {i}.png was not found."
|
||||
|
||||
@@ -20,73 +20,111 @@ from pathlib import Path
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.common.envs.factory import make_env, make_env_config
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.common.policies.factory import (
|
||||
_policy_cfg_from_hydra_cfg,
|
||||
get_policy_and_config_classes,
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
|
||||
# Create only one camera input which is squared to fit all current policy constraints
|
||||
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
|
||||
camera_features = {
|
||||
"observation.images.laptop": {
|
||||
"shape": (84, 84, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
}
|
||||
motor_features = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
}
|
||||
info = info_factory(
|
||||
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
|
||||
)
|
||||
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
||||
return ds_meta
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
assert policy_cls.name == policy_name
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
assert issubclass(
|
||||
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
|
||||
[
|
||||
("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
|
||||
("pusht", "diffusion", []),
|
||||
("pusht", "vqbet", []),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}),
|
||||
("lerobot/pusht", "pusht", {}, "diffusion", {}),
|
||||
("lerobot/pusht", "pusht", {}, "vqbet", {}),
|
||||
("lerobot/pusht", "pusht", {}, "act", {}),
|
||||
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_scripted",
|
||||
"aloha",
|
||||
{"task": "AlohaInsertion-v0"},
|
||||
"act",
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
(
|
||||
"aloha",
|
||||
{"task": "AlohaInsertion-v0"},
|
||||
"diffusion",
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
|
||||
{},
|
||||
),
|
||||
(
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
"aloha",
|
||||
{"task": "AlohaTransferCube-v0"},
|
||||
"act",
|
||||
{},
|
||||
),
|
||||
(
|
||||
"lerobot/aloha_sim_transfer_cube_scripted",
|
||||
"aloha",
|
||||
{"task": "AlohaTransferCube-v0"},
|
||||
"act",
|
||||
{},
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||
("dora_aloha_real", "act_real", []),
|
||||
("dora_aloha_real", "act_real_no_state", []),
|
||||
],
|
||||
)
|
||||
@require_env
|
||||
def test_policy(env_name, policy_name, extra_overrides):
|
||||
def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
"""
|
||||
Tests:
|
||||
- Making the policy object.
|
||||
@@ -99,53 +137,22 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||
and for now we add tests as we see fit.
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
+ extra_overrides,
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||
env=make_env_config(env_name, **env_kwargs),
|
||||
device=DEVICE,
|
||||
)
|
||||
|
||||
# Additional config override logic.
|
||||
if env_name == "aloha" and policy_name == "diffusion":
|
||||
for keys in [
|
||||
("training", "delta_timestamps"),
|
||||
("policy", "input_shapes"),
|
||||
("policy", "input_normalization_modes"),
|
||||
]:
|
||||
dct = dict(cfg[keys[0]][keys[1]])
|
||||
dct["observation.images.top"] = dct["observation.image"]
|
||||
del dct["observation.image"]
|
||||
cfg[keys[0]][keys[1]] = dct
|
||||
cfg.override_dataset_stats = None
|
||||
|
||||
# Additional config override logic.
|
||||
if env_name == "pusht" and policy_name == "act":
|
||||
for keys in [
|
||||
("policy", "input_shapes"),
|
||||
("policy", "input_normalization_modes"),
|
||||
]:
|
||||
dct = dict(cfg[keys[0]][keys[1]])
|
||||
dct["observation.image"] = dct["observation.images.top"]
|
||||
del dct["observation.images.top"]
|
||||
cfg[keys[0]][keys[1]] = dct
|
||||
cfg.override_dataset_stats = None
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
assert isinstance(policy, PyTorchModelHubMixin)
|
||||
dataset = make_dataset(train_cfg)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE)
|
||||
assert isinstance(policy, PreTrainedPolicy)
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, n_envs=2)
|
||||
env = make_env(train_cfg.env, n_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
@@ -172,7 +179,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
observation, _ = env.reset(seed=cfg.seed)
|
||||
observation, _ = env.reset(seed=train_cfg.seed)
|
||||
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation)
|
||||
@@ -195,65 +202,59 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
env.step(action)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
# TODO(rcadene, aliberts): This test is quite end-to-end. Move this test in test_optimizer?
|
||||
def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
"env=aloha",
|
||||
"policy=act",
|
||||
f"device={DEVICE}",
|
||||
"training.lr_backbone=0.001",
|
||||
"training.lr=0.01",
|
||||
],
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
device=DEVICE,
|
||||
)
|
||||
assert cfg.training.lr == 0.01
|
||||
assert cfg.training.lr_backbone == 0.001
|
||||
cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
assert cfg.policy.optimizer_lr == 0.01
|
||||
assert cfg.policy.optimizer_lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr
|
||||
assert optimizer.param_groups[1]["lr"] == cfg.policy.optimizer_lr_backbone
|
||||
assert len(optimizer.param_groups[0]["params"]) == 133
|
||||
assert len(optimizer.param_groups[1]["params"]) == 20
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_policy_defaults(policy_name: str):
|
||||
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
|
||||
"""Check that the policy can be instantiated with defaults."""
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy_cls()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name",
|
||||
[
|
||||
("xarm", "tdmpc"),
|
||||
("pusht", "diffusion"),
|
||||
("aloha", "act"),
|
||||
],
|
||||
)
|
||||
def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
||||
"""Check that dataclass configs match their respective yaml configs."""
|
||||
hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"])
|
||||
_, policy_cfg_cls = get_policy_and_config_classes(policy_name)
|
||||
policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg)
|
||||
policy_cfg_from_dataclass = policy_cfg_cls()
|
||||
assert policy_cfg_from_hydra == policy_cfg_from_dataclass
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
policy_cls(policy_cfg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_save_and_load_pretrained(policy_name: str):
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy: Policy = policy_cls()
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
|
||||
policy_cls = get_policy_class(policy_name)
|
||||
policy_cfg = make_policy_config(policy_name)
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
policy = policy_cls(policy_cfg)
|
||||
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
|
||||
|
||||
@@ -267,20 +268,27 @@ def test_normalize(insert_temporal_dim):
|
||||
expected.
|
||||
"""
|
||||
|
||||
input_shapes = {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [10],
|
||||
input_features = {
|
||||
"observation.image": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 96, 96),
|
||||
),
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
}
|
||||
output_shapes = {
|
||||
"action": [5],
|
||||
output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
),
|
||||
}
|
||||
|
||||
normalize_input_modes = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
}
|
||||
unnormalize_output_modes = {
|
||||
"action": "min_max",
|
||||
norm_map = {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
@@ -324,59 +332,76 @@ def test_normalize(insert_temporal_dim):
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
normalize = Normalize(input_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
new_normalize = Normalize(input_features, norm_map, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test without stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, policy_name, extra_overrides, file_name_extra",
|
||||
"ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra",
|
||||
[
|
||||
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||
# to test with `policy.use_mpc=false`.
|
||||
("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": False}, {"batch_size": 25}, "use_policy"),
|
||||
# ("lerobot/xarm_lift_medium", "xarm", "tdmpc", {"use_mpc": True}, {}, "use_mpc"),
|
||||
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
|
||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
||||
# Thus, we deactivate this test for now.
|
||||
# (
|
||||
# "lerobot/pusht",
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
# {
|
||||
# "n_action_steps": 8,
|
||||
# "num_inference_steps": 10,
|
||||
# "down_dims": [128, 256, 512],
|
||||
# },
|
||||
# {"batch_size": 64},
|
||||
# "",
|
||||
# ),
|
||||
("lerobot/aloha_sim_insertion_human", "aloha", "act", {"n_action_steps": 10}, {}, ""),
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
"",
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"aloha",
|
||||
"act",
|
||||
{"n_action_steps": 1000, "chunk_size": 1000},
|
||||
{},
|
||||
"_1000_steps",
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_aloha_real", ["policy.n_action_steps=10"], ""),
|
||||
],
|
||||
)
|
||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||
# pass if it's run on another platform due to floating point errors
|
||||
@require_x86_64_kernel
|
||||
@require_cpu
|
||||
def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra):
|
||||
def test_backward_compatibility(
|
||||
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs, file_name_extra
|
||||
):
|
||||
"""
|
||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||
@@ -397,16 +422,18 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
|
||||
saved_actions = load_file(env_policy_dir / "actions.safetensors")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs
|
||||
)
|
||||
|
||||
for key in saved_output_dict:
|
||||
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.allclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7)
|
||||
for key in saved_grad_stats:
|
||||
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.allclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7)
|
||||
for key in saved_param_stats:
|
||||
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
||||
assert torch.allclose(param_stats[key], saved_param_stats[key], rtol=0.1, atol=1e-7)
|
||||
for key in saved_actions:
|
||||
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.allclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7)
|
||||
|
||||
|
||||
def test_act_temporal_ensembler():
|
||||
@@ -462,7 +489,3 @@ def test_act_temporal_ensembler():
|
||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_act_temporal_ensembler()
|
||||
|
||||
@@ -229,11 +229,13 @@ def _mock_download_raw(raw_dir, repo_id):
|
||||
raise ValueError(repo_id)
|
||||
|
||||
|
||||
@pytest.mark.skip("push_dataset_to_hub is deprecated")
|
||||
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
|
||||
with pytest.raises(ValueError):
|
||||
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
|
||||
|
||||
|
||||
@pytest.mark.skip("push_dataset_to_hub is deprecated")
|
||||
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
tmpdir = Path(tmpdir)
|
||||
out_dir = tmpdir / "out"
|
||||
@@ -250,7 +252,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.skip("push_dataset_to_hub is deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id, make_test_data",
|
||||
[
|
||||
@@ -318,6 +320,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
||||
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
|
||||
|
||||
|
||||
@pytest.mark.skip("push_dataset_to_hub is deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"raw_format, repo_id",
|
||||
[
|
||||
@@ -329,9 +332,6 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
||||
("dora_parquet", "cadene/wrist_gripper"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
"Not compatible with our CI since it downloads raw datasets. Run with `python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||
)
|
||||
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||
_, dataset_id = repo_id.split("/")
|
||||
|
||||
|
||||
@@ -28,9 +28,9 @@ from pathlib import Path
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from tests.utils import TEST_ROBOT_TYPES, make_robot, mock_calibration_dir, require_robot
|
||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@@ -39,12 +39,12 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): add compatibility with other robots
|
||||
robot_kwargs = {"robot_type": robot_type}
|
||||
robot_kwargs = {"robot_type": robot_type, "mock": mock}
|
||||
|
||||
if robot_type == "aloha" and mock:
|
||||
# To simplify unit test, we do not rerun manual calibration for Aloha mock=True.
|
||||
# Instead, we use the files from '.cache/calibration/aloha_default'
|
||||
overrides_calibration_dir = None
|
||||
pass
|
||||
else:
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
@@ -52,18 +52,11 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||
# Create an empty calibration directory to trigger manual calibration
|
||||
tmpdir = Path(tmpdir)
|
||||
calibration_dir = tmpdir / robot_type
|
||||
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
|
||||
mock_calibration_dir(calibration_dir)
|
||||
robot_kwargs["calibration_dir"] = calibration_dir
|
||||
|
||||
# Test connecting without devices raises an error
|
||||
robot = ManipulatorRobot(**robot_kwargs)
|
||||
with pytest.raises(ValueError):
|
||||
robot.connect()
|
||||
del robot
|
||||
|
||||
# Test using robot before connecting raises an error
|
||||
robot = ManipulatorRobot(**robot_kwargs)
|
||||
robot = make_robot(**robot_kwargs)
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
robot.teleop_step()
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
@@ -79,7 +72,7 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||
del robot
|
||||
|
||||
# Test connecting (triggers manual calibration)
|
||||
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
|
||||
robot = make_robot(**robot_kwargs)
|
||||
robot.connect()
|
||||
assert robot.is_connected
|
||||
|
||||
@@ -92,9 +85,7 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||
robot.disconnect()
|
||||
|
||||
# Test teleop can run
|
||||
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
|
||||
if overrides_calibration_dir is not None:
|
||||
robot.calibration_dir = calibration_dir
|
||||
robot = make_robot(**robot_kwargs)
|
||||
robot.connect()
|
||||
robot.teleop_step()
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import random
|
||||
from typing import Callable
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -13,7 +12,6 @@ from lerobot.common.datasets.utils import (
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
get_global_random_state,
|
||||
init_hydra_config,
|
||||
seeded_context,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
@@ -70,10 +68,3 @@ def test_calculate_episode_data_index():
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_init_hydra_config_empty():
|
||||
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
|
||||
with open(test_file, "w") as f:
|
||||
f.write("\n")
|
||||
init_hydra_config(test_file)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
from copy import copy
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
@@ -25,18 +24,12 @@ import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot as make_robot_from_cfg
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Pass this as the first argument to init_hydra_config.
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
|
||||
ROBOT_CONFIG_PATH_TEMPLATE = "lerobot/configs/robot/{robot}.yaml"
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
|
||||
TEST_ROBOT_TYPES = []
|
||||
for robot_type in available_robots:
|
||||
@@ -52,7 +45,7 @@ for motor_type in available_motors:
|
||||
|
||||
# Camera indices used for connecting physical cameras
|
||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||
INTELREALSENSE_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_CAMERA_INDEX", 128422271614))
|
||||
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
||||
|
||||
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
||||
DYNAMIXEL_MOTORS = {
|
||||
@@ -309,79 +302,30 @@ def mock_calibration_dir(calibration_dir):
|
||||
json.dump(example_calib, f)
|
||||
|
||||
|
||||
def make_robot(robot_type: str, overrides: list[str] | None = None, mock=False) -> Robot:
|
||||
if mock:
|
||||
overrides = [] if overrides is None else copy(overrides)
|
||||
|
||||
# Explicitely add mock argument to the cameras and set it to true
|
||||
# TODO(rcadene, aliberts): redesign when we drop hydra
|
||||
if robot_type in ["koch", "so100", "moss"]:
|
||||
overrides.append("+leader_arms.main.mock=true")
|
||||
overrides.append("+follower_arms.main.mock=true")
|
||||
if "~cameras" not in overrides:
|
||||
overrides.append("+cameras.laptop.mock=true")
|
||||
overrides.append("+cameras.phone.mock=true")
|
||||
|
||||
elif robot_type == "koch_bimanual":
|
||||
overrides.append("+leader_arms.left.mock=true")
|
||||
overrides.append("+leader_arms.right.mock=true")
|
||||
overrides.append("+follower_arms.left.mock=true")
|
||||
overrides.append("+follower_arms.right.mock=true")
|
||||
if "~cameras" not in overrides:
|
||||
overrides.append("+cameras.laptop.mock=true")
|
||||
overrides.append("+cameras.phone.mock=true")
|
||||
|
||||
elif robot_type == "aloha":
|
||||
overrides.append("+leader_arms.left.mock=true")
|
||||
overrides.append("+leader_arms.right.mock=true")
|
||||
overrides.append("+follower_arms.left.mock=true")
|
||||
overrides.append("+follower_arms.right.mock=true")
|
||||
if "~cameras" not in overrides:
|
||||
overrides.append("+cameras.cam_high.mock=true")
|
||||
overrides.append("+cameras.cam_low.mock=true")
|
||||
overrides.append("+cameras.cam_left_wrist.mock=true")
|
||||
overrides.append("+cameras.cam_right_wrist.mock=true")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(robot_type)
|
||||
|
||||
config_path = ROBOT_CONFIG_PATH_TEMPLATE.format(robot=robot_type)
|
||||
robot_cfg = init_hydra_config(config_path, overrides)
|
||||
robot = make_robot_from_cfg(robot_cfg)
|
||||
return robot
|
||||
|
||||
|
||||
def make_camera(camera_type, **kwargs) -> Camera:
|
||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||
def make_camera(camera_type: str, **kwargs) -> Camera:
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
camera_index = kwargs.pop("camera_index", OPENCV_CAMERA_INDEX)
|
||||
return OpenCVCamera(camera_index, **kwargs)
|
||||
return make_camera_device(camera_type, camera_index=camera_index, **kwargs)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
|
||||
camera_index = kwargs.pop("camera_index", INTELREALSENSE_CAMERA_INDEX)
|
||||
return IntelRealSenseCamera(camera_index, **kwargs)
|
||||
|
||||
serial_number = kwargs.pop("serial_number", INTELREALSENSE_SERIAL_NUMBER)
|
||||
return make_camera_device(camera_type, serial_number=serial_number, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): remove this dark pattern that overrides
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
port = kwargs.pop("port", DYNAMIXEL_PORT)
|
||||
motors = kwargs.pop("motors", DYNAMIXEL_MOTORS)
|
||||
return DynamixelMotorsBus(port, motors, **kwargs)
|
||||
return make_motors_bus_device(motor_type, port=port, motors=motors, **kwargs)
|
||||
|
||||
elif motor_type == "feetech":
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
|
||||
port = kwargs.pop("port", FEETECH_PORT)
|
||||
motors = kwargs.pop("motors", FEETECH_MOTORS)
|
||||
return FeetechMotorsBus(port, motors, **kwargs)
|
||||
return make_motors_bus_device(motor_type, port=port, motors=motors, **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
|
||||
Reference in New Issue
Block a user