init commit
This commit is contained in:
163
nimbus/components/store/base_writer.py
Normal file
163
nimbus/components/store/base_writer.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import copy
|
||||
|
||||
from nimbus.components.data.iterator import Iterator
|
||||
from nimbus.components.data.observation import Observations
|
||||
from nimbus.components.data.scene import Scene
|
||||
from nimbus.components.data.sequence import Sequence
|
||||
from nimbus.daemon import ComponentStatus, StatusReporter
|
||||
from nimbus.utils.flags import is_debug_mode
|
||||
from nimbus.utils.utils import unpack_iter_data
|
||||
|
||||
|
||||
def run_batch(func, args):
|
||||
for arg in args:
|
||||
func(*arg)
|
||||
|
||||
|
||||
class BaseWriter(Iterator):
|
||||
"""
|
||||
A base class for writing generated sequences and observations to disk. This class defines the structure for
|
||||
writing data and tracking the writing process. It manages the current scene, success and total case counts,
|
||||
and provides hooks for subclasses to implement specific data writing logic. The writer supports both synchronous
|
||||
and asynchronous batch writing modes, allowing for efficient data handling in various scenarios.
|
||||
|
||||
Args:
|
||||
data_iter (Iterator): An iterator that provides data to be written, typically containing scenes,
|
||||
sequences, and observations.
|
||||
seq_output_dir (str): The directory where generated sequences will be saved. Can be None
|
||||
if sequence output is not needed.
|
||||
obs_output_dir (str): The directory where generated observations will be saved. Can be None
|
||||
if observation output is not needed.
|
||||
batch_async (bool): If True, the writer will use asynchronous batch writing to improve performance
|
||||
when handling large amounts of data. Default is True.
|
||||
async_threshold (int): The maximum number of asynchronous write operations that can be in progress
|
||||
at the same time. If the threshold is reached, the writer will wait for the oldest operation
|
||||
to complete before starting a new one. Default is 1.
|
||||
batch_size (int): The number of data items to write in each batch when using asynchronous writing.
|
||||
Default is 2, and it will be capped at 8 to prevent potential issues with too many concurrent operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_iter: Iterator[tuple[Scene, Sequence, Observations]],
|
||||
seq_output_dir: str,
|
||||
obs_output_dir: str,
|
||||
batch_async: bool = True,
|
||||
async_threshold: int = 1,
|
||||
batch_size: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
seq_output_dir is not None or obs_output_dir is not None
|
||||
), "At least one output directory must be provided"
|
||||
self.data_iter = data_iter
|
||||
self.seq_output_dir = seq_output_dir
|
||||
self.obs_output_dir = obs_output_dir
|
||||
self.scene = None
|
||||
self.async_mode = batch_async
|
||||
self.batch_size = batch_size if batch_size <= 8 else 8
|
||||
if batch_async and batch_size > self.batch_size:
|
||||
self.logger.info("Batch size is larger than 8(probably cause program hang), batch size will be set to 8")
|
||||
self.async_threshold = async_threshold
|
||||
self.flush_executor = ThreadPoolExecutor(max_workers=max(1, 64 // self.batch_size))
|
||||
self.flush_threads = []
|
||||
self.data_buffer = []
|
||||
self.logger.info(
|
||||
f"Batch Async Write Mode: {self.async_mode}, async threshold: {self.async_threshold}, batch size:"
|
||||
f" {self.batch_size}"
|
||||
)
|
||||
self.total_case = 0
|
||||
self.success_case = 0
|
||||
self.last_scene_key = None
|
||||
self.status_reporter = StatusReporter(self.__class__.__name__)
|
||||
|
||||
def _next(self):
|
||||
try:
|
||||
data = next(self.data_iter)
|
||||
scene, seq, obs = unpack_iter_data(data)
|
||||
|
||||
new_key = (scene.task_id, scene.name, scene.task_exec_num) if scene is not None else None
|
||||
|
||||
self.scene = scene
|
||||
|
||||
if new_key != self.last_scene_key:
|
||||
if self.scene is not None and self.last_scene_key is not None:
|
||||
self.logger.info(
|
||||
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
|
||||
)
|
||||
self.success_case = 0
|
||||
self.total_case = 0
|
||||
self.last_scene_key = new_key
|
||||
|
||||
if self.scene is None:
|
||||
return None
|
||||
|
||||
self.total_case += 1
|
||||
|
||||
self.status_reporter.update_status(ComponentStatus.RUNNING)
|
||||
if seq is None and obs is None:
|
||||
self.logger.info(f"generate failed, skip once! success rate: {self.success_case}/{self.total_case}")
|
||||
self.scene.update_generate_status(success=False)
|
||||
return None
|
||||
scene_name = self.scene.name
|
||||
io_start_time = time.time()
|
||||
if self.async_mode:
|
||||
cp_start_time = time.time()
|
||||
cp = copy(self.scene.wf)
|
||||
cp_end_time = time.time()
|
||||
if self.scene.wf is not None:
|
||||
self.logger.info(f"Scene {scene_name} workflow copy time: {cp_end_time - cp_start_time:.2f}s")
|
||||
self.data_buffer.append((cp, scene_name, seq, obs))
|
||||
if len(self.data_buffer) >= self.batch_size:
|
||||
self.flush_threads = [t for t in self.flush_threads if not t.done()]
|
||||
|
||||
if len(self.flush_threads) >= self.async_threshold:
|
||||
self.logger.info("Max async workers reached, waiting for the oldest thread to finish")
|
||||
self.flush_threads[0].result()
|
||||
self.flush_threads = self.flush_threads[1:]
|
||||
|
||||
to_flush_buffer = self.data_buffer.copy()
|
||||
async_flush = self.flush_executor.submit(run_batch, self.flush_to_disk, to_flush_buffer)
|
||||
if is_debug_mode():
|
||||
async_flush.result() # surface exceptions immediately in debug mode
|
||||
self.flush_threads.append(async_flush)
|
||||
self.data_buffer = []
|
||||
flush_length = len(obs) if obs is not None else len(seq)
|
||||
else:
|
||||
flush_length = self.flush_to_disk(self.scene.wf, scene_name, seq, obs)
|
||||
self.success_case += 1
|
||||
self.scene.update_generate_status(success=True)
|
||||
self.collect_io_frame_info(flush_length, time.time() - io_start_time)
|
||||
self.status_reporter.update_status(ComponentStatus.COMPLETED)
|
||||
return None
|
||||
except StopIteration:
|
||||
if self.async_mode:
|
||||
if len(self.data_buffer) > 0:
|
||||
async_flush = self.flush_executor.submit(run_batch, self.flush_to_disk, self.data_buffer)
|
||||
self.flush_threads.append(async_flush)
|
||||
for thread in self.flush_threads:
|
||||
thread.result()
|
||||
if self.scene is not None:
|
||||
self.logger.info(
|
||||
f"Scene {self.scene.name} generate finish, success rate: {self.success_case}/{self.total_case}"
|
||||
)
|
||||
raise StopIteration("no data")
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error during data writing: {e}")
|
||||
raise e
|
||||
|
||||
def __del__(self):
|
||||
for thread in self.flush_threads:
|
||||
thread.result()
|
||||
self.logger.info(f"Writer {len(self.flush_threads)} threads closed")
|
||||
# Close the simulation app if it exists
|
||||
if self.scene is not None and self.scene.simulation_app is not None:
|
||||
self.logger.info("Closing simulation app")
|
||||
self.scene.simulation_app.close()
|
||||
|
||||
@abstractmethod
|
||||
def flush_to_disk(self, task, scene_name, seq, obs):
|
||||
raise NotImplementedError("This method should be overridden by subclasses")
|
||||
Reference in New Issue
Block a user