diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 0daaaf9f..7a979b86 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -50,6 +50,8 @@ class AlohaEnv(EnvConfig): fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" + observation_height: int = 480 + observation_width: int = 640 render_mode: str = "rgb_array" features: dict[str, PolicyFeature] = field( default_factory=lambda: { @@ -67,10 +69,14 @@ class AlohaEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": - self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) - self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) + self.features["pixels/top"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) @property def gym_kwargs(self) -> dict: @@ -91,6 +97,8 @@ class PushtEnv(EnvConfig): render_mode: str = "rgb_array" visualization_width: int = 384 visualization_height: int = 384 + observation_height: int = 384 + observation_width: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), @@ -108,7 +116,9 @@ class PushtEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels_agent_pos": - self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) + self.features["pixels"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) + ) elif self.obs_type == "environment_state_agent_pos": self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) @@ -255,6 +265,8 @@ class LiberoEnv(EnvConfig): camera_name: str = "agentview_image,robot0_eye_in_hand_image" init_states: bool = True camera_name_mapping: dict[str, str] | None = None + observation_height: int = 360 + observation_width: int = 360 features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), @@ -272,18 +284,18 @@ class LiberoEnv(EnvConfig): def __post_init__(self): if self.obs_type == "pixels": self.features["pixels/agentview_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) self.features["pixels/agentview_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( - type=FeatureType.VISUAL, shape=(360, 360, 3) + type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3) ) else: raise ValueError(f"Unsupported obs_type: {self.obs_type}")