[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -108,9 +108,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
||||||
|
|
||||||
if not (
|
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
||||||
hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")
|
|
||||||
):
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
@@ -124,9 +122,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_envs_task(
|
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||||
env: gym.vector.VectorEnv, observation: dict[str, Any]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||||
if hasattr(env.envs[0], "task_description"):
|
if hasattr(env.envs[0], "task_description"):
|
||||||
observation["task"] = env.call("task_description")
|
observation["task"] = env.call("task_description")
|
||||||
|
|||||||
@@ -490,7 +490,7 @@ class SACObservationEncoder(nn.Module):
|
|||||||
"""Encode image and/or state vector observations."""
|
"""Encode image and/or state vector observations."""
|
||||||
|
|
||||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||||
super(SACObservationEncoder, self).__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.input_normalization = input_normalizer
|
self.input_normalization = input_normalizer
|
||||||
self._init_image_layers()
|
self._init_image_layers()
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ from contextlib import nullcontext
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
import rerun as rr
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import rerun as rr
|
||||||
import torch
|
import torch
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|||||||
@@ -700,7 +700,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
|||||||
Reset the environment and clear any pending events
|
Reset the environment and clear any pending events
|
||||||
"""
|
"""
|
||||||
with self.event_lock:
|
with self.event_lock:
|
||||||
self.events = {k: False for k in self.events}
|
self.events = dict.fromkeys(self.events, False)
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
@@ -179,9 +179,7 @@ def run_server(
|
|||||||
]
|
]
|
||||||
videos_info = [
|
videos_info = [
|
||||||
{
|
{
|
||||||
"url": url_for(
|
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
|
||||||
"static", filename=str(video_path).replace("\\", "/")
|
|
||||||
),
|
|
||||||
"filename": video_path.parent.name,
|
"filename": video_path.parent.name,
|
||||||
}
|
}
|
||||||
for video_path in video_paths
|
for video_path in video_paths
|
||||||
|
|||||||
Reference in New Issue
Block a user