[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -108,9 +108,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
||||
|
||||
if not (
|
||||
hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")
|
||||
):
|
||||
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
||||
warnings.warn(
|
||||
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||
UserWarning,
|
||||
@@ -124,9 +122,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
)
|
||||
|
||||
|
||||
def add_envs_task(
|
||||
env: gym.vector.VectorEnv, observation: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
observation["task"] = env.call("task_description")
|
||||
|
||||
@@ -490,7 +490,7 @@ class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
super(SACObservationEncoder, self).__init__()
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self._init_image_layers()
|
||||
|
||||
@@ -24,8 +24,8 @@ from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import rerun as rr
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
|
||||
@@ -700,7 +700,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||
Reset the environment and clear any pending events
|
||||
"""
|
||||
with self.event_lock:
|
||||
self.events = {k: False for k in self.events}
|
||||
self.events = dict.fromkeys(self.events, False)
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def close(self):
|
||||
|
||||
@@ -179,9 +179,7 @@ def run_server(
|
||||
]
|
||||
videos_info = [
|
||||
{
|
||||
"url": url_for(
|
||||
"static", filename=str(video_path).replace("\\", "/")
|
||||
),
|
||||
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
|
||||
"filename": video_path.parent.name,
|
||||
}
|
||||
for video_path in video_paths
|
||||
|
||||
Reference in New Issue
Block a user