[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -69,7 +69,9 @@ class HubMixin:
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name # Defaults to `save_directory` name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return self.push_to_hub(
repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs
)
return None
def _save_pretrained(self, save_directory: Path) -> None:
@@ -175,7 +177,9 @@ class HubMixin:
The url of the commit of your object in the given repository.
"""
api = HfApi(token=token)
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
repo_id = api.create_repo(
repo_id=repo_id, private=private, exist_ok=True
).repo_id
if commit_message is None:
if "Policy" in self.__class__.__name__:

View File

@@ -20,7 +20,16 @@ from typing import TypeVar
import imageio
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
JsonLike = (
str
| int
| float
| bool
| None
| list["JsonLike"]
| dict[str, "JsonLike"]
| tuple["JsonLike", ...]
)
T = TypeVar("T", bound=JsonLike)
@@ -76,7 +85,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# Check length
if len(target) != len(source):
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
raise ValueError(
f"List length mismatch: expected {len(target)}, got {len(source)}"
)
# Recursively update each element.
for i in range(len(target)):
@@ -88,10 +99,14 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
# which we'll convert back to a tuple.
elif isinstance(target, tuple):
if not isinstance(source, list):
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
raise TypeError(
f"Type mismatch: expected list (for tuple), got {type(source)}"
)
if len(target) != len(source):
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
raise ValueError(
f"Tuple length mismatch: expected {len(target)}, got {len(source)}"
)
# Convert each element, forming a new tuple.
converted_items = []
@@ -105,7 +120,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
else:
# Check the exact type. If these must match 1:1, do:
if type(target) is not type(source):
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
raise TypeError(
f"Type mismatch: expected {type(target)}, got {type(source)}"
)
return source
# Perform the in-place/recursive deserialization

View File

@@ -107,13 +107,17 @@ class MetricsTracker:
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
def __getattr__(
self, name: str
) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__:
return self.__dict__[name]
elif name in self.metrics:
return self.metrics[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def __setattr__(self, name: str, value: Any) -> None:
if name in self.__dict__:
@@ -121,7 +125,9 @@ class MetricsTracker:
elif name in self.metrics:
self.metrics[name].update(value)
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def step(self) -> None:
"""

View File

@@ -42,7 +42,11 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non
"""
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
"""
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
py_state = (
rng_state_dict["py_rng_version"].item(),
tuple(rng_state_dict["py_rng_state"].tolist()),
None,
)
random.setstate(py_state)
@@ -119,7 +123,9 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
torch_rng_state_dict = {
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
}
deserialize_python_rng_state(py_rng_state_dict)
deserialize_numpy_rng_state(np_rng_state_dict)

View File

@@ -48,7 +48,9 @@ def auto_select_torch_device() -> torch.device:
logging.info("Metal backend detected, using cuda.")
return torch.device("mps")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
logging.warning(
"No accelerated backend detected. Using default cpu, this will be slow."
)
return torch.device("cpu")
@@ -96,7 +98,9 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
raise ValueError(
f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu."
)
def is_amp_available(device: str):
@@ -219,7 +223,9 @@ def say(text, blocking=False):
if blocking:
subprocess.run(cmd, check=True)
else:
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
subprocess.Popen(
cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0
)
def log_say(text, play_sounds, blocking=False):

View File

@@ -26,7 +26,9 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR
from lerobot.configs.train import TrainPipelineConfig
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
def cfg_to_group(
cfg: TrainPipelineConfig, return_list: bool = False
) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
@@ -93,7 +95,9 @@ class WandBLogger:
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
logging.info(
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
)
self._wandb = wandb
def log_policy(self, checkpoint_dir: Path):
@@ -105,7 +109,9 @@ class WandBLogger:
artifact_name = f"{self._group}-{step_id}"
artifact_name = get_safe_wandb_artifact_name(artifact_name)
artifact = self._wandb.Artifact(artifact_name, type="model")
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
artifact.add_file(
checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE
)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"):