739 lines
29 KiB
Python
739 lines
29 KiB
Python
import glob
|
|
import json
|
|
import os
|
|
import pickle
|
|
import random
|
|
import time
|
|
from collections import defaultdict
|
|
from copy import deepcopy
|
|
from datetime import datetime
|
|
|
|
import numpy as np
|
|
import yaml
|
|
from omni.isaac.core.utils.prims import get_prim_at_path
|
|
from omni.isaac.core.utils.transformations import (
|
|
get_relative_transform,
|
|
pose_from_tf_matrix,
|
|
)
|
|
from omni.physx import acquire_physx_interface
|
|
from tqdm import tqdm
|
|
from yaml import Loader
|
|
|
|
from deps.world_toolkit.world_recorder import WorldRecorder
|
|
from workflows.simbox.utils.task_config_parser import TaskConfigParser
|
|
|
|
from .base import NimbusWorkFlow
|
|
from .simbox.core.controllers import get_controller_cls
|
|
from .simbox.core.loggers.lmdb_logger import LmdbLogger
|
|
from .simbox.core.loggers.utils import log_dual_obs
|
|
from .simbox.core.skills import get_skill_cls
|
|
from .simbox.core.tasks import get_task_cls
|
|
from .simbox.core.utils.collision_utils import filter_collisions
|
|
from .simbox.core.utils.utils import set_random_seed
|
|
|
|
|
|
# pylint: disable=unused-argument
|
|
@NimbusWorkFlow.register("SimBoxDualWorkFlow")
|
|
class SimBoxDualWorkFlow(NimbusWorkFlow):
|
|
def __init__(
|
|
self,
|
|
world,
|
|
task_cfg_path: str,
|
|
scene_info: str = "dining_room_scene_info",
|
|
random_seed: int = None,
|
|
):
|
|
self.scene_info = scene_info
|
|
self.step_replay = False
|
|
self.random_seed = random_seed
|
|
super().__init__(world, task_cfg_path)
|
|
|
|
def parse_task_cfgs(self, task_cfg_path: str) -> list:
|
|
task_cfgs = TaskConfigParser(task_cfg_path).parse_tasks()
|
|
# Merge robot configs for each task
|
|
for task_cfg in task_cfgs:
|
|
self._merge_robot_configs(task_cfg)
|
|
return task_cfgs
|
|
|
|
def _merge_robot_configs(self, task_cfg: dict):
|
|
"""Merge robot configs from robot_config_file into task_cfg['robots']."""
|
|
robots = task_cfg.get("robots", [])
|
|
|
|
for robot in robots:
|
|
robot_config_file = robot.get("robot_config_file")
|
|
if robot_config_file:
|
|
with open(robot_config_file, "r", encoding="utf-8") as f:
|
|
robot_base_cfg = yaml.load(f, Loader=Loader)
|
|
|
|
# Merge: robot_base_cfg as base, task_cfg['robots'][i] overrides
|
|
merged_cfg = deepcopy(robot_base_cfg)
|
|
merged_cfg.update(robot)
|
|
robot.clear()
|
|
robot.update(merged_cfg)
|
|
|
|
def reset(self, need_preload: bool = True):
|
|
# source code noted this as debug, so it could be removed later
|
|
from omni.isaac.core.utils.viewports import set_camera_view
|
|
|
|
set_camera_view(eye=[1.3, 0.7, 2.7], target=[0.0, 0, 1.5], camera_prim_path="/OmniverseKit_Persp")
|
|
# Modify config
|
|
arena_file_path = self.task_cfg.get("arena_file", None)
|
|
with open(arena_file_path, "r", encoding="utf-8") as arena_file:
|
|
arena = yaml.load(arena_file, Loader=Loader)
|
|
|
|
if "involved_scenes" in arena:
|
|
arena["involved_scenes"] = self.scene_info
|
|
|
|
self.task_cfg["arena"] = arena
|
|
|
|
for obj_cfg in self.task_cfg["objects"]:
|
|
if obj_cfg["target_class"] == "ArticulatedObject":
|
|
asset_root = self.task_cfg["asset_root"]
|
|
art_paths = glob.glob(os.path.join(asset_root, obj_cfg["art_cat"], "*"))
|
|
art_paths.sort()
|
|
path = random.choice(art_paths)
|
|
info_name = obj_cfg["info_name"]
|
|
info_path = f"{path}/Kps/{info_name}/info.json"
|
|
with open(info_path, "r", encoding="utf-8") as f:
|
|
info = json.load(f)
|
|
scale = info["object_scale"][:3]
|
|
|
|
obj_cfg["path"] = path.replace(f"{asset_root}/", "", 1)
|
|
obj_cfg["category"] = path.split("/")[-2]
|
|
obj_cfg["obj_info_path"] = info_path.replace(f"{asset_root}/", "", 1)
|
|
obj_cfg["scale"] = scale
|
|
self.task_cfg["data"]["collect_info"] = obj_cfg["category"]
|
|
|
|
self.task_cfg.pop("arena_file", None)
|
|
self.task_cfg.pop("camera_file", None)
|
|
self.task_cfg.pop("logger_file", None)
|
|
# Modify config done
|
|
if self.task_cfg.get("fluid", None):
|
|
# for fluid manipulation, only gpu mode is supportive
|
|
physx_interface = acquire_physx_interface()
|
|
physx_interface.overwrite_gpu_setting(1)
|
|
|
|
self.task = get_task_cls(self.task_cfg["task"])(self.task_cfg)
|
|
self.stage = self.world.stage
|
|
self.stage.SetDefaultPrim(self.stage.GetPrimAtPath("/World"))
|
|
self.world.add_task(self.task)
|
|
|
|
# # Add hidden ground plane for physics simulation
|
|
# from omni.isaac.core.objects import GroundPlane
|
|
# plane = GroundPlane(
|
|
# prim_path="/World/GroundPlane",
|
|
# z_position=0.0,
|
|
# visible=False,
|
|
# )
|
|
|
|
prim_paths = [] # do not collide with each other
|
|
global_collision_paths = [] # collide with everything
|
|
|
|
self.robots_prim_paths = []
|
|
for robot in self.task_cfg["robots"]:
|
|
robot_prim_path = self.task.root_prim_path + "/" + robot["name"]
|
|
prim_paths.append(robot_prim_path)
|
|
self.robots_prim_paths.append(robot_prim_path)
|
|
neglect_collision_names = self.task_cfg.get("neglect_collision_names", [])
|
|
candidates = self.task_cfg["objects"] + self.task_cfg["arena"]["fixtures"]
|
|
for candidate in candidates:
|
|
candidate_prim_path = self.task.root_prim_path + "/" + candidate["name"]
|
|
global_collision_paths.append(candidate_prim_path)
|
|
for neglect_collision_name in neglect_collision_names:
|
|
if neglect_collision_name in candidate["name"]:
|
|
prim_paths.append(candidate_prim_path)
|
|
global_collision_paths.remove(candidate_prim_path)
|
|
|
|
collision_root_path = "/World/collisions"
|
|
filter_collisions(
|
|
self.stage,
|
|
self.world.get_physics_context().prim_path,
|
|
collision_root_path,
|
|
prim_paths,
|
|
global_collision_paths,
|
|
)
|
|
self.world.reset()
|
|
self.world.step(render=True)
|
|
self.controllers = self._initialize_controllers(self.task, self.task_cfg, self.world)
|
|
self.skills = self._initialize_skills(self.task, self.task_cfg, self.controllers, self.world)
|
|
|
|
for _ in range(50):
|
|
self._init_static_objects(self.task)
|
|
self.world.step(render=False)
|
|
|
|
self.logger = LmdbLogger(
|
|
task_dir=self.task_cfg["data"]["task_dir"],
|
|
language_instruction=self.task.language_instruction,
|
|
detailed_language_instruction=self.task.detailed_language_instruction,
|
|
collect_info=self.task_cfg["data"]["collect_info"],
|
|
version=self.task_cfg["data"].get("version", "v1.0"),
|
|
)
|
|
|
|
if self.random_seed is not None:
|
|
seed = self.random_seed
|
|
else:
|
|
seed = time.time_ns() % (2**32)
|
|
set_random_seed(seed)
|
|
|
|
# while True:
|
|
# self.world.get_observations()
|
|
# # self._init_static_objects(self.task)
|
|
# self.world.step(render=True)
|
|
|
|
def _initialize_skills(self, task, task_cfg, controllers, world):
|
|
draw_points = False
|
|
if draw_points:
|
|
from omni.isaac.debug_draw import _debug_draw
|
|
|
|
draw = _debug_draw.acquire_debug_draw_interface()
|
|
else:
|
|
draw = None
|
|
|
|
# Initialize skills for each robot.
|
|
skills = []
|
|
for cfg_skill_dict in task_cfg["skills"]:
|
|
skill_dict = defaultdict(list)
|
|
for robot_name, robot_skill_list in cfg_skill_dict.items():
|
|
robot = task.robots[robot_name]
|
|
controller = controllers[robot_name]
|
|
|
|
for lr_skill_dict in robot_skill_list:
|
|
skill_sequence = [
|
|
[
|
|
get_skill_cls(skill_cfg["name"])(
|
|
robot,
|
|
controller[lr_name],
|
|
task,
|
|
skill_cfg,
|
|
world=world,
|
|
draw=draw,
|
|
)
|
|
for skill_cfg in lr_skill_list
|
|
]
|
|
for lr_name, lr_skill_list in lr_skill_dict.items()
|
|
]
|
|
skill_dict[robot_name].append(skill_sequence)
|
|
skills.append(skill_dict)
|
|
return skills
|
|
|
|
def _initialize_controllers(self, task, task_cfg, world):
|
|
"""Initialize controllers for each robot."""
|
|
controllers = {}
|
|
for robot in task_cfg["robots"]:
|
|
controllers[robot["name"]] = {}
|
|
for robot_file in robot["robot_file"]:
|
|
controller_name = "left" if "left" in robot_file else "right"
|
|
controllers[robot["name"]][controller_name] = get_controller_cls(robot["target_class"])(
|
|
name=robot["name"],
|
|
robot_file=robot_file,
|
|
constrain_grasp_approach=robot.get("constrain_grasp_approach", False),
|
|
collision_activation_distance=robot.get("collision_activation_distance", 0.03),
|
|
task=task,
|
|
world=world,
|
|
ignore_substring=robot.get("ignore_substring", ["material", "Plane", "conveyor", "scene", "table"]),
|
|
use_batch=robot.get("use_batch", False),
|
|
)
|
|
controllers[robot["name"]][controller_name].reset()
|
|
return controllers
|
|
|
|
def _initialize_world_recorder(self):
|
|
"""
|
|
Initialize WorldRecorder with appropriate mode based on configuration.
|
|
|
|
Supports two modes:
|
|
- step_replay=False: Records prim poses for fast geometric replay (compatible with old workflow)
|
|
- step_replay=True: Uses preprocessed joint position data for physics-accurate replay (new default)
|
|
"""
|
|
self.world_recorder = WorldRecorder(
|
|
self.world,
|
|
self.task.robots,
|
|
self.task.objects | self.task.distractors | self.task.visuals,
|
|
step_replay=self.step_replay,
|
|
)
|
|
self.world_recorder.reset()
|
|
|
|
def _reset_controllers(self, controllers):
|
|
"""Reset all controllers."""
|
|
for _, controller in controllers.items():
|
|
for _, ctrl in controller.items():
|
|
ctrl.reset()
|
|
|
|
def _init_static_objects(self, task):
|
|
for _, obj in task.objects.items():
|
|
try:
|
|
init_translation = obj.init_translation
|
|
init_orientation = obj.init_orientation
|
|
init_parent = obj.init_parent
|
|
if init_translation and init_orientation and init_parent:
|
|
parent_world_pose = get_relative_transform(
|
|
get_prim_at_path(task.root_prim_path + "/" + init_parent), get_prim_at_path(task.root_prim_path)
|
|
)
|
|
parent_translation, _ = pose_from_tf_matrix(parent_world_pose)
|
|
obj.set_local_pose(
|
|
translation=(parent_translation + init_translation), orientation=init_orientation
|
|
)
|
|
obj.set_angular_velocity(np.array([0.0, 0.0, 0.0]))
|
|
obj.set_linear_velocity(np.array([0.0, 0.0, 0.0]))
|
|
except Exception:
|
|
pass
|
|
|
|
def _randomization_layout_mem(self):
|
|
# Reset world
|
|
self.world.reset()
|
|
|
|
# Individual initialize
|
|
self.task.individual_randomize_from_mem()
|
|
self.task.post_reset()
|
|
|
|
self.world.step(render=False)
|
|
|
|
# Reset controllers
|
|
self._reset_controllers(self.controllers)
|
|
|
|
# Reset skills
|
|
del self.skills
|
|
self.skills = self._initialize_skills(self.task, self.task_cfg, self.controllers, self.world)
|
|
|
|
# Warmup
|
|
for _ in range(20):
|
|
self.world.get_observations()
|
|
self._init_static_objects(self.task)
|
|
self.world.step(render=False)
|
|
|
|
self._initialize_world_recorder()
|
|
|
|
self.logger.clear(
|
|
language_instruction=self.task.language_instruction,
|
|
detailed_language_instruction=self.task.detailed_language_instruction,
|
|
)
|
|
|
|
# episode_stats["current_times"] += 1
|
|
|
|
def _randomization_layout(self):
|
|
# Reset world
|
|
self.world.reset()
|
|
|
|
# Individual initialize
|
|
self.task.individual_randomize()
|
|
self.task.post_reset()
|
|
|
|
self.world.step(render=False)
|
|
|
|
# Reset controllers
|
|
if self.task_cfg.get("fluid", None):
|
|
# Fluid, Bug, Why !!!!!!
|
|
# For fluid manipulation, only delete controllers and reinitialize controllers can plan successfully
|
|
if hasattr(self, "controllers"):
|
|
del self.controllers
|
|
self.controllers = self._initialize_controllers(self.task, self.task_cfg, self.world)
|
|
|
|
# del self.controllers
|
|
# self.controllers = self._initialize_controllers(self.task, self.task_cfg, self.world)
|
|
self._reset_controllers(self.controllers)
|
|
|
|
# Reset skills
|
|
if hasattr(self, "skills"):
|
|
del self.skills
|
|
|
|
self.skills = self._initialize_skills(self.task, self.task_cfg, self.controllers, self.world)
|
|
|
|
# Warmup
|
|
for _ in range(20):
|
|
self.world.get_observations()
|
|
self._init_static_objects(self.task)
|
|
self.world.step(render=False)
|
|
|
|
if self.task_cfg.get("fluid", None):
|
|
self.task._set_fluid()
|
|
# Fluid need additional warmup
|
|
for _ in range(150):
|
|
self.world.step(render=False)
|
|
|
|
self._initialize_world_recorder()
|
|
|
|
self.logger.clear(
|
|
language_instruction=self.task.language_instruction,
|
|
detailed_language_instruction=self.task.detailed_language_instruction,
|
|
)
|
|
|
|
# episode_stats["current_times"] += 1
|
|
|
|
def randomization(self, layout_path=None) -> bool:
|
|
try:
|
|
if layout_path is None:
|
|
# Individual Reset
|
|
self.task.individual_reset()
|
|
self._randomization_layout()
|
|
else:
|
|
with open(layout_path, "rb") as f:
|
|
data = pickle.load(f)
|
|
self.data = data
|
|
self.randomization_from_mem(data)
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def update_skill_states(self, skills, episode_success, should_continue):
|
|
"""Update and manage skill states."""
|
|
current_skills = skills[0]
|
|
|
|
# Check if any skills remain
|
|
if not any(current_skills.values()):
|
|
skills.pop(0)
|
|
if skills:
|
|
should_continue = self.plan_first_skill(skills, should_continue)
|
|
return episode_success, should_continue
|
|
|
|
# Update each robot's skills
|
|
for _, skill_sequences in current_skills.items():
|
|
if not skill_sequences:
|
|
continue
|
|
|
|
# Update all skills first
|
|
for lr_skill_list in skill_sequences[0]:
|
|
if lr_skill_list:
|
|
start_lr_skill = lr_skill_list[0]
|
|
start_lr_skill.update() # Must update regardless of completion
|
|
if start_lr_skill.is_done():
|
|
if not start_lr_skill.is_success():
|
|
episode_success = False
|
|
should_continue = False
|
|
lr_skill_list.remove(start_lr_skill)
|
|
|
|
if lr_skill_list:
|
|
next_skill = lr_skill_list[0]
|
|
next_skill.simple_generate_manip_cmds()
|
|
if hasattr(next_skill, "visualize_target"):
|
|
next_skill.visualize_target(self.world)
|
|
if len(next_skill.manip_list) == 0:
|
|
should_continue = not next_skill.is_ready()
|
|
if hasattr(start_lr_skill, "visualize_target"):
|
|
start_lr_skill.visualize_target(self.world)
|
|
|
|
# Remove empty skill sequences
|
|
completed_skills = []
|
|
for lr_skill_list in skill_sequences[0]:
|
|
if not lr_skill_list:
|
|
completed_skills.append(lr_skill_list)
|
|
for completed_skill in completed_skills:
|
|
skill_sequences[0].remove(completed_skill)
|
|
|
|
# Move to next sequence if current is empty
|
|
if not skill_sequences[0]:
|
|
skill_sequences.pop(0)
|
|
if skill_sequences:
|
|
for skill in skill_sequences[0]:
|
|
skill[0].simple_generate_manip_cmds()
|
|
if len(skill[0].manip_list) == 0:
|
|
should_continue = not skill[0].is_ready()
|
|
return episode_success, should_continue
|
|
|
|
def plan_first_skill(self, skills, should_continue):
|
|
for _, robot_skill_list in skills[0].items():
|
|
for lr_skill_list in robot_skill_list[0]:
|
|
lr_skill_list[0].simple_generate_manip_cmds()
|
|
if hasattr(lr_skill_list[0], "visualize_target"):
|
|
lr_skill_list[0].visualize_target(self.world)
|
|
if len(lr_skill_list[0].manip_list) == 0:
|
|
should_continue = not lr_skill_list[0].is_ready()
|
|
return should_continue
|
|
|
|
def generate_seq(self) -> list:
|
|
end = False
|
|
|
|
# while True:
|
|
# obs = self.world.get_observations()
|
|
# # self._init_static_objects(self.task)
|
|
# self.world.step(render=True)
|
|
|
|
step_id = 0
|
|
episode_success = True
|
|
should_continue = True
|
|
max_episode_length = self.task_cfg["data"]["max_episode_length"]
|
|
episode_stats = {"succeed_times": 0, "current_times": 0}
|
|
|
|
should_continue = self.plan_first_skill(self.skills, should_continue)
|
|
|
|
# Warmup
|
|
for _ in range(10):
|
|
obs = self.world.get_observations()
|
|
# self._init_static_objects(self.task)
|
|
self.world.step(render=False)
|
|
|
|
while not (step_id >= max_episode_length or (not self.skills and not episode_success) or (not should_continue)):
|
|
obs = self.world.get_observations()
|
|
action_dict = {}
|
|
record_flag = True
|
|
if self.skills and should_continue:
|
|
# Process current skills
|
|
current_skills = self.skills[0]
|
|
for robot_name, skill_sequences in current_skills.items():
|
|
if skill_sequences and skill_sequences[0]:
|
|
action = [
|
|
skill[0].controller.forward(skill[0].manip_list[0])
|
|
for skill in skill_sequences[0]
|
|
if skill[0] and skill[0].is_ready()
|
|
]
|
|
|
|
feasible_labels = [skill[0].is_feasible() for skill in skill_sequences[0] if skill[0]]
|
|
record_labels = [skill[0].is_record() for skill in skill_sequences[0] if skill[0]]
|
|
|
|
if False in feasible_labels:
|
|
should_continue = False
|
|
if False in record_labels:
|
|
record_flag = False
|
|
|
|
if action:
|
|
action_dict[robot_name] = {
|
|
"joint_positions": np.concatenate([a["joint_positions"] for a in action]),
|
|
"joint_indices": np.concatenate([a["joint_indices"] for a in action]),
|
|
"raw_action": action,
|
|
}
|
|
elif not self.skills and episode_success:
|
|
print("Task is successful")
|
|
end = True
|
|
for j_idx in range(1, 7):
|
|
self.world.step(render=False)
|
|
obs = self.world.get_observations()
|
|
log_dual_obs(self.logger, obs, action_dict, self.controllers, step_idx=step_id + j_idx)
|
|
self.world_recorder.record()
|
|
|
|
episode_stats["succeed_times"] += 1
|
|
should_continue = False
|
|
|
|
if record_flag:
|
|
log_dual_obs(self.logger, obs, action_dict, self.controllers, step_idx=step_id)
|
|
self.world_recorder.record()
|
|
self.task.apply_action(action_dict)
|
|
self.world.step(render=False)
|
|
|
|
step_id += 1
|
|
if self.skills:
|
|
episode_success, should_continue = self.update_skill_states(
|
|
self.skills, episode_success, should_continue
|
|
)
|
|
|
|
if end:
|
|
if self.step_replay:
|
|
return [None] * step_id
|
|
else:
|
|
# Prim poses mode: return recorded poses for compatibility
|
|
return self.world_recorder.prim_poses
|
|
else:
|
|
return []
|
|
|
|
def recover_seq(self, seq_path):
|
|
data = self.data
|
|
return self.recover_seq_from_mem(data)
|
|
|
|
def _record_rgb_depth(self, step_idx: int):
|
|
for key, value in self.task.cameras.items():
|
|
for robot_name, _ in self.task.robots.items():
|
|
if robot_name in key:
|
|
rgb_img = value.get_observations()["color_image"]
|
|
# Special processing if enabled
|
|
camera2env_pose = value.get_observations()["camera2env_pose"]
|
|
save_camera_name = key.replace(f"{robot_name}_", "")
|
|
self.logger.add_color_image(robot_name, "images.rgb." + save_camera_name, rgb_img)
|
|
self.logger.add_scalar_data(robot_name, "camera2env_pose." + save_camera_name, camera2env_pose)
|
|
if step_idx == 0:
|
|
save_camera_name = key.replace(f"{robot_name}_", "")
|
|
self.logger.add_json_data(
|
|
robot_name, f"{save_camera_name}_camera_params", value.get_observations()["camera_params"]
|
|
)
|
|
|
|
# depth_img = get_src(value, "depth")
|
|
# depth_img = np.nan_to_num(depth_img, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
# # Initialize lists for new camera keys
|
|
# if key not in self.rgb:
|
|
# self.rgb[key] = []
|
|
# if key not in self.depth:
|
|
# self.depth[key] = []
|
|
|
|
# # Append current frame to the corresponding camera's list
|
|
# self.rgb[key].append(rgb_img)
|
|
# self.depth[key].append(depth_img)
|
|
|
|
def seq_replay(self, sequence: list) -> int:
|
|
"""
|
|
Replay recorded sequence with mode-specific data preparation.
|
|
|
|
Returns:
|
|
int: Number of steps replayed
|
|
"""
|
|
if not self.step_replay:
|
|
self.world_recorder.prim_poses = sequence
|
|
|
|
# warmup before replay formally
|
|
self.world_recorder.warmup()
|
|
|
|
# Get total steps from WorldRecorder
|
|
total_steps = self.world_recorder.get_total_steps()
|
|
step_idx = 0
|
|
|
|
# Unified replay loop - WorldRecorder handles rendering internally
|
|
with tqdm(total=total_steps, desc="Replay Progress") as pbar:
|
|
while not self.world_recorder.replay():
|
|
# Record RGB/depth at current step
|
|
self._record_rgb_depth(step_idx)
|
|
step_idx += 1
|
|
pbar.update(1)
|
|
|
|
self.length = total_steps
|
|
print("Replay finished.")
|
|
return total_steps
|
|
|
|
def get_task_name(self):
|
|
return self.task_cfg["task"]
|
|
|
|
def save_seq(self, save_path: str) -> int:
|
|
ser_bytes = self.dump_plan_info()
|
|
timestamp = datetime.now().strftime("%Y-%m-%d_%H_%M_%S_%f")
|
|
save_path = os.path.join(save_path, "plan")
|
|
os.makedirs(save_path, exist_ok=True)
|
|
path = os.path.join(save_path, f"{timestamp}.pkl")
|
|
with open(path, "wb") as f:
|
|
f.write(ser_bytes)
|
|
return self.world_recorder.get_total_steps()
|
|
|
|
def save(self, save_path: str) -> int:
|
|
os.makedirs(save_path, exist_ok=True)
|
|
timestamp = datetime.now().strftime("%Y-%m-%d_%H_%M_%S_%f")
|
|
self.logger.save(save_path, timestamp, save_img=True)
|
|
|
|
return self.length
|
|
|
|
def plan_with_render(self):
|
|
end = False
|
|
|
|
step_id = 0
|
|
length = 0
|
|
episode_success = True
|
|
should_continue = True
|
|
max_episode_length = self.task_cfg["data"]["max_episode_length"]
|
|
episode_stats = {"succeed_times": 0, "current_times": 0}
|
|
|
|
should_continue = self.plan_first_skill(self.skills, should_continue)
|
|
|
|
# Warmup
|
|
for _ in range(10):
|
|
obs = self.world.get_observations()
|
|
# self._init_static_objects(self.task)
|
|
self.world.step(render=True)
|
|
|
|
# while True:
|
|
# obs = self.world.get_observations()
|
|
# # self._init_static_objects(self.task)
|
|
# self.world.step(render=True)
|
|
|
|
while not (step_id >= max_episode_length or (not self.skills and not episode_success) or (not should_continue)):
|
|
obs = self.world.get_observations()
|
|
action_dict = {}
|
|
record_flag = True
|
|
if self.skills and should_continue:
|
|
# Process current skills
|
|
current_skills = self.skills[0]
|
|
for robot_name, skill_sequences in current_skills.items():
|
|
if skill_sequences and skill_sequences[0]:
|
|
action = [
|
|
skill[0].controller.forward(skill[0].manip_list[0])
|
|
for skill in skill_sequences[0]
|
|
if skill[0] and skill[0].is_ready()
|
|
]
|
|
|
|
feasible_labels = [skill[0].is_feasible() for skill in skill_sequences[0] if skill[0]]
|
|
record_labels = [skill[0].is_record() for skill in skill_sequences[0] if skill[0]]
|
|
|
|
if False in feasible_labels:
|
|
should_continue = False
|
|
if False in record_labels:
|
|
record_flag = False
|
|
|
|
if action:
|
|
action_dict[robot_name] = {
|
|
"joint_positions": np.concatenate([a["joint_positions"] for a in action]),
|
|
"joint_indices": np.concatenate([a["joint_indices"] for a in action]),
|
|
"raw_action": action,
|
|
}
|
|
elif not self.skills and episode_success:
|
|
print("Task is successful")
|
|
end = True
|
|
for j_idx in range(1, 7):
|
|
self.world.step(render=True)
|
|
obs = self.world.get_observations()
|
|
log_dual_obs(self.logger, obs, action_dict, self.controllers, step_idx=step_id + j_idx)
|
|
self._record_rgb_depth(step_id + j_idx)
|
|
self.world_recorder.record()
|
|
length = step_id + 6
|
|
episode_stats["succeed_times"] += 1
|
|
should_continue = False
|
|
|
|
if record_flag:
|
|
log_dual_obs(self.logger, obs, action_dict, self.controllers, step_idx=step_id)
|
|
self._record_rgb_depth(step_id)
|
|
self.task.apply_action(action_dict)
|
|
self.world.step(render=True)
|
|
|
|
step_id += 1
|
|
if self.skills:
|
|
episode_success, should_continue = self.update_skill_states(
|
|
self.skills, episode_success, should_continue
|
|
)
|
|
|
|
self.length = length
|
|
if end:
|
|
return length
|
|
else:
|
|
return 0
|
|
|
|
def _dump_task_cfg(self, task_cfg):
|
|
task_cfg_copy = deepcopy(task_cfg)
|
|
return pickle.dumps(task_cfg_copy)
|
|
|
|
def dump_plan_info(self) -> bytes:
|
|
logger_ser = self.logger.dump()
|
|
cfg_ser = self._dump_task_cfg(self.task_cfg)
|
|
ser = pickle.dumps((cfg_ser, self.world_recorder.dumps(), logger_ser))
|
|
return ser
|
|
|
|
def dedump_plan_info(self, ser_obj: bytes) -> object:
|
|
res = pickle.loads(ser_obj)
|
|
return res
|
|
|
|
def randomization_from_mem(self, data) -> bool:
|
|
try:
|
|
cfg_ser, _, _ = data
|
|
task_cfg = pickle.loads(cfg_ser)
|
|
self.task_cfg = task_cfg
|
|
self.task.cfg = task_cfg
|
|
|
|
# Individual Reset
|
|
self.task.individual_reset_from_mem()
|
|
self._randomization_layout_mem()
|
|
return True
|
|
except Exception as e:
|
|
raise e
|
|
|
|
def recover_seq_from_mem(self, data) -> list:
|
|
"""
|
|
Recover sequence from memory based on WorldRecorder mode.
|
|
|
|
Returns:
|
|
- step_replay=False: Returns prim_poses list
|
|
- step_replay=True: Returns placeholder list (replay data is in WorldRecorder)
|
|
"""
|
|
try:
|
|
_, wr_ser, logger_ser = data
|
|
self.logger.dedump(logger_ser)
|
|
|
|
if wr_ser:
|
|
self.world_recorder.loads(wr_ser)
|
|
|
|
if self.step_replay:
|
|
return [None] * self.world_recorder.num_steps
|
|
else:
|
|
return self.world_recorder.prim_poses
|
|
|
|
except Exception as e:
|
|
raise e
|