WIP: add multiprocess
This commit is contained in:
@@ -102,6 +102,7 @@ import argparse
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
@@ -237,6 +238,35 @@ def is_headless():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def worker_save_frame_in_threads(frame_queue, num_image_writers):
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||||
|
futures = []
|
||||||
|
while True:
|
||||||
|
# Blocks until a frame is available
|
||||||
|
frame_data = frame_queue.get()
|
||||||
|
|
||||||
|
# Exit if we send None to stop the worker
|
||||||
|
if frame_data is None:
|
||||||
|
# Wait for all submitted futures to complete before exiting
|
||||||
|
for _ in tqdm.tqdm(
|
||||||
|
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
|
||||||
|
frame, key, frame_index, episode_index, videos_dir = frame_data
|
||||||
|
futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def start_frame_worker_process(frame_queue, num_image_writers):
|
||||||
|
process = multiprocessing.Process(
|
||||||
|
target=worker_save_frame_in_threads,
|
||||||
|
args=(frame_queue, num_image_writers),
|
||||||
|
)
|
||||||
|
process.start()
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
# Control modes
|
# Control modes
|
||||||
########################################################################################
|
########################################################################################
|
||||||
@@ -442,12 +472,14 @@ def record(
|
|||||||
|
|
||||||
timestamp = time.perf_counter() - start_warmup_t
|
timestamp = time.perf_counter() - start_warmup_t
|
||||||
|
|
||||||
# Save images using threads to reach high fps (30 and more)
|
# Save images using a worker process with writer threads to reach high fps (30 and more)
|
||||||
# Using `with` to exist smoothly if an execption is raised.
|
|
||||||
futures = []
|
|
||||||
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
||||||
num_image_writers = max(num_image_writers, 1)
|
num_image_writers = max(num_image_writers, 1)
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
frame_queue = multiprocessing.Queue()
|
||||||
|
frame_worker = start_frame_worker_process(frame_queue, num_image_writers)
|
||||||
|
|
||||||
|
# Using `try` to exist smoothly if an exception is raised
|
||||||
|
try:
|
||||||
# Start recording all episodes
|
# Start recording all episodes
|
||||||
while episode_index < num_episodes:
|
while episode_index < num_episodes:
|
||||||
logging.info(f"Recording episode {episode_index}")
|
logging.info(f"Recording episode {episode_index}")
|
||||||
@@ -468,11 +500,7 @@ def record(
|
|||||||
not_image_keys = [key for key in observation if "image" not in key]
|
not_image_keys = [key for key in observation if "image" not in key]
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
futures += [
|
frame_queue.put((observation[key], key, frame_index, episode_index, videos_dir))
|
||||||
executor.submit(
|
|
||||||
save_image, observation[key], key, frame_index, episode_index, videos_dir
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if not is_headless():
|
if not is_headless():
|
||||||
image_keys = [key for key in observation if "image" in key]
|
image_keys = [key for key in observation if "image" in key]
|
||||||
@@ -615,11 +643,15 @@ def record(
|
|||||||
listener.stop()
|
listener.stop()
|
||||||
|
|
||||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||||
for _ in tqdm.tqdm(
|
# Send None to stop the worker process
|
||||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
frame_queue.put(None)
|
||||||
):
|
frame_worker.join()
|
||||||
pass
|
|
||||||
break
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
# Send None to safely exit worker process and threads
|
||||||
|
frame_queue.put(None)
|
||||||
|
frame_worker.join()
|
||||||
|
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
if not is_headless():
|
if not is_headless():
|
||||||
|
|||||||
@@ -92,12 +92,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||||||
if mock:
|
if mock:
|
||||||
request.getfixturevalue("patch_builtins_input")
|
request.getfixturevalue("patch_builtins_input")
|
||||||
|
|
||||||
if robot_type == "aloha":
|
|
||||||
pytest.skip("TODO(rcadene): enable test once aloha_real and act_aloha_real are merged")
|
|
||||||
|
|
||||||
env_name = "koch_real"
|
|
||||||
policy_name = "act_koch_real"
|
|
||||||
|
|
||||||
root = Path(tmpdir)
|
root = Path(tmpdir)
|
||||||
repo_id = "lerobot/debug"
|
repo_id = "lerobot/debug"
|
||||||
|
|
||||||
@@ -117,13 +111,28 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
|||||||
|
|
||||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||||
|
|
||||||
|
# TODO(rcadene, aliberts): rethink this design
|
||||||
|
if robot_type == "aloha":
|
||||||
|
env_name = "aloha_real"
|
||||||
|
policy_name = "act_aloha_real"
|
||||||
|
elif robot_type in ["koch", "koch_bimanual"]:
|
||||||
|
env_name = "koch_real"
|
||||||
|
policy_name = "act_koch_real"
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(robot_type)
|
||||||
|
|
||||||
|
overrides = [
|
||||||
|
f"env={env_name}",
|
||||||
|
f"policy={policy_name}",
|
||||||
|
f"device={DEVICE}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if robot_type == "koch_bimanual":
|
||||||
|
overrides += ["env.state_dim=12", "env.action_dim=12"]
|
||||||
|
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=[
|
overrides=overrides,
|
||||||
f"env={env_name}",
|
|
||||||
f"policy={policy_name}",
|
|
||||||
f"device={DEVICE}",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||||
|
|||||||
Reference in New Issue
Block a user