Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -30,17 +30,27 @@ from unittest.mock import patch
import pytest
from lerobot.common.logger import Logger
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
RecordControlConfig,
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.test_robots import make_robot
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_teleoperate(tmpdir, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
@@ -49,39 +59,44 @@ def test_teleoperate(tmpdir, request, robot_type, mock):
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
mock_calibration_dir(calibration_dir)
overrides = [f"calibration_dir={calibration_dir}"]
robot_kwargs["calibration_dir"] = calibration_dir
else:
# Use the default .cache/calibration folder when mock=False
overrides = None
pass
robot = make_robot(robot_type, overrides=overrides, mock=mock)
teleoperate(robot, teleop_time_s=1)
teleoperate(robot, fps=30, teleop_time_s=1)
teleoperate(robot, fps=60, teleop_time_s=1)
robot = make_robot(**robot_kwargs)
teleoperate(robot, TeleoperateControlConfig(teleop_time_s=1))
teleoperate(robot, TeleoperateControlConfig(fps=30, teleop_time_s=1))
teleoperate(robot, TeleoperateControlConfig(fps=60, teleop_time_s=1))
del robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_calibrate(tmpdir, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
overrides_calibration_dir = [f"calibration_dir={calibration_dir}"]
robot_kwargs["calibration_dir"] = calibration_dir
robot = make_robot(robot_type, overrides=overrides_calibration_dir, mock=mock)
calibrate(robot, arms=robot.available_arms)
robot = make_robot(**robot_kwargs)
calib_cfg = CalibrateControlConfig(arms=robot.available_arms)
calibrate(robot, calib_cfg)
del robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_record_without_cameras(tmpdir, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
# Avoid using cameras
overrides = ["~cameras"]
robot_kwargs["cameras"] = {}
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
@@ -90,33 +105,38 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = Path(tmpdir) / robot_type
mock_calibration_dir(calibration_dir)
overrides.append(f"calibration_dir={calibration_dir}")
robot_kwargs["calibration_dir"] = calibration_dir
else:
# Use the default .cache/calibration folder when mock=False
pass
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
robot = make_robot(robot_type, overrides=overrides, mock=mock)
record(
robot,
fps=30,
root=root,
robot = make_robot(**robot_kwargs)
rec_cfg = RecordControlConfig(
repo_id=repo_id,
single_task=single_task,
warmup_time_s=1,
root=root,
fps=30,
warmup_time_s=0.1,
episode_time_s=1,
reset_time_s=0.1,
num_episodes=2,
run_compute_stats=False,
push_to_hub=False,
video=False,
play_sounds=False,
)
record(robot, rec_cfg)
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
tmpdir = Path(tmpdir)
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
@@ -125,28 +145,24 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
mock_calibration_dir(calibration_dir)
overrides = [f"calibration_dir={calibration_dir}"]
robot_kwargs["calibration_dir"] = calibration_dir
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = None
# Use the default .cache/calibration folder when mock=False
pass
env_name = "koch_real"
policy_name = "act_koch_real"
repo_id = "lerobot/debug"
repo_id = "lerobot_test/debug"
root = tmpdir / "data" / repo_id
single_task = "Do something."
robot = make_robot(robot_type, overrides=overrides, mock=mock)
dataset = record(
robot,
root,
repo_id,
single_task,
robot = make_robot(**robot_kwargs)
rec_cfg = RecordControlConfig(
repo_id=repo_id,
single_task=single_task,
root=root,
fps=1,
warmup_time_s=0.5,
warmup_time_s=0.1,
episode_time_s=1,
reset_time_s=1,
reset_time_s=0.1,
num_episodes=2,
push_to_hub=False,
# TODO(rcadene, aliberts): test video=True
@@ -155,56 +171,34 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
)
dataset = record(robot, rec_cfg)
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True)
# TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha":
env_name = "aloha_real"
policy_name = "act_aloha_real"
elif robot_type in ["koch", "koch_bimanual"]:
env_name = "koch_real"
policy_name = "act_koch_real"
elif robot_type == "so100":
env_name = "so100_real"
policy_name = "act_so100_real"
elif robot_type == "moss":
env_name = "moss_real"
policy_name = "act_moss_real"
else:
raise NotImplementedError(robot_type)
overrides = [
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
if robot_type == "koch_bimanual":
overrides += ["env.state_dim=12", "env.action_dim=12"]
overrides += ["wandb.enable=false"]
overrides += ["env.fps=1"]
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=overrides,
replay_cfg = ReplayControlConfig(
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True
)
replay(robot, replay_cfg)
policy_cfg = ACTConfig()
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
out_dir = tmpdir / "logger"
logger = Logger(cfg, out_dir, wandb_job_name="debug")
logger.save_checkpoint(
0,
policy,
optimizer,
lr_scheduler,
identifier=0,
ds_cfg = DatasetConfig(repo_id, local_files_only=True)
train_cfg = TrainPipelineConfig(
dataset=ds_cfg,
policy=policy_cfg,
output_dir=out_dir,
device=DEVICE,
)
pretrained_policy_name_or_path = out_dir / "checkpoints/last/pretrained_model"
logger = Logger(train_cfg)
logger.save_checkpoint(
train_step=0,
identifier=0,
policy=policy,
)
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
# during inference, to reach constent fps, so we test this here.
@@ -230,15 +224,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
eval_repo_id = "lerobot/eval_debug"
eval_root = tmpdir / "data" / eval_repo_id
dataset = record(
robot,
eval_root,
eval_repo_id,
single_task,
pretrained_policy_name_or_path,
warmup_time_s=1,
rec_eval_cfg = RecordControlConfig(
repo_id=eval_repo_id,
root=eval_root,
single_task=single_task,
fps=1,
warmup_time_s=0.1,
episode_time_s=1,
reset_time_s=1,
reset_time_s=0.1,
num_episodes=2,
run_compute_stats=False,
push_to_hub=False,
@@ -246,8 +239,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
num_image_writer_processes=num_image_writer_processes,
device=DEVICE,
use_amp=False,
)
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)
rec_eval_cfg.policy.pretrained_path = pretrained_policy_path
dataset = record(robot, rec_eval_cfg)
assert dataset.num_episodes == 2
assert len(dataset) == 2
@@ -257,6 +256,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_resume_record(tmpdir, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
@@ -264,48 +265,50 @@ def test_resume_record(tmpdir, request, robot_type, mock):
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
mock_calibration_dir(calibration_dir)
overrides = [f"calibration_dir={calibration_dir}"]
robot_kwargs["calibration_dir"] = calibration_dir
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
# Use the default .cache/calibration folder when mock=False
pass
robot = make_robot(robot_type, overrides=overrides, mock=mock)
robot = make_robot(**robot_kwargs)
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
record_kwargs = {
"robot": robot,
"root": root,
"repo_id": repo_id,
"single_task": single_task,
"fps": 1,
"warmup_time_s": 0,
"episode_time_s": 1,
"push_to_hub": False,
"video": False,
"display_cameras": False,
"play_sounds": False,
"run_compute_stats": False,
"local_files_only": True,
"num_episodes": 1,
}
rec_cfg = RecordControlConfig(
repo_id=repo_id,
root=root,
single_task=single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
local_files_only=True,
num_episodes=1,
)
dataset = record(**record_kwargs)
dataset = record(robot, rec_cfg)
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
with pytest.raises(FileExistsError):
# Dataset already exists, but resume=False by default
record(**record_kwargs)
record(robot, rec_cfg)
dataset = record(**record_kwargs, resume=True)
rec_cfg.resume = True
dataset = record(robot, rec_cfg)
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha":
request.getfixturevalue("patch_builtins_input")
@@ -313,12 +316,13 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
mock_calibration_dir(calibration_dir)
overrides = [f"calibration_dir={calibration_dir}"]
robot_kwargs["calibration_dir"] = calibration_dir
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
# Use the default .cache/calibration folder when mock=False
pass
robot = make_robot(**robot_kwargs)
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
@@ -330,11 +334,10 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
single_task,
rec_cfg = RecordControlConfig(
repo_id=repo_id,
root=root,
single_task=single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
@@ -345,6 +348,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
play_sounds=False,
run_compute_stats=False,
)
dataset = record(robot, rec_cfg)
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
@@ -354,6 +358,8 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock:
request.getfixturevalue("patch_builtins_input")
@@ -361,12 +367,13 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
mock_calibration_dir(calibration_dir)
overrides = [f"calibration_dir={calibration_dir}"]
robot_kwargs["calibration_dir"] = calibration_dir
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
# Use the default .cache/calibration folder when mock=False
pass
robot = make_robot(**robot_kwargs)
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
@@ -378,12 +385,11 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
fps=2,
rec_cfg = RecordControlConfig(
repo_id=repo_id,
root=root,
single_task=single_task,
repo_id=repo_id,
fps=2,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
@@ -394,6 +400,8 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
run_compute_stats=False,
)
dataset = record(robot, rec_cfg)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@@ -403,6 +411,8 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
)
@require_robot
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock:
request.getfixturevalue("patch_builtins_input")
@@ -410,12 +420,13 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
mock_calibration_dir(calibration_dir)
overrides = [f"calibration_dir={calibration_dir}"]
robot_kwargs["calibration_dir"] = calibration_dir
else:
# Use the default .cache/calibration folder when mock=False or for aloha
overrides = []
# Use the default .cache/calibration folder when mock=False
pass
robot = make_robot(**robot_kwargs)
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
@@ -427,14 +438,14 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
rec_cfg = RecordControlConfig(
repo_id=repo_id,
root=root,
single_task=single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
reset_time_s=0.1,
num_episodes=2,
push_to_hub=False,
video=False,
@@ -444,5 +455,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
num_image_writer_processes=num_image_writer_processes,
)
dataset = record(robot, rec_cfg)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"