chore: enable pyugrade ruff lint (#2084)

This commit is contained in:
Steven Palma
2025-09-29 13:28:53 +02:00
committed by GitHub
parent 90684a9690
commit c378a325f0
18 changed files with 33 additions and 43 deletions

View File

@@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
# N: pep8-naming # N: pep8-naming
# TODO: Uncomment rules when ready to use # TODO: Uncomment rules when ready to use
select = [ select = [
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP" "E", "W", "F", "I", "B", "C4", "T20", "N", "UP" # "SIM", "A", "S", "D", "RUF"
] ]
ignore = [ ignore = [
"E501", # Line too long "E501", # Line too long

View File

@@ -1421,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video stream from cameras.""" """Keys to access image and video stream from cameras."""
keys = [] keys = []
for key, feats in self.features.items(): for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)): if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key) keys.append(key)
return keys return keys

View File

@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
self.sharpness = self._check_input(sharpness) self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness): def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)): if isinstance(sharpness, (int | float)):
if sharpness < 0: if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.") raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness] sharpness = [1.0 - sharpness, 1.0 + sharpness]

View File

@@ -21,7 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator from collections.abc import Iterable, Iterator
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import Any, Deque, Generic, TypeVar from typing import Any, Generic, TypeVar
import datasets import datasets
import numpy as np import numpy as np
@@ -207,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
""" """
serialized_dict = {} serialized_dict = {}
for key, value in flatten_dict(stats).items(): for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)): if isinstance(value, (torch.Tensor | np.ndarray)):
serialized_dict[key] = value.tolist() serialized_dict[key] = value.tolist()
elif isinstance(value, list) and isinstance(value[0], (int, float, list)): elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
serialized_dict[key] = value serialized_dict[key] = value
elif isinstance(value, np.generic): elif isinstance(value, np.generic):
serialized_dict[key] = value.item() serialized_dict[key] = value.item()
elif isinstance(value, (int, float)): elif isinstance(value, (int | float)):
serialized_dict[key] = value serialized_dict[key] = value
else: else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
@@ -1179,7 +1179,7 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor. dict: Dictionary with all tensor-like items converted to torch.Tensor.
""" """
for key, val in item.items(): for key, val in item.items():
if isinstance(val, (np.ndarray, list)) and key not in ["task"]: if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors # Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val) item[key] = torch.tensor(val)
return item return item
@@ -1253,8 +1253,8 @@ class Backtrackable(Generic[T]):
raise ValueError("lookahead must be > 0") raise ValueError("lookahead must be > 0")
self._source: Iterator[T] = iter(iterable) self._source: Iterator[T] = iter(iterable)
self._back_buf: Deque[T] = deque(maxlen=history) self._back_buf: deque[T] = deque(maxlen=history)
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._cursor: int = 0 self._cursor: int = 0
self._history = history self._history = history
self._lookahead = lookahead self._lookahead = lookahead

View File

@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings.""" """Normalize camera_name into a non-empty list of strings."""
if isinstance(camera_name, str): if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()] cams = [c.strip() for c in camera_name.split(",") if c.strip()]
elif isinstance(camera_name, (list, tuple)): elif isinstance(camera_name, (list | tuple)):
cams = [str(c).strip() for c in camera_name if str(c).strip()] cams = [str(c).strip() for c in camera_name if str(c).strip()]
else: else:
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}") raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")

View File

@@ -183,10 +183,10 @@ def _(env: Mapping) -> None:
@close_envs.register @close_envs.register
def _(envs: Sequence) -> None: def _(envs: Sequence) -> None:
if isinstance(envs, (str, bytes)): if isinstance(envs, (str | bytes)):
return return
for v in envs: for v in envs:
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)): if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
close_envs(v) close_envs(v)
elif hasattr(v, "close"): elif hasattr(v, "close"):
_close_single_env(v) _close_single_env(v)

View File

@@ -342,7 +342,7 @@ class MotorsBus(abc.ABC):
raise TypeError(motors) raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int, float)): if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values) return dict.fromkeys(self.ids, values)
elif isinstance(values, dict): elif isinstance(values, dict):
return {self.motors[motor].id: val for motor, val in values.items()} return {self.motors[motor].id: val for motor, val in values.items()}
@@ -669,7 +669,7 @@ class MotorsBus(abc.ABC):
""" """
if motors is None: if motors is None:
motors = list(self.motors) motors = list(self.motors)
elif isinstance(motors, (str, int)): elif isinstance(motors, (str | int)):
motors = [motors] motors = [motors]
elif not isinstance(motors, list): elif not isinstance(motors, list):
raise TypeError(motors) raise TypeError(motors)
@@ -697,7 +697,7 @@ class MotorsBus(abc.ABC):
""" """
if motors is None: if motors is None:
motors = list(self.motors) motors = list(self.motors)
elif isinstance(motors, (str, int)): elif isinstance(motors, (str | int)):
motors = [motors] motors = [motors]
elif not isinstance(motors, list): elif not isinstance(motors, list):
raise TypeError(motors) raise TypeError(motors)
@@ -733,7 +733,7 @@ class MotorsBus(abc.ABC):
""" """
if motors is None: if motors is None:
motors = list(self.motors) motors = list(self.motors)
elif isinstance(motors, (str, int)): elif isinstance(motors, (str | int)):
motors = [motors] motors = [motors]
elif not isinstance(motors, list): elif not isinstance(motors, list):
raise TypeError(motors) raise TypeError(motors)

