WIP: add multiprocess

This commit is contained in:
Remi Cadene
2024-09-28 15:00:38 +02:00
parent 9b76ee9eb0
commit 77ba43d25b
2 changed files with 66 additions and 25 deletions

View File

@@ -102,6 +102,7 @@ import argparse
import concurrent.futures
import json
import logging
import multiprocessing
import os
import platform
import shutil
@@ -237,6 +238,35 @@ def is_headless():
return True
def worker_save_frame_in_threads(frame_queue, num_image_writers):
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
futures = []
while True:
# Blocks until a frame is available
frame_data = frame_queue.get()
# Exit if we send None to stop the worker
if frame_data is None:
# Wait for all submitted futures to complete before exiting
for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
):
pass
break
frame, key, frame_index, episode_index, videos_dir = frame_data
futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir))
def start_frame_worker_process(frame_queue, num_image_writers):
process = multiprocessing.Process(
target=worker_save_frame_in_threads,
args=(frame_queue, num_image_writers),
)
process.start()
return process
########################################################################################
# Control modes
########################################################################################
@@ -442,12 +472,14 @@ def record(
timestamp = time.perf_counter() - start_warmup_t
# Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised.
futures = []
# Save images using a worker process with writer threads to reach high fps (30 and more)
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
num_image_writers = max(num_image_writers, 1)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
frame_queue = multiprocessing.Queue()
frame_worker = start_frame_worker_process(frame_queue, num_image_writers)
# Using `try` to exist smoothly if an exception is raised
try:
# Start recording all episodes
while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}")
@@ -468,11 +500,7 @@ def record(
not_image_keys = [key for key in observation if "image" not in key]
for key in image_keys:
futures += [
executor.submit(
save_image, observation[key], key, frame_index, episode_index, videos_dir
)
]
frame_queue.put((observation[key], key, frame_index, episode_index, videos_dir))
if not is_headless():
image_keys = [key for key in observation if "image" in key]
@@ -615,11 +643,15 @@ def record(
listener.stop()
logging.info("Waiting for threads writing the images on disk to terminate...")
for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
):
pass
break
# Send None to stop the worker process
frame_queue.put(None)
frame_worker.join()
except Exception:
traceback.print_exc()
# Send None to safely exit worker process and threads
frame_queue.put(None)
frame_worker.join()
robot.disconnect()
if not is_headless():

View File

@@ -92,12 +92,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
if mock:
request.getfixturevalue("patch_builtins_input")
if robot_type == "aloha":
pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged")
env_name = "koch_real"
policy_name = "act_koch_real"
root = Path(tmpdir)
repo_id = "lerobot/debug"
@@ -117,13 +111,28 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
# TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha":
env_name = "aloha_real"
policy_name = "act_aloha_real"
elif robot_type in ["koch", "koch_bimanual"]:
env_name = "koch_real"
policy_name = "act_koch_real"
else:
raise NotImplementedError(robot_type)
overrides = [
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
if robot_type == "koch_bimanual":
overrides += ["env.state_dim=12", "env.action_dim=12"]
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
],
overrides=overrides,
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)