chore: enable pyugrade ruff lint (#2084)

This commit is contained in:
Steven Palma
2025-09-29 13:28:53 +02:00
committed by GitHub
parent 90684a9690
commit c378a325f0
18 changed files with 33 additions and 43 deletions

View File

@@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
for key, param in policy.named_parameters():
if param.requires_grad:
grad_stats[f"{key}_mean"] = param.grad.mean()
grad_stats[f"{key}_std"] = (
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
)
grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0)
optimizer.step()
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0)
optimizer.zero_grad()
policy.reset()

View File

@@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep):
# Add type validation for multiplier
if isinstance(multiplier, str):
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
if not isinstance(multiplier, (int, float)):
if not isinstance(multiplier, (int | float)):
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
self.multiplier = float(multiplier)
self.env = env # Non-serializable parameter (like gym.Env)
@@ -1623,7 +1623,7 @@ def test_override_with_callables():
# Define a transform function
def double_values(x):
if isinstance(x, (int, float)):
if isinstance(x, (int | float)):
return x * 2
elif isinstance(x, torch.Tensor):
return x * 2

View File

@@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses):
if isinstance(obj, torch.Tensor):
return get_tensor_memory_consumption(obj)
elif isinstance(obj, (list, tuple)):
elif isinstance(obj, (list | tuple)):
for item in obj:
total_size += get_tensors_memory_consumption(item, visited_addresses)
elif isinstance(obj, dict):