forked from tangger/lerobot
WIP: add multiprocess
This commit is contained in:
@@ -102,6 +102,7 @@ import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
@@ -237,6 +238,35 @@ def is_headless():
|
||||
return True
|
||||
|
||||
|
||||
def worker_save_frame_in_threads(frame_queue, num_image_writers):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
# Blocks until a frame is available
|
||||
frame_data = frame_queue.get()
|
||||
|
||||
# Exit if we send None to stop the worker
|
||||
if frame_data is None:
|
||||
# Wait for all submitted futures to complete before exiting
|
||||
for _ in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||
):
|
||||
pass
|
||||
break
|
||||
|
||||
frame, key, frame_index, episode_index, videos_dir = frame_data
|
||||
futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
|
||||
def start_frame_worker_process(frame_queue, num_image_writers):
|
||||
process = multiprocessing.Process(
|
||||
target=worker_save_frame_in_threads,
|
||||
args=(frame_queue, num_image_writers),
|
||||
)
|
||||
process.start()
|
||||
return process
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
@@ -442,12 +472,14 @@ def record(
|
||||
|
||||
timestamp = time.perf_counter() - start_warmup_t
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
futures = []
|
||||
# Save images using a worker process with writer threads to reach high fps (30 and more)
|
||||
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
||||
num_image_writers = max(num_image_writers, 1)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
frame_queue = multiprocessing.Queue()
|
||||
frame_worker = start_frame_worker_process(frame_queue, num_image_writers)
|
||||
|
||||
# Using `try` to exist smoothly if an exception is raised
|
||||
try:
|
||||
# Start recording all episodes
|
||||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
@@ -468,11 +500,7 @@ def record(
|
||||
not_image_keys = [key for key in observation if "image" not in key]
|
||||
|
||||
for key in image_keys:
|
||||
futures += [
|
||||
executor.submit(
|
||||
save_image, observation[key], key, frame_index, episode_index, videos_dir
|
||||
)
|
||||
]
|
||||
frame_queue.put((observation[key], key, frame_index, episode_index, videos_dir))
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
@@ -615,11 +643,15 @@ def record(
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
for _ in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||
):
|
||||
pass
|
||||
break
|
||||
# Send None to stop the worker process
|
||||
frame_queue.put(None)
|
||||
frame_worker.join()
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
# Send None to safely exit worker process and threads
|
||||
frame_queue.put(None)
|
||||
frame_worker.join()
|
||||
|
||||
robot.disconnect()
|
||||
if not is_headless():
|
||||
|
||||
@@ -92,12 +92,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
if robot_type == "aloha":
|
||||
pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged")
|
||||
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
|
||||
root = Path(tmpdir)
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
@@ -117,13 +111,28 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
|
||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||
|
||||
# TODO(rcadene, aliberts): rethink this design
|
||||
if robot_type == "aloha":
|
||||
env_name = "aloha_real"
|
||||
policy_name = "act_aloha_real"
|
||||
elif robot_type in ["koch", "koch_bimanual"]:
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
else:
|
||||
raise NotImplementedError(robot_type)
|
||||
|
||||
overrides = [
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
|
||||
if robot_type == "koch_bimanual":
|
||||
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
||||
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
],
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
|
||||
Reference in New Issue
Block a user