chore: enable pyugrade ruff lint (#2084)

This commit is contained in:
Steven Palma
2025-09-29 13:28:53 +02:00
committed by GitHub
parent 90684a9690
commit c378a325f0
18 changed files with 33 additions and 43 deletions

View File

@@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
# N: pep8-naming
# TODO: Uncomment rules when ready to use
select = [
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP"
"E", "W", "F", "I", "B", "C4", "T20", "N", "UP" # "SIM", "A", "S", "D", "RUF"
]
ignore = [
"E501", # Line too long

View File

@@ -1421,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys

View File

@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if isinstance(sharpness, (int | float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]

View File

@@ -21,7 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import Any, Deque, Generic, TypeVar
from typing import Any, Generic, TypeVar
import datasets
import numpy as np
@@ -207,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
"""
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
if isinstance(value, (torch.Tensor | np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
serialized_dict[key] = value
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
@@ -1179,7 +1179,7 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
@@ -1253,8 +1253,8 @@ class Backtrackable(Generic[T]):
raise ValueError("lookahead must be > 0")
self._source: Iterator[T] = iter(iterable)
self._back_buf: Deque[T] = deque(maxlen=history)
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._back_buf: deque[T] = deque(maxlen=history)
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._cursor: int = 0
self._history = history
self._lookahead = lookahead

View File

@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings."""
if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
elif isinstance(camera_name, (list, tuple)):
elif isinstance(camera_name, (list | tuple)):
cams = [str(c).strip() for c in camera_name if str(c).strip()]
else:
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")

View File

@@ -183,10 +183,10 @@ def _(env: Mapping) -> None:
@close_envs.register
def _(envs: Sequence) -> None:
if isinstance(envs, (str, bytes)):
if isinstance(envs, (str | bytes)):
return
for v in envs:
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
close_envs(v)
elif hasattr(v, "close"):
_close_single_env(v)

View File

@@ -342,7 +342,7 @@ class MotorsBus(abc.ABC):
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int, float)):
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
return {self.motors[motor].id: val for motor, val in values.items()}
@@ -669,7 +669,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -697,7 +697,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -733,7 +733,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)

View File

@@ -260,13 +260,11 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
assert len(inter_params) == 0, (
f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
)
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
)
f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
)
decay = [param_dict[pn] for pn in sorted(decay)]

View File

@@ -176,7 +176,7 @@ class ReplayBuffer:
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
@@ -223,7 +223,7 @@ class ReplayBuffer:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity

View File

@@ -137,7 +137,7 @@ class WandBLogger:
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)

View File

@@ -267,11 +267,7 @@ def record_loop(
for t in teleop
if isinstance(
t,
(
so100_leader.SO100Leader,
so101_leader.SO101Leader,
koch_leader.KochLeader,
),
(so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
)
),
None,

View File

@@ -18,7 +18,6 @@ import logging
import threading
from collections import deque
from pprint import pformat
from typing import Deque
import serial
@@ -60,7 +59,7 @@ class HomunculusArm(Teleoperator):
self.n: int = n
self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed
self._buffers: dict[str, Deque[int]] = {
self._buffers: dict[str, deque[int]] = {
joint: deque(maxlen=n)
for joint in (
"shoulder_pitch",

View File

@@ -18,7 +18,6 @@ import logging
import threading
from collections import deque
from pprint import pformat
from typing import Deque
import serial
@@ -97,7 +96,7 @@ class HomunculusGlove(Teleoperator):
self.n: int = n
self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed
self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
# running EMA value per joint lazily initialised on first read
self._ema: dict[str, float | None] = dict.fromkeys(self._buffers)

View File

@@ -63,7 +63,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
elif isinstance(val, (int | float | bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")

View File

@@ -35,7 +35,7 @@ def _is_scalar(x):
return (
isinstance(x, float)
or isinstance(x, numbers.Real)
or isinstance(x, (np.integer, np.floating))
or isinstance(x, (np.integer | np.floating))
or (isinstance(x, np.ndarray) and x.ndim == 0)
)

View File

@@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
for key, param in policy.named_parameters():
if param.requires_grad:
grad_stats[f"{key}_mean"] = param.grad.mean()
grad_stats[f"{key}_std"] = (
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
)
grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0)
optimizer.step()
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0)
optimizer.zero_grad()
policy.reset()

View File

@@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep):
# Add type validation for multiplier
if isinstance(multiplier, str):
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
if not isinstance(multiplier, (int, float)):
if not isinstance(multiplier, (int | float)):
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
self.multiplier = float(multiplier)
self.env = env # Non-serializable parameter (like gym.Env)
@@ -1623,7 +1623,7 @@ def test_override_with_callables():
# Define a transform function
def double_values(x):
if isinstance(x, (int, float)):
if isinstance(x, (int | float)):
return x * 2
elif isinstance(x, torch.Tensor):
return x * 2

View File

@@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses):
if isinstance(obj, torch.Tensor):
return get_tensor_memory_consumption(obj)
elif isinstance(obj, (list, tuple)):
elif isinstance(obj, (list | tuple)):
for item in obj:
total_size += get_tensors_memory_consumption(item, visited_addresses)
elif isinstance(obj, dict):