forked from tangger/lerobot
WIP: add multiprocess
This commit is contained in:
@@ -102,6 +102,7 @@ import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
@@ -237,6 +238,35 @@ def is_headless():
|
||||
return True
|
||||
|
||||
|
||||
def worker_save_frame_in_threads(frame_queue, num_image_writers):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
# Blocks until a frame is available
|
||||
frame_data = frame_queue.get()
|
||||
|
||||
# Exit if we send None to stop the worker
|
||||
if frame_data is None:
|
||||
# Wait for all submitted futures to complete before exiting
|
||||
for _ in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||
):
|
||||
pass
|
||||
break
|
||||
|
||||
frame, key, frame_index, episode_index, videos_dir = frame_data
|
||||
futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
|
||||
def start_frame_worker_process(frame_queue, num_image_writers):
|
||||
process = multiprocessing.Process(
|
||||
target=worker_save_frame_in_threads,
|
||||
args=(frame_queue, num_image_writers),
|
||||
)
|
||||
process.start()
|
||||
return process
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
@@ -442,12 +472,14 @@ def record(
|
||||
|
||||
timestamp = time.perf_counter() - start_warmup_t
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
futures = []
|
||||
# Save images using a worker process with writer threads to reach high fps (30 and more)
|
||||
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
||||
num_image_writers = max(num_image_writers, 1)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
frame_queue = multiprocessing.Queue()
|
||||
frame_worker = start_frame_worker_process(frame_queue, num_image_writers)
|
||||
|
||||
# Using `try` to exist smoothly if an exception is raised
|
||||
try:
|
||||
# Start recording all episodes
|
||||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
@@ -468,11 +500,7 @@ def record(
|
||||
not_image_keys = [key for key in observation if "image" not in key]
|
||||
|
||||
for key in image_keys:
|
||||
futures += [
|
||||
executor.submit(
|
||||
save_image, observation[key], key, frame_index, episode_index, videos_dir
|
||||
)
|
||||
]
|
||||
frame_queue.put((observation[key], key, frame_index, episode_index, videos_dir))
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
@@ -615,11 +643,15 @@ def record(
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
for _ in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||
):
|
||||
pass
|
||||
break
|
||||
# Send None to stop the worker process
|
||||
frame_queue.put(None)
|
||||
frame_worker.join()
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
# Send None to safely exit worker process and threads
|
||||
frame_queue.put(None)
|
||||
frame_worker.join()
|
||||
|
||||
robot.disconnect()
|
||||
if not is_headless():
|
||||
|
||||
Reference in New Issue
Block a user