[HIL-SERL] Review feedback modifications (#1112)

This commit is contained in:
Adil Zouitine
2025-05-15 15:24:41 +02:00
committed by AdilZouitine
parent 5902f8fcc7
commit a5f758d7c6
17 changed files with 504 additions and 180 deletions

View File

@@ -68,6 +68,7 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
# https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/vector/vectorize_observation.py#L19,
# env = ObservationProcessorWrapper(env=env)
return env

View File

@@ -81,35 +81,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return return_observations
class ObservationProcessorWrapper(gym.vector.VectorEnvWrapper):
def __init__(self, env: gym.vector.VectorEnv):
super().__init__(env)
def _observations(self, observations: dict[str, Any]) -> dict[str, Any]:
return preprocess_observation(observations)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
):
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
observations, infos = self.env.reset(seed=seed, options=options)
return self._observations(observations), infos
def step(self, actions):
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return (
self._observations(observations),
rewards,
terminations,
truncations,
infos,
)
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies)