forked from tangger/lerobot
chore: enable pyugrade ruff lint (#2084)
This commit is contained in:
@@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||
# N: pep8-naming
|
||||
# TODO: Uncomment rules when ready to use
|
||||
select = [
|
||||
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP"
|
||||
"E", "W", "F", "I", "B", "C4", "T20", "N", "UP" # "SIM", "A", "S", "D", "RUF"
|
||||
]
|
||||
ignore = [
|
||||
"E501", # Line too long
|
||||
|
||||
@@ -1421,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
if isinstance(feats, (datasets.Image | VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
|
||||
self.sharpness = self._check_input(sharpness)
|
||||
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if isinstance(sharpness, (int | float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
|
||||
@@ -21,7 +21,7 @@ from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any, Deque, Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -207,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
"""
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
if isinstance(value, (torch.Tensor | np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||
elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
|
||||
serialized_dict[key] = value
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
@@ -1179,7 +1179,7 @@ def item_to_torch(item: dict) -> dict:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
@@ -1253,8 +1253,8 @@ class Backtrackable(Generic[T]):
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: Deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._back_buf: deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
if isinstance(camera_name, str):
|
||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||
elif isinstance(camera_name, (list, tuple)):
|
||||
elif isinstance(camera_name, (list | tuple)):
|
||||
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
||||
else:
|
||||
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
||||
|
||||
@@ -183,10 +183,10 @@ def _(env: Mapping) -> None:
|
||||
|
||||
@close_envs.register
|
||||
def _(envs: Sequence) -> None:
|
||||
if isinstance(envs, (str, bytes)):
|
||||
if isinstance(envs, (str | bytes)):
|
||||
return
|
||||
for v in envs:
|
||||
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
|
||||
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
|
||||
close_envs(v)
|
||||
elif hasattr(v, "close"):
|
||||
_close_single_env(v)
|
||||
|
||||
@@ -342,7 +342,7 @@ class MotorsBus(abc.ABC):
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
if isinstance(values, (int, float)):
|
||||
if isinstance(values, (int | float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
return {self.motors[motor].id: val for motor, val in values.items()}
|
||||
@@ -669,7 +669,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
@@ -697,7 +697,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
@@ -733,7 +733,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
@@ -260,13 +260,11 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
assert len(inter_params) == 0, (
|
||||
f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
)
|
||||
f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
|
||||
)
|
||||
|
||||
decay = [param_dict[pn] for pn in sorted(decay)]
|
||||
|
||||
@@ -176,7 +176,7 @@ class ReplayBuffer:
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
@@ -223,7 +223,7 @@ class ReplayBuffer:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
@@ -137,7 +137,7 @@ class WandBLogger:
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
if not isinstance(v, (int | float | str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||
)
|
||||
|
||||
@@ -267,11 +267,7 @@ def record_loop(
|
||||
for t in teleop
|
||||
if isinstance(
|
||||
t,
|
||||
(
|
||||
so100_leader.SO100Leader,
|
||||
so101_leader.SO101Leader,
|
||||
koch_leader.KochLeader,
|
||||
),
|
||||
(so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
|
||||
)
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -18,7 +18,6 @@ import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from pprint import pformat
|
||||
from typing import Deque
|
||||
|
||||
import serial
|
||||
|
||||
@@ -60,7 +59,7 @@ class HomunculusArm(Teleoperator):
|
||||
self.n: int = n
|
||||
self.alpha: float = 2 / (n + 1)
|
||||
# one deque *per joint* so we can inspect raw history if needed
|
||||
self._buffers: dict[str, Deque[int]] = {
|
||||
self._buffers: dict[str, deque[int]] = {
|
||||
joint: deque(maxlen=n)
|
||||
for joint in (
|
||||
"shoulder_pitch",
|
||||
|
||||
@@ -18,7 +18,6 @@ import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from pprint import pformat
|
||||
from typing import Deque
|
||||
|
||||
import serial
|
||||
|
||||
@@ -97,7 +96,7 @@ class HomunculusGlove(Teleoperator):
|
||||
self.n: int = n
|
||||
self.alpha: float = 2 / (n + 1)
|
||||
# one deque *per joint* so we can inspect raw history if needed
|
||||
self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
|
||||
self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
|
||||
# running EMA value per joint – lazily initialised on first read
|
||||
self._ema: dict[str, float | None] = dict.fromkeys(self._buffers)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
elif isinstance(val, (int | float | bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
|
||||
@@ -35,7 +35,7 @@ def _is_scalar(x):
|
||||
return (
|
||||
isinstance(x, float)
|
||||
or isinstance(x, numbers.Real)
|
||||
or isinstance(x, (np.integer, np.floating))
|
||||
or isinstance(x, (np.integer | np.floating))
|
||||
or (isinstance(x, np.ndarray) and x.ndim == 0)
|
||||
)
|
||||
|
||||
|
||||
@@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
for key, param in policy.named_parameters():
|
||||
if param.requires_grad:
|
||||
grad_stats[f"{key}_mean"] = param.grad.mean()
|
||||
grad_stats[f"{key}_std"] = (
|
||||
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
|
||||
)
|
||||
grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0)
|
||||
|
||||
optimizer.step()
|
||||
param_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
param_stats[f"{key}_mean"] = param.mean()
|
||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0)
|
||||
|
||||
optimizer.zero_grad()
|
||||
policy.reset()
|
||||
|
||||
@@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep):
|
||||
# Add type validation for multiplier
|
||||
if isinstance(multiplier, str):
|
||||
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
|
||||
if not isinstance(multiplier, (int, float)):
|
||||
if not isinstance(multiplier, (int | float)):
|
||||
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
|
||||
self.multiplier = float(multiplier)
|
||||
self.env = env # Non-serializable parameter (like gym.Env)
|
||||
@@ -1623,7 +1623,7 @@ def test_override_with_callables():
|
||||
|
||||
# Define a transform function
|
||||
def double_values(x):
|
||||
if isinstance(x, (int, float)):
|
||||
if isinstance(x, (int | float)):
|
||||
return x * 2
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x * 2
|
||||
|
||||
@@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses):
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return get_tensor_memory_consumption(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
elif isinstance(obj, (list | tuple)):
|
||||
for item in obj:
|
||||
total_size += get_tensors_memory_consumption(item, visited_addresses)
|
||||
elif isinstance(obj, dict):
|
||||
|
||||
Reference in New Issue
Block a user