From 77ba43d25b34a515d739e6a298368b9e109229ae Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 28 Sep 2024 15:00:38 +0200 Subject: [PATCH] WIP: add multiprocess --- lerobot/scripts/control_robot.py | 60 ++++++++++++++++++++++++-------- tests/test_control_robot.py | 31 +++++++++++------ 2 files changed, 66 insertions(+), 25 deletions(-) diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index c07602a19..a0621457b 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -102,6 +102,7 @@ import argparse import concurrent.futures import json import logging +import multiprocessing import os import platform import shutil @@ -237,6 +238,35 @@ def is_headless(): return True +def worker_save_frame_in_threads(frame_queue, num_image_writers): + with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor: + futures = [] + while True: + # Blocks until a frame is available + frame_data = frame_queue.get() + + # Exit if we send None to stop the worker + if frame_data is None: + # Wait for all submitted futures to complete before exiting + for _ in tqdm.tqdm( + concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images" + ): + pass + break + + frame, key, frame_index, episode_index, videos_dir = frame_data + futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir)) + + +def start_frame_worker_process(frame_queue, num_image_writers): + process = multiprocessing.Process( + target=worker_save_frame_in_threads, + args=(frame_queue, num_image_writers), + ) + process.start() + return process + + ######################################################################################## # Control modes ######################################################################################## @@ -442,12 +472,14 @@ def record( timestamp = time.perf_counter() - start_warmup_t - # Save images using threads to reach high fps (30 and more) - # Using `with` to exist smoothly if an execption is raised. - futures = [] + # Save images using a worker process with writer threads to reach high fps (30 and more) num_image_writers = num_image_writers_per_camera * len(robot.cameras) num_image_writers = max(num_image_writers, 1) - with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor: + frame_queue = multiprocessing.Queue() + frame_worker = start_frame_worker_process(frame_queue, num_image_writers) + + # Using `try` to exist smoothly if an exception is raised + try: # Start recording all episodes while episode_index < num_episodes: logging.info(f"Recording episode {episode_index}") @@ -468,11 +500,7 @@ def record( not_image_keys = [key for key in observation if "image" not in key] for key in image_keys: - futures += [ - executor.submit( - save_image, observation[key], key, frame_index, episode_index, videos_dir - ) - ] + frame_queue.put((observation[key], key, frame_index, episode_index, videos_dir)) if not is_headless(): image_keys = [key for key in observation if "image" in key] @@ -615,11 +643,15 @@ def record( listener.stop() logging.info("Waiting for threads writing the images on disk to terminate...") - for _ in tqdm.tqdm( - concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images" - ): - pass - break + # Send None to stop the worker process + frame_queue.put(None) + frame_worker.join() + + except Exception: + traceback.print_exc() + # Send None to safely exit worker process and threads + frame_queue.put(None) + frame_worker.join() robot.disconnect() if not is_headless(): diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index fc679528e..452c709f6 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -92,12 +92,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): if mock: request.getfixturevalue("patch_builtins_input") - if robot_type == "aloha": - pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged") - - env_name = "koch_real" - policy_name = "act_koch_real" - root = Path(tmpdir) repo_id = "lerobot/debug" @@ -117,13 +111,28 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): replay(robot, episode=0, fps=30, root=root, repo_id=repo_id) + # TODO(rcadene, aliberts): rethink this design + if robot_type == "aloha": + env_name = "aloha_real" + policy_name = "act_aloha_real" + elif robot_type in ["koch", "koch_bimanual"]: + env_name = "koch_real" + policy_name = "act_koch_real" + else: + raise NotImplementedError(robot_type) + + overrides = [ + f"env={env_name}", + f"policy={policy_name}", + f"device={DEVICE}", + ] + + if robot_type == "koch_bimanual": + overrides += ["env.state_dim=12", "env.action_dim=12"] + cfg = init_hydra_config( DEFAULT_CONFIG_PATH, - overrides=[ - f"env={env_name}", - f"policy={policy_name}", - f"device={DEVICE}", - ], + overrides=overrides, ) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)