[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-04-18 13:33:36 +00:00
parent dcd850feab
commit fb92935601
5 changed files with 6 additions and 12 deletions

View File

@@ -108,9 +108,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("once", UserWarning) # Apply filter only in this function warnings.simplefilter("once", UserWarning) # Apply filter only in this function
if not ( if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")
):
warnings.warn( warnings.warn(
"The environment does not have 'task_description' and 'task'. Some policies require these features.", "The environment does not have 'task_description' and 'task'. Some policies require these features.",
UserWarning, UserWarning,
@@ -124,9 +122,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
) )
def add_envs_task( def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
env: gym.vector.VectorEnv, observation: dict[str, Any]
) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute.""" """Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"): if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description") observation["task"] = env.call("task_description")

View File

@@ -490,7 +490,7 @@ class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations.""" """Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None: def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
super(SACObservationEncoder, self).__init__() super().__init__()
self.config = config self.config = config
self.input_normalization = input_normalizer self.input_normalization = input_normalizer
self._init_image_layers() self._init_image_layers()

View File

@@ -24,8 +24,8 @@ from contextlib import nullcontext
from copy import copy from copy import copy
from functools import cache from functools import cache
import rerun as rr
import numpy as np import numpy as np
import rerun as rr
import torch import torch
from deepdiff import DeepDiff from deepdiff import DeepDiff
from termcolor import colored from termcolor import colored

View File

@@ -700,7 +700,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
Reset the environment and clear any pending events Reset the environment and clear any pending events
""" """
with self.event_lock: with self.event_lock:
self.events = {k: False for k in self.events} self.events = dict.fromkeys(self.events, False)
return self.env.reset(**kwargs) return self.env.reset(**kwargs)
def close(self): def close(self):

View File

@@ -179,9 +179,7 @@ def run_server(
] ]
videos_info = [ videos_info = [
{ {
"url": url_for( "url": url_for("static", filename=str(video_path).replace("\\", "/")),
"static", filename=str(video_path).replace("\\", "/")
),
"filename": video_path.parent.name, "filename": video_path.parent.name,
} }
for video_path in video_paths for video_path in video_paths