diff --git a/pyproject.toml b/pyproject.toml index 44e29043b..12bb552fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"] # N: pep8-naming # TODO: Uncomment rules when ready to use select = [ - "E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP" + "E", "W", "F", "I", "B", "C4", "T20", "N", "UP" # "SIM", "A", "S", "D", "RUF" ] ignore = [ "E501", # Line too long diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 691d86af7..b661b21b0 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1421,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): """Keys to access image and video stream from cameras.""" keys = [] for key, feats in self.features.items(): - if isinstance(feats, (datasets.Image, VideoFrame)): + if isinstance(feats, (datasets.Image | VideoFrame)): keys.append(key) return keys diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py index f992275b7..f7072c72f 100644 --- a/src/lerobot/datasets/transforms.py +++ b/src/lerobot/datasets/transforms.py @@ -120,7 +120,7 @@ class SharpnessJitter(Transform): self.sharpness = self._check_input(sharpness) def _check_input(self, sharpness): - if isinstance(sharpness, (int, float)): + if isinstance(sharpness, (int | float)): if sharpness < 0: raise ValueError("If sharpness is a single number, it must be non negative.") sharpness = [1.0 - sharpness, 1.0 + sharpness] diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 81b361ab6..a2f285014 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -21,7 +21,7 @@ from collections import deque from collections.abc import Iterable, Iterator from pathlib import Path from pprint import pformat -from typing import Any, Deque, Generic, TypeVar +from typing import Any, Generic, TypeVar import datasets import numpy as np @@ -207,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: """ serialized_dict = {} for key, value in flatten_dict(stats).items(): - if isinstance(value, (torch.Tensor, np.ndarray)): + if isinstance(value, (torch.Tensor | np.ndarray)): serialized_dict[key] = value.tolist() - elif isinstance(value, list) and isinstance(value[0], (int, float, list)): + elif isinstance(value, list) and isinstance(value[0], (int | float | list)): serialized_dict[key] = value elif isinstance(value, np.generic): serialized_dict[key] = value.item() - elif isinstance(value, (int, float)): + elif isinstance(value, (int | float)): serialized_dict[key] = value else: raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") @@ -1179,7 +1179,7 @@ def item_to_torch(item: dict) -> dict: dict: Dictionary with all tensor-like items converted to torch.Tensor. """ for key, val in item.items(): - if isinstance(val, (np.ndarray, list)) and key not in ["task"]: + if isinstance(val, (np.ndarray | list)) and key not in ["task"]: # Convert numpy arrays and lists to torch tensors item[key] = torch.tensor(val) return item @@ -1253,8 +1253,8 @@ class Backtrackable(Generic[T]): raise ValueError("lookahead must be > 0") self._source: Iterator[T] = iter(iterable) - self._back_buf: Deque[T] = deque(maxlen=history) - self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() + self._back_buf: deque[T] = deque(maxlen=history) + self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() self._cursor: int = 0 self._history = history self._lookahead = lookahead diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 466796975..99ec6712f 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: """Normalize camera_name into a non-empty list of strings.""" if isinstance(camera_name, str): cams = [c.strip() for c in camera_name.split(",") if c.strip()] - elif isinstance(camera_name, (list, tuple)): + elif isinstance(camera_name, (list | tuple)): cams = [str(c).strip() for c in camera_name if str(c).strip()] else: raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}") diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 023ceea67..b5cfc7e26 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -183,10 +183,10 @@ def _(env: Mapping) -> None: @close_envs.register def _(envs: Sequence) -> None: - if isinstance(envs, (str, bytes)): + if isinstance(envs, (str | bytes)): return for v in envs: - if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)): + if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)): close_envs(v) elif hasattr(v, "close"): _close_single_env(v) diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 8603d81a9..17eaa8063 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -342,7 +342,7 @@ class MotorsBus(abc.ABC): raise TypeError(motors) def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: - if isinstance(values, (int, float)): + if isinstance(values, (int | float)): return dict.fromkeys(self.ids, values) elif isinstance(values, dict): return {self.motors[motor].id: val for motor, val in values.items()} @@ -669,7 +669,7 @@ class MotorsBus(abc.ABC): """ if motors is None: motors = list(self.motors) - elif isinstance(motors, (str, int)): + elif isinstance(motors, (str | int)): motors = [motors] elif not isinstance(motors, list): raise TypeError(motors) @@ -697,7 +697,7 @@ class MotorsBus(abc.ABC): """ if motors is None: motors = list(self.motors) - elif isinstance(motors, (str, int)): + elif isinstance(motors, (str | int)): motors = [motors] elif not isinstance(motors, list): raise TypeError(motors) @@ -733,7 +733,7 @@ class MotorsBus(abc.ABC): """ if motors is None: motors = list(self.motors) - elif isinstance(motors, (str, int)): + elif isinstance(motors, (str | int)): motors = [motors] elif not isinstance(motors, list): raise TypeError(motors) diff --git a/src/lerobot/policies/vqbet/vqbet_utils.py b/src/lerobot/policies/vqbet/vqbet_utils.py index 44b7d5f0b..7b13577f6 100644 --- a/src/lerobot/policies/vqbet/vqbet_utils.py +++ b/src/lerobot/policies/vqbet/vqbet_utils.py @@ -260,13 +260,11 @@ class GPT(nn.Module): param_dict = dict(self.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( - str(inter_params) + assert len(inter_params) == 0, ( + f"parameters {str(inter_params)} made it into both decay/no_decay sets!" ) assert len(param_dict.keys() - union_params) == 0, ( - "parameters {} were not separated into either decay/no_decay set!".format( - str(param_dict.keys() - union_params), - ) + f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" ) decay = [param_dict[pn] for pn in sorted(decay)] diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index d30b65082..917e4e2cc 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -176,7 +176,7 @@ class ReplayBuffer: self.complementary_info[key] = torch.empty( (self.capacity, *value_shape), device=self.storage_device ) - elif isinstance(value, (int, float)): + elif isinstance(value, (int | float)): # Handle scalar values similar to reward self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) else: @@ -223,7 +223,7 @@ class ReplayBuffer: value = complementary_info[key] if isinstance(value, torch.Tensor): self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) - elif isinstance(value, (int, float)): + elif isinstance(value, (int | float)): self.complementary_info[key][self.position] = value self.position = (self.position + 1) % self.capacity diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index b13254421..01cef9487 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -137,7 +137,7 @@ class WandBLogger: self._wandb.define_metric(new_custom_key, hidden=True) for k, v in d.items(): - if not isinstance(v, (int, float, str)): + if not isinstance(v, (int | float | str)): logging.warning( f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' ) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index d097a9d2f..ddb21e917 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -267,11 +267,7 @@ def record_loop( for t in teleop if isinstance( t, - ( - so100_leader.SO100Leader, - so101_leader.SO101Leader, - koch_leader.KochLeader, - ), + (so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader), ) ), None, diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 4eca4b9e2..21d73de2e 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -18,7 +18,6 @@ import logging import threading from collections import deque from pprint import pformat -from typing import Deque import serial @@ -60,7 +59,7 @@ class HomunculusArm(Teleoperator): self.n: int = n self.alpha: float = 2 / (n + 1) # one deque *per joint* so we can inspect raw history if needed - self._buffers: dict[str, Deque[int]] = { + self._buffers: dict[str, deque[int]] = { joint: deque(maxlen=n) for joint in ( "shoulder_pitch", diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index 52fd19def..251ecf56d 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -18,7 +18,6 @@ import logging import threading from collections import deque from pprint import pformat -from typing import Deque import serial @@ -97,7 +96,7 @@ class HomunculusGlove(Teleoperator): self.n: int = n self.alpha: float = 2 / (n + 1) # one deque *per joint* so we can inspect raw history if needed - self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints} + self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints} # running EMA value per joint – lazily initialised on first read self._ema: dict[str, float | None] = dict.fromkeys(self._buffers) diff --git a/src/lerobot/utils/transition.py b/src/lerobot/utils/transition.py index e874bd096..fe3620861 100644 --- a/src/lerobot/utils/transition.py +++ b/src/lerobot/utils/transition.py @@ -63,7 +63,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr for key, val in transition["complementary_info"].items(): if isinstance(val, torch.Tensor): transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) - elif isinstance(val, (int, float, bool)): + elif isinstance(val, (int | float | bool)): transition["complementary_info"][key] = torch.tensor(val, device=device) else: raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index ae070b7c4..d0201ecbf 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -35,7 +35,7 @@ def _is_scalar(x): return ( isinstance(x, float) or isinstance(x, numbers.Real) - or isinstance(x, (np.integer, np.floating)) + or isinstance(x, (np.integer | np.floating)) or (isinstance(x, np.ndarray) and x.ndim == 0) ) diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index e130ae144..64b125cc9 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): for key, param in policy.named_parameters(): if param.requires_grad: grad_stats[f"{key}_mean"] = param.grad.mean() - grad_stats[f"{key}_std"] = ( - param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0)) - ) + grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0) optimizer.step() param_stats = {} for key, param in policy.named_parameters(): param_stats[f"{key}_mean"] = param.mean() - param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0)) + param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0) optimizer.zero_grad() policy.reset() diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 904fd6fc1..76f2b1c26 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep): # Add type validation for multiplier if isinstance(multiplier, str): raise ValueError(f"multiplier must be a number, got string '{multiplier}'") - if not isinstance(multiplier, (int, float)): + if not isinstance(multiplier, (int | float)): raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}") self.multiplier = float(multiplier) self.env = env # Non-serializable parameter (like gym.Env) @@ -1623,7 +1623,7 @@ def test_override_with_callables(): # Define a transform function def double_values(x): - if isinstance(x, (int, float)): + if isinstance(x, (int | float)): return x * 2 elif isinstance(x, torch.Tensor): return x * 2 diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index ddf0771f1..b9d3a1ac0 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses): if isinstance(obj, torch.Tensor): return get_tensor_memory_consumption(obj) - elif isinstance(obj, (list, tuple)): + elif isinstance(obj, (list | tuple)): for item in obj: total_size += get_tensors_memory_consumption(item, visited_addresses) elif isinstance(obj, dict):