From 904aaa497c2267ab481078fc95eb13560ecf5bad Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 14 Oct 2024 15:15:28 +0200 Subject: [PATCH] Refactor -> control_loop() --- lerobot/common/datasets/populate_dataset.py | 2 +- .../{control_robot.py => control_utils.py} | 122 +++++++++++------- lerobot/scripts/control_robot.py | 61 ++++----- 3 files changed, 99 insertions(+), 86 deletions(-) rename lerobot/common/robot_devices/{control_robot.py => control_utils.py} (82%) diff --git a/lerobot/common/datasets/populate_dataset.py b/lerobot/common/datasets/populate_dataset.py index e4fd6c1c..df5d20e5 100644 --- a/lerobot/common/datasets/populate_dataset.py +++ b/lerobot/common/datasets/populate_dataset.py @@ -38,7 +38,7 @@ def safe_stop_image_writer(func): try: return func(*args, **kwargs) except Exception as e: - image_writer = kwargs["dataset"].get("image_writer") + image_writer = kwargs.get("dataset", {}).get("image_writer") if image_writer is not None: print("Waiting for image writer to terminate...") stop_image_writer(image_writer, timeout=20) diff --git a/lerobot/common/robot_devices/control_robot.py b/lerobot/common/robot_devices/control_utils.py similarity index 82% rename from lerobot/common/robot_devices/control_robot.py rename to lerobot/common/robot_devices/control_utils.py index f7c91c75..3b880fd3 100644 --- a/lerobot/common/robot_devices/control_robot.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -19,7 +19,7 @@ from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_ from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import busy_wait -from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, log_say, set_global_seed +from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed from lerobot.scripts.eval import get_pretrained_policy_path @@ -184,45 +184,28 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides, fps): return policy, fps, device, use_amp -def warmup_record(robot, events, enable_teloperation, warmup_time_s, display_cameras, play_sounds, fps): - # TODO(rcadene): refactor warmup_record and reset_environment - timestamp = 0 - start_warmup_t = time.perf_counter() - - if warmup_time_s > 0: - log_say("Warming up (no data recording)", play_sounds) - - while timestamp < warmup_time_s: - start_loop_t = time.perf_counter() - - if enable_teloperation: - observation, _ = robot.teleop_step(record_data=True) - else: - observation = robot.capture_observation() - - if display_cameras and not is_headless(): - image_keys = [key for key in observation if "image" in key] - for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) - - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - log_control_info(robot, dt_s, fps=fps) - - timestamp = time.perf_counter() - start_warmup_t - if events is not None and events["exit_early"]: - events["exit_early"] = False - break - - -@safe_stop_image_writer -def record_episode( - dataset, +def warmup_record( robot, events, + enable_teloperation, + warmup_time_s, + display_cameras, + fps, +): + control_loop( + robot=robot, + control_time_s=warmup_time_s, + display_cameras=display_cameras, + events=events, + fps=fps, + teleoperate=enable_teloperation, + ) + + +def record_episode( + robot, + dataset, + events, episode_time_s, display_cameras, policy, @@ -230,24 +213,65 @@ def record_episode( use_amp, fps, ): + control_loop( + robot=robot, + control_time_s=episode_time_s, + display_cameras=display_cameras, + dataset=dataset, + events=events, + policy=policy, + device=device, + use_amp=use_amp, + fps=fps, + teleoperate=policy is None, + ) + + +@safe_stop_image_writer +def control_loop( + robot, + control_time_s, + teleoperate=False, + display_cameras=False, + dataset=None, + events=None, + policy=None, + device=None, + use_amp=None, + fps=None, +): + # TODO(rcadene): Add option to record logs + if not robot.is_connected: + robot.connect() + + if events is None: + events = {} + + if teleoperate and policy is not None: + raise ValueError("When `teleoperate` is True, `policy` should be None.") + + if dataset is not None and fps is not None and dataset["fps"] != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + timestamp = 0 start_episode_t = time.perf_counter() - while timestamp < episode_time_s: + while timestamp < control_time_s: start_loop_t = time.perf_counter() - if policy is None: + if teleoperate: observation, action = robot.teleop_step(record_data=True) else: observation = robot.capture_observation() - pred_action = predict_action(observation, policy, device, use_amp) + if policy is not None: + pred_action = predict_action(observation, policy, device, use_amp) + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = robot.send_action(pred_action) + action = {"action": action} - # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. - action = robot.send_action(pred_action) - action = {"action": action} - - add_frame(dataset, observation, action) + if dataset is not None: + add_frame(dataset, observation, action) if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] @@ -262,7 +286,7 @@ def record_episode( log_control_info(robot, dt_s, fps=fps) timestamp = time.perf_counter() - start_episode_t - if events is not None and events["exit_early"]: + if events["exit_early"]: events["exit_early"] = False break @@ -282,7 +306,7 @@ def reset_environment(robot, events, reset_time_s): time.sleep(1) timestamp = time.perf_counter() - start_vencod_t pbar.update(1) - if events is not None and events["exit_early"]: + if events["exit_early"]: events["exit_early"] = False break diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 47e1f2b6..e1a40972 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -111,7 +111,8 @@ from lerobot.common.datasets.populate_dataset import ( init_dataset, save_current_episode, ) -from lerobot.common.robot_devices.control_robot import ( +from lerobot.common.robot_devices.control_utils import ( + control_loop, has_method, init_keyboard_listener, init_policy, @@ -177,25 +178,16 @@ def calibrate(robot: Robot, arms: list[str] | None): @safe_disconnect -def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None): - # TODO(rcadene): Add option to record logs - if not robot.is_connected: - robot.connect() - - start_teleop_t = time.perf_counter() - while True: - start_loop_t = time.perf_counter() - robot.teleop_step() - - if fps is not None: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - log_control_info(robot, dt_s, fps=fps) - - if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s: - break +def teleoperate( + robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False +): + control_loop( + robot, + control_time_s=teleop_time_s, + fps=fps, + teleoperate=True, + display_cameras=display_cameras, + ) @safe_disconnect @@ -254,7 +246,8 @@ def record( # 2. give times to the robot devices to connect and start synchronizing, # 3. place the cameras windows on screen enable_teleoperation = policy is None - warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, play_sounds, fps) + log_say("Warmup record", play_sounds) + warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps) if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() @@ -277,32 +270,21 @@ def record( fps=fps, ) - # In case stop recording is requested during `record_episode` - if events is not None and events["stop_recording"]: - save_current_episode(dataset) - break - # Execute a few seconds without recording to give time to manually reset the environment # Current code logic doesn't allow to teleoperate during this time. # TODO(rcadene): add an option to enable teleoperation during reset # Skip reset for the last episode to be recorded - if episode_index < num_episodes - 1: + if not events["stop_recording"] and ( + (episode_index < num_episodes - 1) or events["rerecord_episode"] + ): log_say("Reset the environment", play_sounds) reset_environment(robot, events, reset_time_s) - # In case stop recording is requested during `reset_environment` - if events is not None and events["stop_recording"]: - save_current_episode(dataset) - break - - if events is not None and events["rerecord_episode"]: + if events["rerecord_episode"]: log_say("Re-record episode", play_sounds) events["rerecord_episode"] = False events["exit_early"] = False delete_current_episode(dataset) - # Force reset - log_say("Reset the environment", play_sounds) - reset_environment(robot, events, reset_time_s) continue # Increment by one dataset["current_episode_index"] @@ -320,6 +302,7 @@ def record( def replay( robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True ): + # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset # TODO(rcadene): Add option to record logs local_dir = Path(root) / repo_id if not local_dir.exists(): @@ -378,6 +361,12 @@ if __name__ == "__main__": parser_teleop.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" ) + parser_teleop.add_argument( + "--display-cameras", + type=int, + default=1, + help="Display all cameras on screen (set to 1 to display or 0).", + ) parser_record = subparsers.add_parser("record", parents=[base_parser]) parser_record.add_argument(