diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py index 122155f78..157afd8bb 100644 --- a/lerobot/common/robot_devices/robots/utils.py +++ b/lerobot/common/robot_devices/robots/utils.py @@ -9,8 +9,9 @@ def get_arm_id(name, arm_type): class Robot(Protocol): - def init_teleop(self): ... - def run_calibration(self): ... + def connect(self): ... + def activate_calibration(self): ... def teleop_step(self, record_data=False): ... def capture_observation(self): ... def send_action(self, action): ... + def disconnect(self): ... diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index a0621457b..ca486105d 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -238,7 +238,7 @@ def is_headless(): return True -def worker_save_frame_in_threads(frame_queue, num_image_writers): +def loop_to_save_frame_in_threads(frame_queue, num_image_writers): with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor: futures = [] while True: @@ -258,13 +258,26 @@ def worker_save_frame_in_threads(frame_queue, num_image_writers): futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir)) -def start_frame_worker_process(frame_queue, num_image_writers): - process = multiprocessing.Process( - target=worker_save_frame_in_threads, - args=(frame_queue, num_image_writers), - ) - process.start() - return process +def start_frame_workers(frame_queue, num_image_writers, num_workers=1): + workers = [] + for _ in range(num_workers): + worker = multiprocessing.Process( + target=loop_to_save_frame_in_threads, + args=(frame_queue, num_image_writers), + ) + worker.start() + workers.append(worker) + return workers + + +def stop_workers(workers, frame_queue): + # Send None to each process to signal it to stop + for _ in workers: + frame_queue.put(None) + + # Wait for all processes to terminate + for process in workers: + process.join() ######################################################################################## @@ -476,7 +489,7 @@ def record( num_image_writers = num_image_writers_per_camera * len(robot.cameras) num_image_writers = max(num_image_writers, 1) frame_queue = multiprocessing.Queue() - frame_worker = start_frame_worker_process(frame_queue, num_image_writers) + frame_workers = start_frame_workers(frame_queue, num_image_writers) # Using `try` to exist smoothly if an exception is raised try: @@ -643,15 +656,11 @@ def record( listener.stop() logging.info("Waiting for threads writing the images on disk to terminate...") - # Send None to stop the worker process - frame_queue.put(None) - frame_worker.join() + stop_workers(frame_workers, frame_queue) except Exception: traceback.print_exc() - # Send None to safely exit worker process and threads - frame_queue.put(None) - frame_worker.join() + stop_workers(frame_workers, frame_queue) robot.disconnect() if not is_headless():