164 lines
7.6 KiB
Python
164 lines
7.6 KiB
Python
import time
|
|
from abc import abstractmethod
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from copy import copy
|
|
|
|
from nimbus.components.data.iterator import Iterator
|
|
from nimbus.components.data.observation import Observations
|
|
from nimbus.components.data.scene import Scene
|
|
from nimbus.components.data.sequence import Sequence
|
|
from nimbus.daemon import ComponentStatus, StatusReporter
|
|
from nimbus.utils.flags import is_debug_mode
|
|
from nimbus.utils.utils import unpack_iter_data
|
|
|
|
|
|
def run_batch(func, args):
|
|
for arg in args:
|
|
func(*arg)
|
|
|
|
|
|
class BaseWriter(Iterator):
|
|
"""
|
|
A base class for writing generated sequences and observations to disk. This class defines the structure for
|
|
writing data and tracking the writing process. It manages the current scene, success and total case counts,
|
|
and provides hooks for subclasses to implement specific data writing logic. The writer supports both synchronous
|
|
and asynchronous batch writing modes, allowing for efficient data handling in various scenarios.
|
|
|
|
Args:
|
|
data_iter (Iterator): An iterator that provides data to be written, typically containing scenes,
|
|
sequences, and observations.
|
|
seq_output_dir (str): The directory where generated sequences will be saved. Can be None
|
|
if sequence output is not needed.
|
|
obs_output_dir (str): The directory where generated observations will be saved. Can be None
|
|
if observation output is not needed.
|
|
batch_async (bool): If True, the writer will use asynchronous batch writing to improve performance
|
|
when handling large amounts of data. Default is True.
|
|
async_threshold (int): The maximum number of asynchronous write operations that can be in progress
|
|
at the same time. If the threshold is reached, the writer will wait for the oldest operation
|
|
to complete before starting a new one. Default is 1.
|
|
batch_size (int): The number of data items to write in each batch when using asynchronous writing.
|
|
Default is 2, and it will be capped at 8 to prevent potential issues with too many concurrent operations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_iter: Iterator[tuple[Scene, Sequence, Observations]],
|
|
seq_output_dir: str,
|
|
obs_output_dir: str,
|
|
batch_async: bool = True,
|
|
async_threshold: int = 1,
|
|
batch_size: int = 2,
|
|
):
|
|
super().__init__()
|
|
assert (
|
|
seq_output_dir is not None or obs_output_dir is not None
|
|
), "At least one output directory must be provided"
|
|
self.data_iter = data_iter
|
|
self.seq_output_dir = seq_output_dir
|
|
self.obs_output_dir = obs_output_dir
|
|
self.scene = None
|
|
self.async_mode = batch_async
|
|
self.batch_size = batch_size if batch_size <= 8 else 8
|
|
if batch_async and batch_size > self.batch_size:
|
|
self.logger.info("Batch size is larger than 8(probably cause program hang), batch size will be set to 8")
|
|
self.async_threshold = async_threshold
|
|
self.flush_executor = ThreadPoolExecutor(max_workers=max(1, 64 // self.batch_size))
|
|
self.flush_threads = []
|
|
self.data_buffer = []
|
|
self.logger.info(
|
|
f"Batch Async Write Mode: {self.async_mode}, async threshold: {self.async_threshold}, batch size:"
|
|
f" {self.batch_size}"
|
|
)
|
|
self.total_case = 0
|
|
self.success_case = 0
|
|
self.last_scene_key = None
|
|
self.status_reporter = StatusReporter(self.__class__.__name__)
|
|
|
|
def _next(self):
|
|
try:
|
|
data = next(self.data_iter)
|
|
scene, seq, obs = unpack_iter_data(data)
|
|
|
|
new_key = (scene.task_id, scene.name, scene.task_exec_num) if scene is not None else None
|
|
|
|
self.scene = scene
|
|
|
|
if new_key != self.last_scene_key:
|
|
if self.scene is not None and self.last_scene_key is not None:
|
|
self.logger.info(
|
|
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
|
|
)
|
|
self.success_case = 0
|
|
self.total_case = 0
|
|
self.last_scene_key = new_key
|
|
|
|
if self.scene is None:
|
|
return None
|
|
|
|
self.total_case += 1
|
|
|
|
self.status_reporter.update_status(ComponentStatus.RUNNING)
|
|
if seq is None and obs is None:
|
|
self.logger.info(f"generate failed, skip once! success rate: {self.success_case}/{self.total_case}")
|
|
self.scene.update_generate_status(success=False)
|
|
return None
|
|
scene_name = self.scene.name
|
|
io_start_time = time.time()
|
|
if self.async_mode:
|
|
cp_start_time = time.time()
|
|
cp = copy(self.scene.wf)
|
|
cp_end_time = time.time()
|
|
if self.scene.wf is not None:
|
|
self.logger.info(f"Scene {scene_name} workflow copy time: {cp_end_time - cp_start_time:.2f}s")
|
|
self.data_buffer.append((cp, scene_name, seq, obs))
|
|
if len(self.data_buffer) >= self.batch_size:
|
|
self.flush_threads = [t for t in self.flush_threads if not t.done()]
|
|
|
|
if len(self.flush_threads) >= self.async_threshold:
|
|
self.logger.info("Max async workers reached, waiting for the oldest thread to finish")
|
|
self.flush_threads[0].result()
|
|
self.flush_threads = self.flush_threads[1:]
|
|
|
|
to_flush_buffer = self.data_buffer.copy()
|
|
async_flush = self.flush_executor.submit(run_batch, self.flush_to_disk, to_flush_buffer)
|
|
if is_debug_mode():
|
|
async_flush.result() # surface exceptions immediately in debug mode
|
|
self.flush_threads.append(async_flush)
|
|
self.data_buffer = []
|
|
flush_length = len(obs) if obs is not None else len(seq)
|
|
else:
|
|
flush_length = self.flush_to_disk(self.scene.wf, scene_name, seq, obs)
|
|
self.success_case += 1
|
|
self.scene.update_generate_status(success=True)
|
|
self.collect_io_frame_info(flush_length, time.time() - io_start_time)
|
|
self.status_reporter.update_status(ComponentStatus.COMPLETED)
|
|
return None
|
|
except StopIteration:
|
|
if self.async_mode:
|
|
if len(self.data_buffer) > 0:
|
|
async_flush = self.flush_executor.submit(run_batch, self.flush_to_disk, self.data_buffer)
|
|
self.flush_threads.append(async_flush)
|
|
for thread in self.flush_threads:
|
|
thread.result()
|
|
if self.scene is not None:
|
|
self.logger.info(
|
|
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
|
|
)
|
|
raise StopIteration("no data")
|
|
except Exception as e:
|
|
self.logger.exception(f"Error during data writing: {e}")
|
|
raise e
|
|
|
|
def __del__(self):
|
|
for thread in self.flush_threads:
|
|
thread.result()
|
|
self.logger.info(f"Writer {len(self.flush_threads)} threads closed")
|
|
# Close the simulation app if it exists
|
|
if self.scene is not None and self.scene.simulation_app is not None:
|
|
self.logger.info("Closing simulation app")
|
|
self.scene.simulation_app.close()
|
|
|
|
@abstractmethod
|
|
def flush_to_disk(self, task, scene_name, seq, obs):
|
|
raise NotImplementedError("This method should be overridden by subclasses")
|