Files
issacdataengine/nimbus/components/store/base_writer.py
2026-03-16 11:44:10 +00:00

164 lines
7.6 KiB
Python

import time
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from nimbus.components.data.iterator import Iterator
from nimbus.components.data.observation import Observations
from nimbus.components.data.scene import Scene
from nimbus.components.data.sequence import Sequence
from nimbus.daemon import ComponentStatus, StatusReporter
from nimbus.utils.flags import is_debug_mode
from nimbus.utils.utils import unpack_iter_data
def run_batch(func, args):
for arg in args:
func(*arg)
class BaseWriter(Iterator):
"""
A base class for writing generated sequences and observations to disk. This class defines the structure for
writing data and tracking the writing process. It manages the current scene, success and total case counts,
and provides hooks for subclasses to implement specific data writing logic. The writer supports both synchronous
and asynchronous batch writing modes, allowing for efficient data handling in various scenarios.
Args:
data_iter (Iterator): An iterator that provides data to be written, typically containing scenes,
sequences, and observations.
seq_output_dir (str): The directory where generated sequences will be saved. Can be None
if sequence output is not needed.
obs_output_dir (str): The directory where generated observations will be saved. Can be None
if observation output is not needed.
batch_async (bool): If True, the writer will use asynchronous batch writing to improve performance
when handling large amounts of data. Default is True.
async_threshold (int): The maximum number of asynchronous write operations that can be in progress
at the same time. If the threshold is reached, the writer will wait for the oldest operation
to complete before starting a new one. Default is 1.
batch_size (int): The number of data items to write in each batch when using asynchronous writing.
Default is 2, and it will be capped at 8 to prevent potential issues with too many concurrent operations.
"""
def __init__(
self,
data_iter: Iterator[tuple[Scene, Sequence, Observations]],
seq_output_dir: str,
obs_output_dir: str,
batch_async: bool = True,
async_threshold: int = 1,
batch_size: int = 2,
):
super().__init__()
assert (
seq_output_dir is not None or obs_output_dir is not None
), "At least one output directory must be provided"
self.data_iter = data_iter
self.seq_output_dir = seq_output_dir
self.obs_output_dir = obs_output_dir
self.scene = None
self.async_mode = batch_async
self.batch_size = batch_size if batch_size <= 8 else 8
if batch_async and batch_size > self.batch_size:
self.logger.info("Batch size is larger than 8(probably cause program hang), batch size will be set to 8")
self.async_threshold = async_threshold
self.flush_executor = ThreadPoolExecutor(max_workers=max(1, 64 // self.batch_size))
self.flush_threads = []
self.data_buffer = []
self.logger.info(
f"Batch Async Write Mode: {self.async_mode}, async threshold: {self.async_threshold}, batch size:"
f" {self.batch_size}"
)
self.total_case = 0
self.success_case = 0
self.last_scene_key = None
self.status_reporter = StatusReporter(self.__class__.__name__)
def _next(self):
try:
data = next(self.data_iter)
scene, seq, obs = unpack_iter_data(data)
new_key = (scene.task_id, scene.name, scene.task_exec_num) if scene is not None else None
self.scene = scene
if new_key != self.last_scene_key:
if self.scene is not None and self.last_scene_key is not None:
self.logger.info(
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
)
self.success_case = 0
self.total_case = 0
self.last_scene_key = new_key
if self.scene is None:
return None
self.total_case += 1
self.status_reporter.update_status(ComponentStatus.RUNNING)
if seq is None and obs is None:
self.logger.info(f"generate failed, skip once! success rate: {self.success_case}/{self.total_case}")
self.scene.update_generate_status(success=False)
return None
scene_name = self.scene.name
io_start_time = time.time()
if self.async_mode:
cp_start_time = time.time()
cp = copy(self.scene.wf)
cp_end_time = time.time()
if self.scene.wf is not None:
self.logger.info(f"Scene {scene_name} workflow copy time: {cp_end_time - cp_start_time:.2f}s")
self.data_buffer.append((cp, scene_name, seq, obs))
if len(self.data_buffer) >= self.batch_size:
self.flush_threads = [t for t in self.flush_threads if not t.done()]
if len(self.flush_threads) >= self.async_threshold:
self.logger.info("Max async workers reached, waiting for the oldest thread to finish")
self.flush_threads[0].result()
self.flush_threads = self.flush_threads[1:]
to_flush_buffer = self.data_buffer.copy()
async_flush = self.flush_executor.submit(run_batch, self.flush_to_disk, to_flush_buffer)
if is_debug_mode():
async_flush.result() # surface exceptions immediately in debug mode
self.flush_threads.append(async_flush)
self.data_buffer = []
flush_length = len(obs) if obs is not None else len(seq)
else:
flush_length = self.flush_to_disk(self.scene.wf, scene_name, seq, obs)
self.success_case += 1
self.scene.update_generate_status(success=True)
self.collect_io_frame_info(flush_length, time.time() - io_start_time)
self.status_reporter.update_status(ComponentStatus.COMPLETED)
return None
except StopIteration:
if self.async_mode:
if len(self.data_buffer) > 0:
async_flush = self.flush_executor.submit(run_batch, self.flush_to_disk, self.data_buffer)
self.flush_threads.append(async_flush)
for thread in self.flush_threads:
thread.result()
if self.scene is not None:
self.logger.info(
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
)
raise StopIteration("no data")
except Exception as e:
self.logger.exception(f"Error during data writing: {e}")
raise e
def __del__(self):
for thread in self.flush_threads:
thread.result()
self.logger.info(f"Writer {len(self.flush_threads)} threads closed")
# Close the simulation app if it exists
if self.scene is not None and self.scene.simulation_app is not None:
self.logger.info("Closing simulation app")
self.scene.simulation_app.close()
@abstractmethod
def flush_to_disk(self, task, scene_name, seq, obs):
raise NotImplementedError("This method should be overridden by subclasses")