Files
issacdataengine/workflows/base.py
2026-03-16 11:44:10 +00:00

204 lines
6.5 KiB
Python

from abc import ABC, abstractmethod
from copy import deepcopy
from functools import wraps
class NimbusWorkFlow(ABC):
workflows = {}
# pylint: disable=W0613
def __init__(self, world, task_cfg_path: str, **kwargs):
"""Initialize the workflow.
Args:
world: The simulation world instance.
task_cfg_path (str): Path to the task configuration file.
Each workflow subclass is responsible for parsing this file.
**kwargs: Workflow-specific parameters.
Subclasses declare only the kwargs they need; unused ones are silently ignored.
"""
self.world = world
self.task_cfg_path = task_cfg_path
self.task_cfgs = self.parse_task_cfgs(task_cfg_path)
def init_task(self, index, need_preload: bool = True):
assert index < len(self.task_cfgs), "Index out of range for task configurations."
self.task_cfg = self.task_cfgs[index]
self.reset(need_preload)
def __copy__(self):
new_wf = type(self).__new__(type(self))
new_wf.__dict__.update(self.__dict__)
if hasattr(self, "logger"):
new_wf.logger = deepcopy(self.logger)
if hasattr(self, "recoder"):
new_wf.recoder = deepcopy(self.recoder)
return new_wf
@abstractmethod
def parse_task_cfgs(self, task_cfg_path) -> list:
"""
Parse the task configuration file.
Args:
task_cfg_path (str): Path to the task configuration file.
Returns:
list: List of task configurations.
"""
pass
@abstractmethod
def get_task_name(self) -> str:
"""Get the name of the current task.
Returns:
str: name of the current task
"""
pass
@abstractmethod
def reset(self, need_preload):
"""Reset the environment to the initial state of the current task.
Args:
need_preload (bool): Whether to preload objects in the environment. Defaults to True.
"""
pass
@abstractmethod
def randomization(self, layout_path=None) -> bool:
"""Randomize the environment layout in one task.
Args:
layout_path (str, optional): Path to the layout file. Defaults to None.
Returns:
bool: True if randomization is successful, False otherwise.
"""
pass
@abstractmethod
def generate_seq(self) -> list:
"""Generate a sequence of states for the current task.
Returns:
list: Sequence of states which be replayed for the current task.
If the sequence is not generated, return an empty list.
"""
pass
@abstractmethod
def seq_replay(self, sequence: list) -> int:
"""Replay the sequence and generate observations.
Args:
sequence (list): Sequence of states to be replayed.
Returns:
int: Length of the replayed sequence.
"""
pass
@abstractmethod
def save(self, save_path: str) -> int:
"""Save the all information.
Args:
save_path (str): Path to save the observations.
Returns:
int: Length of the saved observations.
"""
pass
# plan mode
def save_seq(self, save_path: str) -> int:
"""Save the generated sequence without observations.
Args:
save_path (str): Path to save the sequence.
Returns:
int: Length of the saved sequence.
"""
raise NotImplementedError("This method should be implemented in the subclass.")
# render mode
def recover_seq(self, seq_path: str) -> list:
"""Recover sequence from a sequence file.
Args:
seq_path (str): Path to the sequence file.
"""
raise NotImplementedError("This method should be implemented in the subclass.")
# plan with render mode
def generate_seq_with_obs(self) -> int:
"""Generate a sequence with observation for the current task.
(For debug or future RL)
Returns:
int: Length of the generated sequence.
"""
raise NotImplementedError("This method should be implemented in the subclass.")
# pipeline mode
def dump_plan_info(self) -> bytes:
"""Dump the layout and sequence plan information of the current task.
Returns:
bytes: Serialized plan information including layout and sequence data.
"""
raise NotImplementedError("This method should be implemented in the subclass.")
# pipeline mode
def dedump_plan_info(self, ser_obj: bytes) -> object:
"""Deserialize the layout and plan information of the current task.
Args:
ser_obj (bytes): Serialized plan information generated from dump_plan_info().
Returns:
object: Deserialized layout and sequence information.
This will be used as input for randomization_from_mem() and recover_seq_from_mem().
"""
raise NotImplementedError("This method should be implemented in the subclass.")
# pipeline mode
def randomization_from_mem(self, data: object) -> bool:
"""Perform randomization using in-memory plan data.
Args:
data (object): Deserialized layout and sequence information.
Returns:
bool: True if randomization succeeds, False otherwise.
"""
raise NotImplementedError("This method should be implemented in the subclass.")
# pipeline mode
def recover_seq_from_mem(self, data: object) -> list:
"""Recover sequence from in-memory plan data.
Args:
data (object): Deserialized layout and sequence information.
Returns:
list: Recovered sequence of states.
"""
raise NotImplementedError("This method should be implemented in the subclass.")
@classmethod
def register(cls, name: str):
"""
Register a workflow with its name(decorator).
Args:
name(str): name of the workflow
"""
def decorator(wfs_class):
cls.workflows[name] = wfs_class
@wraps(wfs_class)
def wrapped_function(*args, **kwargs):
return wfs_class(*args, **kwargs)
return wrapped_function
return decorator
def create_workflow(workflow_type: str, world, task_cfg_path: str, **kwargs):
wf_cls = NimbusWorkFlow.workflows[workflow_type]
return wf_cls(world, task_cfg_path, **kwargs)