forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -61,10 +61,16 @@ class AlohaEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(480, 640, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(14,)
|
||||
)
|
||||
self.features["pixels/top"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(480, 640, 3)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -102,9 +108,13 @@ class PushtEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
self.features["pixels"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(384, 384, 3)
|
||||
)
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
self.features["environment_state"] = PolicyFeature(
|
||||
type=FeatureType.ENV, shape=(16,)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
@@ -143,7 +153,9 @@ class XarmEnv(EnvConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(4,)
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
||||
@@ -32,7 +32,9 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
def make_env(
|
||||
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
|
||||
) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
@@ -56,7 +58,9 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`"
|
||||
)
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
@@ -64,7 +68,10 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs)
|
||||
for _ in range(n_envs)
|
||||
]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -50,7 +50,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
assert c < h and c < w, (
|
||||
f"expect channel last images, but instead got {img.shape=}"
|
||||
)
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
@@ -83,7 +85,9 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
raise ValueError(
|
||||
f"Number of dimensions of {key} != 3 (shape={ft.shape})"
|
||||
)
|
||||
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
@@ -105,7 +109,9 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
||||
|
||||
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
||||
if not (
|
||||
hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")
|
||||
):
|
||||
warnings.warn(
|
||||
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||
UserWarning,
|
||||
@@ -119,7 +125,9 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
)
|
||||
|
||||
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
def add_envs_task(
|
||||
env: gym.vector.VectorEnv, observation: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
observation["task"] = env.call("task_description")
|
||||
|
||||
Reference in New Issue
Block a user