init commit

This commit is contained in:
zyhe
2026-03-16 11:44:10 +00:00
commit 94384a93c9
552 changed files with 363038 additions and 0 deletions

66
nimbus/data_engine.py Normal file
View File

@@ -0,0 +1,66 @@
from time import time
from nimbus.dist_sim.head_node import HeadNode
from nimbus.scheduler.sches import gen_pipe, gen_scheduler
from nimbus.utils.logging import configure_logging
from nimbus.utils.random import set_all_seeds
from nimbus.utils.types import (
NAME,
SAFE_THRESHOLD,
STAGE_PIPE,
WORKER_SCHEDULE,
StageInput,
)
from nimbus.utils.utils import consume_stage
class DataEngine:
def __init__(self, config, master_seed=None):
if master_seed is not None:
master_seed = int(master_seed)
set_all_seeds(master_seed)
exp_name = config[NAME]
configure_logging(exp_name, config=config)
self._sche_list = gen_scheduler(config)
self._stage_input = StageInput()
def run(self):
for stage in self._sche_list:
self._stage_input = stage.run(self._stage_input)
consume_stage(self._stage_input)
class DistPipeDataEngine:
def __init__(self, config, master_seed=None):
self._sche_list = gen_scheduler(config)
self.config = config
self._stage_input = StageInput()
exp_name = config[NAME]
self.logger = configure_logging(exp_name, config=config)
master_seed = int(master_seed) if master_seed is not None else None
self.pipe_list = gen_pipe(config, self._sche_list, exp_name, master_seed=master_seed)
self.head_nodes = {}
def run(self):
self.logger.info("[DistPipeDataEngine]: %s", self.pipe_list)
st_time = time()
cur_pipe_queue = None
pre_worker_num = 0
worker_schedule = self.config[STAGE_PIPE].get(WORKER_SCHEDULE, False)
for idx, pipe in enumerate(self.pipe_list):
self.head_nodes[idx] = HeadNode(
cur_pipe_queue,
pipe,
pre_worker_num,
self.config[STAGE_PIPE][SAFE_THRESHOLD],
worker_schedule,
self.logger,
idx,
)
self.head_nodes[idx].run()
cur_pipe_queue = self.head_nodes[idx].result_queue()
pre_worker_num = len(pipe)
for _, value in self.head_nodes.items():
value.wait_stop()
et_time = time()
self.logger.info("execution duration: %s", et_time - st_time)