View File

@@ -260,13 +260,11 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters()) param_dict = dict(self.named_parameters())
inter_params = decay & no_decay inter_params = decay & no_decay
union_params = decay | no_decay union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( assert len(inter_params) == 0, (
str(inter_params) f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
) )
assert len(param_dict.keys() - union_params) == 0, ( assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format( f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
str(param_dict.keys() - union_params),
)
) )
decay = [param_dict[pn] for pn in sorted(decay)] decay = [param_dict[pn] for pn in sorted(decay)]

View File

@@ -176,7 +176,7 @@ class ReplayBuffer:
self.complementary_info[key] = torch.empty( self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device (self.capacity, *value_shape), device=self.storage_device
) )
elif isinstance(value, (int, float)): elif isinstance(value, (int | float)):
# Handle scalar values similar to reward # Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else: else:
@@ -223,7 +223,7 @@ class ReplayBuffer:
value = complementary_info[key] value = complementary_info[key]
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int, float)): elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity self.position = (self.position + 1) % self.capacity

View File

@@ -137,7 +137,7 @@ class WandBLogger:
self._wandb.define_metric(new_custom_key, hidden=True) self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, (int, float, str)): if not isinstance(v, (int | float | str)):
logging.warning( logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
) )

View File

@@ -267,11 +267,7 @@ def record_loop(
for t in teleop for t in teleop
if isinstance( if isinstance(
t, t,
( (so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
so100_leader.SO100Leader,
so101_leader.SO101Leader,
koch_leader.KochLeader,
),
) )
), ),
None, None,

View File

@@ -18,7 +18,6 @@ import logging
import threading import threading
from collections import deque from collections import deque
from pprint import pformat from pprint import pformat
from typing import Deque
import serial import serial
@@ -60,7 +59,7 @@ class HomunculusArm(Teleoperator):
self.n: int = n self.n: int = n
self.alpha: float = 2 / (n + 1) self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed # one deque *per joint* so we can inspect raw history if needed
self._buffers: dict[str, Deque[int]] = { self._buffers: dict[str, deque[int]] = {
joint: deque(maxlen=n) joint: deque(maxlen=n)
for joint in ( for joint in (
"shoulder_pitch", "shoulder_pitch",

View File

@@ -18,7 +18,6 @@ import logging
import threading import threading
from collections import deque from collections import deque
from pprint import pformat from pprint import pformat
from typing import Deque
import serial import serial
@@ -97,7 +96,7 @@ class HomunculusGlove(Teleoperator):
self.n: int = n self.n: int = n
self.alpha: float = 2 / (n + 1) self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed # one deque *per joint* so we can inspect raw history if needed
self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints} self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
# running EMA value per joint lazily initialised on first read # running EMA value per joint lazily initialised on first read
self._ema: dict[str, float | None] = dict.fromkeys(self._buffers) self._ema: dict[str, float | None] = dict.fromkeys(self._buffers)

View File

@@ -63,7 +63,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
for key, val in transition["complementary_info"].items(): for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor): if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)): elif isinstance(val, (int | float | bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device) transition["complementary_info"][key] = torch.tensor(val, device=device)
else: else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")

View File

@@ -35,7 +35,7 @@ def _is_scalar(x):
return ( return (
isinstance(x, float) isinstance(x, float)
or isinstance(x, numbers.Real) or isinstance(x, numbers.Real)
or isinstance(x, (np.integer, np.floating)) or isinstance(x, (np.integer | np.floating))
or (isinstance(x, np.ndarray) and x.ndim == 0) or (isinstance(x, np.ndarray) and x.ndim == 0)
) )

View File

@@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
for key, param in policy.named_parameters(): for key, param in policy.named_parameters():
if param.requires_grad: if param.requires_grad:
grad_stats[f"{key}_mean"] = param.grad.mean() grad_stats[f"{key}_mean"] = param.grad.mean()
grad_stats[f"{key}_std"] = ( grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0)
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.step() optimizer.step()
param_stats = {} param_stats = {}
for key, param in policy.named_parameters(): for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean() param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0)) param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0)
optimizer.zero_grad() optimizer.zero_grad()
policy.reset() policy.reset()

View File

@@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep):
# Add type validation for multiplier # Add type validation for multiplier
if isinstance(multiplier, str): if isinstance(multiplier, str):
raise ValueError(f"multiplier must be a number, got string '{multiplier}'") raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
if not isinstance(multiplier, (int, float)): if not isinstance(multiplier, (int | float)):
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}") raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
self.multiplier = float(multiplier) self.multiplier = float(multiplier)
self.env = env # Non-serializable parameter (like gym.Env) self.env = env # Non-serializable parameter (like gym.Env)
@@ -1623,7 +1623,7 @@ def test_override_with_callables():
# Define a transform function # Define a transform function
def double_values(x): def double_values(x):
if isinstance(x, (int, float)): if isinstance(x, (int | float)):
return x * 2 return x * 2
elif isinstance(x, torch.Tensor): elif isinstance(x, torch.Tensor):
return x * 2 return x * 2

View File

@@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses):
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
return get_tensor_memory_consumption(obj) return get_tensor_memory_consumption(obj)
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list | tuple)):
for item in obj: for item in obj:
total_size += get_tensors_memory_consumption(item, visited_addresses) total_size += get_tensors_memory_consumption(item, visited_addresses)
elif isinstance(obj, dict): elif isinstance(obj, dict):