[WIP] Non functional yet

Add ManiSkill environment configuration and wrappers

- Introduced `VideoRecordConfig` for video recording settings.
- Added `ManiskillEnvConfig` to encapsulate environment-specific configurations.
- Implemented various wrappers for the ManiSkill environment, including observation and action scaling.
- Enhanced the `make_maniskill` function to create a wrapped ManiSkill environment with video recording and observation processing.
- Updated the `actor_server` and `learner_server` to utilize the new configuration structure.
- Refactored the training pipeline to accommodate the new environment and policy configurations.
This commit is contained in:
AdilZouitine
2025-03-26 08:15:05 +00:00
committed by Michel Aractingi
parent 114ec644d0
commit 056f79d358
9 changed files with 667 additions and 436 deletions

View File

@@ -1,4 +1,7 @@
from typing import Any
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import einops
import gymnasium as gym
@@ -6,7 +9,11 @@ import numpy as np
import torch
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from omegaconf import DictConfig
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
from lerobot.configs import parser
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
def preprocess_maniskill_observation(
@@ -46,9 +53,14 @@ def preprocess_maniskill_observation(
return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env)
if isinstance(device, str):
device = torch.device(device)
self.device = device
def observation(self, observation):
@@ -108,76 +120,129 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
return obs, reward, terminated, truncated, info
class BatchCompatibleWrapper(gym.ObservationWrapper):
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
def __init__(self, env):
super().__init__(env)
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
if "state" in key and observation[key].dim() == 1:
observation[key] = observation[key].unsqueeze(0)
return observation
class TimeLimitWrapper(gym.Wrapper):
"""Adds a time limit to the environment based on fps and control_time."""
def __init__(self, env, control_time_s, fps):
super().__init__(env)
self.control_time_s = control_time_s
self.fps = fps
self.max_episode_steps = int(self.control_time_s * self.fps)
self.current_step = 0
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.current_step += 1
if self.current_step >= self.max_episode_steps:
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, *, seed=None, options=None):
self.current_step = 0
return super().reset(seed=seed, options=options)
def make_maniskill(
cfg: DictConfig,
cfg: ManiskillEnvConfig,
n_envs: int | None = None,
) -> gym.Env:
"""
Factory function to create a ManiSkill environment with standard wrappers.
Args:
task: Name of the ManiSkill task
obs_mode: Observation mode (rgb, rgbd, etc)
control_mode: Control mode for the robot
render_mode: Rendering mode
sensor_configs: Camera sensor configurations
cfg: Configuration for the ManiSkill environment
n_envs: Number of parallel environments
Returns:
A wrapped ManiSkill environment
"""
env = gym.make(
cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
cfg.task,
obs_mode=cfg.obs_type,
control_mode=cfg.control_mode,
render_mode=cfg.render_mode,
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
num_envs=n_envs,
)
if cfg.env.video_record.enabled:
# Add video recording if enabled
if cfg.video_record.enabled:
env = RecordEpisode(
env,
output_dir=cfg.env.video_record.record_dir,
output_dir=cfg.video_record.record_dir,
save_trajectory=True,
trajectory_name=cfg.env.video_record.trajectory_name,
trajectory_name=cfg.video_record.trajectory_name,
save_video=True,
video_fps=30,
)
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
# Add observation and image processing
env = ManiSkillObservationWrapper(env, device=cfg.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
env.unwrapped.metadata["render_fps"] = cfg.fps
# Add compatibility wrappers
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
return env
if __name__ == "__main__":
import argparse
import hydra
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
args = parser.parse_args()
# Initialize config
with hydra.initialize(version_base=None, config_path="../../configs"):
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
env = make_maniskill(
task=cfg.env.task,
obs_mode=cfg.env.obs,
control_mode=cfg.env.control_mode,
render_mode=cfg.env.render_mode,
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
)
print("env done")
@parser.wrap()
def main(cfg: ManiskillEnvConfig):
"""Main function to run the ManiSkill environment."""
# Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1)
# Reset the environment
obs, info = env.reset()
random_action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(random_action)
# Run a simple interaction loop
sum_reward = 0
for i in range(100):
# Sample a random action
action = env.action_space.sample()
# Step the environment
start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time
sum_reward += reward
# Log information
# Reset if episode terminated
if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0
obs, info = env.reset()
# Close the environment
env.close()
# if __name__ == "__main__":
# logging.basicConfig(level=logging.INFO)
# main()
if __name__ == "__main__":
import draccus
config = ManiskillEnvConfig()
draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )