Add new test_control_robot

This commit is contained in:
Simon Alibert
2025-05-08 17:38:16 +02:00
parent f9db727647
commit 95f9b45418
3 changed files with 96 additions and 7 deletions

View File

@@ -195,7 +195,7 @@ def record_loop(
@draccus.wrap()
def record(cfg: RecordConfig):
def record(cfg: RecordConfig) -> LeRobotDataset:
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:

View File

@@ -51,10 +51,13 @@ class TeleoperateConfig:
display_data: bool = False
def teleop_loop(teleop: Teleoperator, robot: Robot, display_data: bool = False):
def teleop_loop(
teleop: Teleoperator, robot: Robot, display_data: bool = False, duration: float | None = None
):
display_len = max(len(key) for key in robot.action_features)
start = time.perf_counter()
while True:
start = time.perf_counter()
loop_start = time.perf_counter()
action = teleop.get_action()
if display_data:
observation = robot.get_observation()
@@ -68,18 +71,22 @@ def teleop_loop(teleop: Teleoperator, robot: Robot, display_data: bool = False):
rr.log(f"action_{act}", rr.Scalar(val))
robot.send_action(action)
loop_s = time.perf_counter() - start
loop_s = time.perf_counter() - loop_start
print("\n" + "-" * (display_len + 10))
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
for motor, value in action.items():
print(f"{motor:<{display_len}} | {value:>7.2f}")
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
if duration is not None and time.perf_counter() - start >= duration:
return
move_cursor_up(len(action) + 5)
@draccus.wrap()
def control_robot(cfg: TeleoperateConfig):
def teleoperate(cfg: TeleoperateConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
@@ -92,7 +99,7 @@ def control_robot(cfg: TeleoperateConfig):
robot.connect()
try:
teleop_loop(teleop, robot, display_data=cfg.display_data)
teleop_loop(teleop, robot, display_data=cfg.display_data, duration=cfg.teleop_time_s)
except KeyboardInterrupt:
pass
finally:
@@ -103,4 +110,4 @@ def control_robot(cfg: TeleoperateConfig):
if __name__ == "__main__":
control_robot()
teleoperate()

View File

@@ -0,0 +1,82 @@
import time
from lerobot.record import DatasetRecordConfig, RecordConfig, record
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
from lerobot.teleoperate import TeleoperateConfig, teleoperate
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig
from tests.mocks.mock_teleop import MockTeleopConfig
def test_teleoperate():
robot_cfg = MockRobotConfig()
teleop_cfg = MockTeleopConfig()
expected_duration = 0.1
cfg = TeleoperateConfig(
robot=robot_cfg,
teleop=teleop_cfg,
teleop_time_s=expected_duration,
)
start = time.perf_counter()
teleoperate(cfg)
actual_duration = time.perf_counter() - start
assert actual_duration <= expected_duration * 1.1
def test_record(tmp_path):
robot_cfg = MockRobotConfig()
teleop_cfg = MockTeleopConfig()
dataset_cfg = DatasetRecordConfig(
repo_id=DUMMY_REPO_ID,
single_task="Dummy task",
root=tmp_path / "record",
num_episodes=1,
episode_time_s=0.1,
push_to_hub=False,
)
cfg = RecordConfig(
robot=robot_cfg,
dataset=dataset_cfg,
teleop=teleop_cfg,
play_sounds=False,
)
dataset = record(cfg)
assert dataset.fps == 30
assert dataset.meta.total_episodes == dataset.num_episodes == 1
assert dataset.meta.total_frames == dataset.num_frames == 3
assert dataset.meta.total_tasks == 1
def test_record_and_replay(tmp_path):
robot_cfg = MockRobotConfig()
teleop_cfg = MockTeleopConfig()
record_dataset_cfg = DatasetRecordConfig(
repo_id=DUMMY_REPO_ID,
single_task="Dummy task",
root=tmp_path / "record_and_replay",
num_episodes=1,
episode_time_s=0.1,
push_to_hub=False,
)
record_cfg = RecordConfig(
robot=robot_cfg,
dataset=record_dataset_cfg,
teleop=teleop_cfg,
play_sounds=False,
)
replay_dataset_cfg = DatasetReplayConfig(
repo_id=DUMMY_REPO_ID,
episode=0,
root=tmp_path / "record_and_replay",
)
replay_cfg = ReplayConfig(
robot=robot_cfg,
dataset=replay_dataset_cfg,
play_sounds=False,
)
record(record_cfg)
replay(replay_cfg)