[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-28 17:20:38 +00:00
committed by Michel Aractingi
parent 8eb3c1510c
commit eb44a06a9b
16 changed files with 93 additions and 91 deletions

View File

@@ -1,7 +1,6 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
from typing import Any
import einops
import gymnasium as gym
@@ -10,10 +9,8 @@ import torch
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from lerobot.common.envs.configs import EnvConfig, ManiskillEnvConfig
from lerobot.common.envs.configs import ManiskillEnvConfig
from lerobot.configs import parser
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_ROBOT
def preprocess_maniskill_observation(
@@ -53,9 +50,6 @@ def preprocess_maniskill_observation(
return return_observations
class ManiSkillObservationWrapper(gym.ObservationWrapper):
def __init__(self, env, device: torch.device = "cuda"):
super().__init__(env)
@@ -122,6 +116,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
class BatchCompatibleWrapper(gym.ObservationWrapper):
"""Ensures observations are batch-compatible by adding a batch dimension if necessary."""
def __init__(self, env):
super().__init__(env)
@@ -136,6 +131,7 @@ class BatchCompatibleWrapper(gym.ObservationWrapper):
class TimeLimitWrapper(gym.Wrapper):
"""Adds a time limit to the environment based on fps and control_time."""
def __init__(self, env, control_time_s, fps):
super().__init__(env)
self.control_time_s = control_time_s
@@ -146,10 +142,10 @@ class TimeLimitWrapper(gym.Wrapper):
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.current_step += 1
if self.current_step >= self.max_episode_steps:
terminated = True
return obs, reward, terminated, truncated, info
def reset(self, *, seed=None, options=None):
@@ -190,18 +186,18 @@ def make_maniskill(
save_video=True,
video_fps=30,
)
# Add observation and image processing
env = ManiSkillObservationWrapper(env, device=cfg.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = cfg.episode_length
env.unwrapped.metadata["render_fps"] = cfg.fps
# Add compatibility wrappers
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=0.03) # Scale actions for better control
return env
@@ -210,29 +206,29 @@ def main(cfg: ManiskillEnvConfig):
"""Main function to run the ManiSkill environment."""
# Create the ManiSkill environment
env = make_maniskill(cfg, n_envs=1)
# Reset the environment
obs, info = env.reset()
# Run a simple interaction loop
sum_reward = 0
for i in range(100):
# Sample a random action
action = env.action_space.sample()
# Step the environment
start_time = time.perf_counter()
obs, reward, terminated, truncated, info = env.step(action)
step_time = time.perf_counter() - start_time
sum_reward += reward
# Log information
# Reset if episode terminated
if terminated or truncated:
logging.info(f"Step {i}, reward: {sum_reward}, step time: {step_time}s")
sum_reward = 0
obs, info = env.reset()
# Close the environment
env.close()
@@ -243,6 +239,10 @@ def main(cfg: ManiskillEnvConfig):
if __name__ == "__main__":
import draccus
config = ManiskillEnvConfig()
draccus.set_config_type("json")
draccus.dump(config=config, stream=open(file='run_config.json', mode='w'), )
draccus.dump(
config=config,
stream=open(file="run_config.json", mode="w"),
)