Add num_workers >=1 capabilities (default to 1)
This commit is contained in:
@@ -9,8 +9,9 @@ def get_arm_id(name, arm_type):
|
|||||||
|
|
||||||
|
|
||||||
class Robot(Protocol):
|
class Robot(Protocol):
|
||||||
def init_teleop(self): ...
|
def connect(self): ...
|
||||||
def run_calibration(self): ...
|
def activate_calibration(self): ...
|
||||||
def teleop_step(self, record_data=False): ...
|
def teleop_step(self, record_data=False): ...
|
||||||
def capture_observation(self): ...
|
def capture_observation(self): ...
|
||||||
def send_action(self, action): ...
|
def send_action(self, action): ...
|
||||||
|
def disconnect(self): ...
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ def is_headless():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def worker_save_frame_in_threads(frame_queue, num_image_writers):
|
def loop_to_save_frame_in_threads(frame_queue, num_image_writers):
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||||
futures = []
|
futures = []
|
||||||
while True:
|
while True:
|
||||||
@@ -258,13 +258,26 @@ def worker_save_frame_in_threads(frame_queue, num_image_writers):
|
|||||||
futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir))
|
futures.append(executor.submit(save_image, frame, key, frame_index, episode_index, videos_dir))
|
||||||
|
|
||||||
|
|
||||||
def start_frame_worker_process(frame_queue, num_image_writers):
|
def start_frame_workers(frame_queue, num_image_writers, num_workers=1):
|
||||||
process = multiprocessing.Process(
|
workers = []
|
||||||
target=worker_save_frame_in_threads,
|
for _ in range(num_workers):
|
||||||
args=(frame_queue, num_image_writers),
|
worker = multiprocessing.Process(
|
||||||
)
|
target=loop_to_save_frame_in_threads,
|
||||||
process.start()
|
args=(frame_queue, num_image_writers),
|
||||||
return process
|
)
|
||||||
|
worker.start()
|
||||||
|
workers.append(worker)
|
||||||
|
return workers
|
||||||
|
|
||||||
|
|
||||||
|
def stop_workers(workers, frame_queue):
|
||||||
|
# Send None to each process to signal it to stop
|
||||||
|
for _ in workers:
|
||||||
|
frame_queue.put(None)
|
||||||
|
|
||||||
|
# Wait for all processes to terminate
|
||||||
|
for process in workers:
|
||||||
|
process.join()
|
||||||
|
|
||||||
|
|
||||||
########################################################################################
|
########################################################################################
|
||||||
@@ -476,7 +489,7 @@ def record(
|
|||||||
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
num_image_writers = num_image_writers_per_camera * len(robot.cameras)
|
||||||
num_image_writers = max(num_image_writers, 1)
|
num_image_writers = max(num_image_writers, 1)
|
||||||
frame_queue = multiprocessing.Queue()
|
frame_queue = multiprocessing.Queue()
|
||||||
frame_worker = start_frame_worker_process(frame_queue, num_image_writers)
|
frame_workers = start_frame_workers(frame_queue, num_image_writers)
|
||||||
|
|
||||||
# Using `try` to exist smoothly if an exception is raised
|
# Using `try` to exist smoothly if an exception is raised
|
||||||
try:
|
try:
|
||||||
@@ -643,15 +656,11 @@ def record(
|
|||||||
listener.stop()
|
listener.stop()
|
||||||
|
|
||||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||||
# Send None to stop the worker process
|
stop_workers(frame_workers, frame_queue)
|
||||||
frame_queue.put(None)
|
|
||||||
frame_worker.join()
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# Send None to safely exit worker process and threads
|
stop_workers(frame_workers, frame_queue)
|
||||||
frame_queue.put(None)
|
|
||||||
frame_worker.join()
|
|
||||||
|
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
if not is_headless():
|
if not is_headless():
|
||||||
|
|||||||
Reference in New Issue
Block a user