forked from tangger/lerobot
269 lines
9.1 KiB
Python
269 lines
9.1 KiB
Python
from typing import Any
|
|
|
|
import einops
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
from mani_skill.utils.wrappers.record import RecordEpisode
|
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
|
|
|
from lerobot.common.envs.configs import ManiskillEnvConfig
|
|
|
|
|
|
def preprocess_maniskill_observation(
|
|
observations: dict[str, np.ndarray],
|
|
) -> dict[str, torch.Tensor]:
|
|
"""Convert environment observation to LeRobot format observation.
|
|
Args:
|
|
observation: Dictionary of observation batches from a Gym vector environment.
|
|
Returns:
|
|
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
|
"""
|
|
# map to expected inputs for the policy
|
|
return_observations = {}
|
|
# TODO: You have to merge all tensors from agent key and extra key
|
|
# You don't keep sensor param key in the observation
|
|
# And you keep sensor data rgb
|
|
q_pos = observations["agent"]["qpos"]
|
|
q_vel = observations["agent"]["qvel"]
|
|
tcp_pos = observations["extra"]["tcp_pose"]
|
|
img = observations["sensor_data"]["base_camera"]["rgb"]
|
|
|
|
_, h, w, c = img.shape
|
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
|
|
|
# sanity check that images are uint8
|
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
|
|
|
# convert to channel first of type float32 in range [0,1]
|
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
|
img = img.type(torch.float32)
|
|
img /= 255
|
|
|
|
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
|
|
|
return_observations["observation.image"] = img
|
|
return_observations["observation.state"] = state
|
|
return return_observations
|
|
|
|
|
|
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
|
def __init__(self, env, device: torch.device = "cuda"):
|
|
super().__init__(env)
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
self.device = device
|
|
|
|
def observation(self, observation):
|
|
observation = preprocess_maniskill_observation(observation)
|
|
observation = {k: v.to(self.device) for k, v in observation.items()}
|
|
return observation
|
|
|
|
|
|
class ManiSkillCompat(gym.Wrapper):
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
new_action_space_shape = env.action_space.shape[-1]
|
|
new_low = np.squeeze(env.action_space.low, axis=0)
|
|
new_high = np.squeeze(env.action_space.high, axis=0)
|
|
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
|
|
|
def reset(
|
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
) -> tuple[Any, dict[str, Any]]:
|
|
options = {}
|
|
return super().reset(seed=seed, options=options)
|
|
|
|
def step(self, action):
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
reward = reward.item()
|
|
terminated = terminated.item()
|
|
truncated = truncated.item()
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
class ManiSkillActionWrapper(gym.ActionWrapper):
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
|
|
|
def action(self, action):
|
|
action, telop = action
|
|
return action
|
|
|
|
|
|
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
|
def __init__(self, env, multiply_factor: float = 1):
|
|
super().__init__(env)
|
|
self.multiply_factor = multiply_factor
|
|
action_space_agent: gym.spaces.Box = env.action_space[0]
|
|
action_space_agent.low = action_space_agent.low * multiply_factor
|
|
action_space_agent.high = action_space_agent.high * multiply_factor
|
|
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
|
|
|
def step(self, action):
|
|
if isinstance(action, tuple):
|
|
action, telop = action
|
|
else:
|
|
telop = 0
|
|
action = action / self.multiply_factor
|
|
obs, reward, terminated, truncated, info = self.env.step((action, telop))
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
class BatchCompatibleWrapper(gym.ObservationWrapper):
|
|
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
|
|
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
|
|
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
for key in observation:
|
|
if "image" in key and observation[key].dim() == 3:
|
|
observation[key] = observation[key].unsqueeze(0)
|
|
if "state" in key and observation[key].dim() == 1:
|
|
observation[key] = observation[key].unsqueeze(0)
|
|
return observation
|
|
|
|
|
|
class TimeLimitWrapper(gym.Wrapper):
|
|
"""Adds a time limit to the environment based on fps and control_time."""
|
|
|
|
def __init__(self, env, control_time_s, fps):
|
|
super().__init__(env)
|
|
self.control_time_s = control_time_s
|
|
self.fps = fps
|
|
self.max_episode_steps = int(self.control_time_s * self.fps)
|
|
self.current_step = 0
|
|
|
|
def step(self, action):
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
self.current_step += 1
|
|
|
|
if self.current_step >= self.max_episode_steps:
|
|
terminated = True
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def reset(self, *, seed=None, options=None):
|
|
self.current_step = 0
|
|
return super().reset(seed=seed, options=options)
|
|
|
|
|
|
class ManiskillMockGripperWrapper(gym.Wrapper):
|
|
def __init__(self, env, nb_discrete_actions: int = 3):
|
|
super().__init__(env)
|
|
new_shape = env.action_space[0].shape[0] + 1
|
|
new_low = np.concatenate([env.action_space[0].low, [0]])
|
|
new_high = np.concatenate([env.action_space[0].high, [nb_discrete_actions - 1]])
|
|
action_space_agent = gym.spaces.Box(low=new_low, high=new_high, shape=(new_shape,))
|
|
self.action_space = gym.spaces.Tuple((action_space_agent, env.action_space[1]))
|
|
|
|
def step(self, action):
|
|
if isinstance(action, tuple):
|
|
action_agent, telop_action = action
|
|
else:
|
|
telop_action = 0
|
|
action_agent = action
|
|
real_action = action_agent[:-1]
|
|
final_action = (real_action, telop_action)
|
|
obs, reward, terminated, truncated, info = self.env.step(final_action)
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
def make_maniskill(
|
|
cfg: ManiskillEnvConfig,
|
|
n_envs: int | None = None,
|
|
) -> gym.Env:
|
|
"""
|
|
Factory function to create a ManiSkill environment with standard wrappers.
|
|
|
|
Args:
|
|
cfg: Configuration for the ManiSkill environment
|
|
n_envs: Number of parallel environments
|
|
|
|
Returns:
|
|
A wrapped ManiSkill environment
|
|
"""
|
|
env = gym.make(
|
|
cfg.task,
|
|
obs_mode=cfg.obs_type,
|
|
control_mode=cfg.control_mode,
|
|
render_mode=cfg.render_mode,
|
|
sensor_configs={"width": cfg.image_size, "height": cfg.image_size},
|
|
num_envs=n_envs,
|
|
)
|
|
|
|
# Add video recording if enabled
|
|
if cfg.video_record.enabled:
|
|
env = RecordEpisode(
|
|
env,
|
|
output_dir=cfg.video_record.record_dir,
|
|
save_trajectory=True,
|
|
trajectory_name=cfg.video_record.trajectory_name,
|
|
save_video=True,
|
|
video_fps=30,
|
|
)
|
|
|
|
# Add observation and image processing
|
|
env = ManiSkillObservationWrapper(env, device=cfg.device)
|
|
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
|
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
|
|
env.unwrapped.metadata["render_fps"] = cfg.fps
|
|
|
|
# Add compatibility wrappers
|
|
env = ManiSkillCompat(env)
|
|
env = ManiSkillActionWrapper(env)
|
|
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
|
|
if cfg.mock_gripper:
|
|
env = ManiskillMockGripperWrapper(env, nb_discrete_actions=3)
|
|
|
|
return env
|
|
|
|
|
|
# @parser.wrap()
|
|
# def main(cfg: TrainPipelineConfig):
|
|
# """Main function to run the ManiSkill environment."""
|
|
# # Create the ManiSkill environment
|
|
# env = make_maniskill(cfg.env, n_envs=1)
|
|
|
|
# # Reset the environment
|
|
# obs, info = env.reset()
|
|
|
|
# # Run a simple interaction loop
|
|
# sum_reward = 0
|
|
# for i in range(100):
|
|
# # Sample a random action
|
|
# action = env.action_space.sample()
|
|
|
|
# # Step the environment
|
|
# start_time = time.perf_counter()
|
|
# obs, reward, terminated, truncated, info = env.step(action)
|
|
# step_time = time.perf_counter() - start_time
|
|
# sum_reward += reward
|
|
# # Log information
|
|
|
|
# # Reset if episode terminated
|
|
# if terminated or truncated:
|
|
# logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
|
|
# sum_reward = 0
|
|
# obs, info = env.reset()
|
|
|
|
# # Close the environment
|
|
# env.close()
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
# logging.basicConfig(level=logging.INFO)
|
|
# main()
|
|
|
|
if __name__ == "__main__":
|
|
import draccus
|
|
|
|
config = ManiskillEnvConfig()
|
|
draccus.set_config_type("json")
|
|
draccus.dump(
|
|
config=config,
|
|
stream=open(file="run_config.json", mode="w"),
|
|
)
|