diff --git a/lerobot/record.py b/lerobot/record.py index 004e6502..e8f81286 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -195,7 +195,7 @@ def record_loop( @draccus.wrap() -def record(cfg: RecordConfig): +def record(cfg: RecordConfig) -> LeRobotDataset: init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: diff --git a/lerobot/teleoperate.py b/lerobot/teleoperate.py index b9614a9b..0c421e2b 100644 --- a/lerobot/teleoperate.py +++ b/lerobot/teleoperate.py @@ -51,10 +51,13 @@ class TeleoperateConfig: display_data: bool = False -def teleop_loop(teleop: Teleoperator, robot: Robot, display_data: bool = False): +def teleop_loop( + teleop: Teleoperator, robot: Robot, display_data: bool = False, duration: float | None = None +): display_len = max(len(key) for key in robot.action_features) + start = time.perf_counter() while True: - start = time.perf_counter() + loop_start = time.perf_counter() action = teleop.get_action() if display_data: observation = robot.get_observation() @@ -68,18 +71,22 @@ def teleop_loop(teleop: Teleoperator, robot: Robot, display_data: bool = False): rr.log(f"action_{act}", rr.Scalar(val)) robot.send_action(action) - loop_s = time.perf_counter() - start + loop_s = time.perf_counter() - loop_start print("\n" + "-" * (display_len + 10)) print(f"{'NAME':<{display_len}} | {'NORM':>7}") for motor, value in action.items(): print(f"{motor:<{display_len}} | {value:>7.2f}") print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") + + if duration is not None and time.perf_counter() - start >= duration: + return + move_cursor_up(len(action) + 5) @draccus.wrap() -def control_robot(cfg: TeleoperateConfig): +def teleoperate(cfg: TeleoperateConfig): init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: @@ -92,7 +99,7 @@ def control_robot(cfg: TeleoperateConfig): robot.connect() try: - teleop_loop(teleop, robot, display_data=cfg.display_data) + teleop_loop(teleop, robot, display_data=cfg.display_data, duration=cfg.teleop_time_s) except KeyboardInterrupt: pass finally: @@ -103,4 +110,4 @@ def control_robot(cfg: TeleoperateConfig): if __name__ == "__main__": - control_robot() + teleoperate() diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py new file mode 100644 index 00000000..6cd80c35 --- /dev/null +++ b/tests/test_control_robot.py @@ -0,0 +1,82 @@ +import time + +from lerobot.record import DatasetRecordConfig, RecordConfig, record +from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay +from lerobot.teleoperate import TeleoperateConfig, teleoperate +from tests.fixtures.constants import DUMMY_REPO_ID +from tests.mocks.mock_robot import MockRobotConfig +from tests.mocks.mock_teleop import MockTeleopConfig + + +def test_teleoperate(): + robot_cfg = MockRobotConfig() + teleop_cfg = MockTeleopConfig() + expected_duration = 0.1 + cfg = TeleoperateConfig( + robot=robot_cfg, + teleop=teleop_cfg, + teleop_time_s=expected_duration, + ) + start = time.perf_counter() + teleoperate(cfg) + actual_duration = time.perf_counter() - start + + assert actual_duration <= expected_duration * 1.1 + + +def test_record(tmp_path): + robot_cfg = MockRobotConfig() + teleop_cfg = MockTeleopConfig() + dataset_cfg = DatasetRecordConfig( + repo_id=DUMMY_REPO_ID, + single_task="Dummy task", + root=tmp_path / "record", + num_episodes=1, + episode_time_s=0.1, + push_to_hub=False, + ) + cfg = RecordConfig( + robot=robot_cfg, + dataset=dataset_cfg, + teleop=teleop_cfg, + play_sounds=False, + ) + + dataset = record(cfg) + + assert dataset.fps == 30 + assert dataset.meta.total_episodes == dataset.num_episodes == 1 + assert dataset.meta.total_frames == dataset.num_frames == 3 + assert dataset.meta.total_tasks == 1 + + +def test_record_and_replay(tmp_path): + robot_cfg = MockRobotConfig() + teleop_cfg = MockTeleopConfig() + record_dataset_cfg = DatasetRecordConfig( + repo_id=DUMMY_REPO_ID, + single_task="Dummy task", + root=tmp_path / "record_and_replay", + num_episodes=1, + episode_time_s=0.1, + push_to_hub=False, + ) + record_cfg = RecordConfig( + robot=robot_cfg, + dataset=record_dataset_cfg, + teleop=teleop_cfg, + play_sounds=False, + ) + replay_dataset_cfg = DatasetReplayConfig( + repo_id=DUMMY_REPO_ID, + episode=0, + root=tmp_path / "record_and_replay", + ) + replay_cfg = ReplayConfig( + robot=robot_cfg, + dataset=replay_dataset_cfg, + play_sounds=False, + ) + + record(record_cfg) + replay(replay_cfg)