[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:41:27 +00:00
committed by Michel Aractingi
parent 2abbd60a0d
commit 0ea27704f6
123 changed files with 1161 additions and 3425 deletions

View File

@@ -61,16 +61,10 @@ class AlohaEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels":
self.features["top"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(480, 640, 3)
)
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
elif self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(
type=FeatureType.STATE, shape=(14,)
)
self.features["pixels/top"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(480, 640, 3)
)
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
@property
def gym_kwargs(self) -> dict:
@@ -108,13 +102,9 @@ class PushtEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["pixels"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(384, 384, 3)
)
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
elif self.obs_type == "environment_state_agent_pos":
self.features["environment_state"] = PolicyFeature(
type=FeatureType.ENV, shape=(16,)
)
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
@property
def gym_kwargs(self) -> dict:
@@ -153,9 +143,7 @@ class XarmEnv(EnvConfig):
def __post_init__(self):
if self.obs_type == "pixels_agent_pos":
self.features["agent_pos"] = PolicyFeature(
type=FeatureType.STATE, shape=(4,)
)
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
@property
def gym_kwargs(self) -> dict:

View File

@@ -32,9 +32,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env(
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
) -> gym.vector.VectorEnv | None:
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
"""Makes a gym vector environment according to the config.
Args:
@@ -58,9 +56,7 @@ def make_env(
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
)
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.task}"
@@ -68,18 +64,13 @@ def make_env(
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
for _ in range(n_envs)
]
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
return env
def make_maniskill_env(
cfg: DictConfig, n_envs: int | None = None
) -> gym.vector.VectorEnv | None:
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
"""Make ManiSkill3 gym environment"""
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
@@ -96,9 +87,7 @@ def make_maniskill_env(
# state should have the size of 25
# env = ConvertToLeRobotEnv(env, n_envs)
# env = PixelWrapper(cfg, env, n_envs)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
return env
@@ -125,11 +114,7 @@ class PixelWrapper(gym.Wrapper):
def _get_obs(self, obs):
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
self._frames.append(frame)
return {
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
self.env.device
)
}
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
def reset(self, seed):
obs, info = self.env.reset() # (seed=seed)
@@ -164,9 +149,7 @@ class ConvertToLeRobotEnv(gym.Wrapper):
images = torch.concat(images, axis=-1)
# flatten the rest of the data which should just be state data
observation = common.flatten_state_dict(
observation, use_torch=True, device=self.base_env.device
)
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
ret = dict()
ret["state"] = observation
ret["pixels"] = images

View File

@@ -50,9 +50,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, (
f"expect channel last images, but instead got {img.shape=}"
)
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
@@ -85,9 +83,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL:
if len(ft.shape) != 3:
raise ValueError(
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
)
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
shape = get_channel_first_image_shape(ft.shape)
feature = PolicyFeature(type=ft.type, shape=shape)