[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -69,7 +69,9 @@ class HubMixin:
|
||||
if push_to_hub:
|
||||
if repo_id is None:
|
||||
repo_id = save_directory.name # Defaults to `save_directory` name
|
||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||
return self.push_to_hub(
|
||||
repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs
|
||||
)
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
@@ -175,7 +177,9 @@ class HubMixin:
|
||||
The url of the commit of your object in the given repository.
|
||||
"""
|
||||
api = HfApi(token=token)
|
||||
repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id
|
||||
repo_id = api.create_repo(
|
||||
repo_id=repo_id, private=private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
if commit_message is None:
|
||||
if "Policy" in self.__class__.__name__:
|
||||
|
||||
@@ -20,7 +20,16 @@ from typing import TypeVar
|
||||
|
||||
import imageio
|
||||
|
||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||
JsonLike = (
|
||||
str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| None
|
||||
| list["JsonLike"]
|
||||
| dict[str, "JsonLike"]
|
||||
| tuple["JsonLike", ...]
|
||||
)
|
||||
T = TypeVar("T", bound=JsonLike)
|
||||
|
||||
|
||||
@@ -76,7 +85,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
|
||||
# Check length
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
|
||||
raise ValueError(
|
||||
f"List length mismatch: expected {len(target)}, got {len(source)}"
|
||||
)
|
||||
|
||||
# Recursively update each element.
|
||||
for i in range(len(target)):
|
||||
@@ -88,10 +99,14 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
# which we'll convert back to a tuple.
|
||||
elif isinstance(target, tuple):
|
||||
if not isinstance(source, list):
|
||||
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
|
||||
raise TypeError(
|
||||
f"Type mismatch: expected list (for tuple), got {type(source)}"
|
||||
)
|
||||
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
|
||||
raise ValueError(
|
||||
f"Tuple length mismatch: expected {len(target)}, got {len(source)}"
|
||||
)
|
||||
|
||||
# Convert each element, forming a new tuple.
|
||||
converted_items = []
|
||||
@@ -105,7 +120,9 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
else:
|
||||
# Check the exact type. If these must match 1:1, do:
|
||||
if type(target) is not type(source):
|
||||
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
|
||||
raise TypeError(
|
||||
f"Type mismatch: expected {type(target)}, got {type(source)}"
|
||||
)
|
||||
return source
|
||||
|
||||
# Perform the in-place/recursive deserialization
|
||||
|
||||
@@ -107,13 +107,17 @@ class MetricsTracker:
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
def __getattr__(
|
||||
self, name: str
|
||||
) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
elif name in self.metrics:
|
||||
return self.metrics[name]
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in self.__dict__:
|
||||
@@ -121,7 +125,9 @@ class MetricsTracker:
|
||||
elif name in self.metrics:
|
||||
self.metrics[name].update(value)
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
raise AttributeError(
|
||||
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -42,7 +42,11 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non
|
||||
"""
|
||||
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
||||
"""
|
||||
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
||||
py_state = (
|
||||
rng_state_dict["py_rng_version"].item(),
|
||||
tuple(rng_state_dict["py_rng_state"].tolist()),
|
||||
None,
|
||||
)
|
||||
random.setstate(py_state)
|
||||
|
||||
|
||||
@@ -119,7 +123,9 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
||||
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
||||
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
||||
torch_rng_state_dict = {
|
||||
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
|
||||
}
|
||||
|
||||
deserialize_python_rng_state(py_rng_state_dict)
|
||||
deserialize_numpy_rng_state(np_rng_state_dict)
|
||||
|
||||
@@ -48,7 +48,9 @@ def auto_select_torch_device() -> torch.device:
|
||||
logging.info("Metal backend detected, using cuda.")
|
||||
return torch.device("mps")
|
||||
else:
|
||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||
logging.warning(
|
||||
"No accelerated backend detected. Using default cpu, this will be slow."
|
||||
)
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@@ -96,7 +98,9 @@ def is_torch_device_available(try_device: str) -> bool:
|
||||
elif try_device == "cpu":
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||
raise ValueError(
|
||||
f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu."
|
||||
)
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
@@ -219,7 +223,9 @@ def say(text, blocking=False):
|
||||
if blocking:
|
||||
subprocess.run(cmd, check=True)
|
||||
else:
|
||||
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
||||
subprocess.Popen(
|
||||
cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0
|
||||
)
|
||||
|
||||
|
||||
def log_say(text, play_sounds, blocking=False):
|
||||
|
||||
@@ -26,7 +26,9 @@ from lerobot.common.constants import PRETRAINED_MODEL_DIR
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
def cfg_to_group(
|
||||
cfg: TrainPipelineConfig, return_list: bool = False
|
||||
) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
@@ -93,7 +95,9 @@ class WandBLogger:
|
||||
mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online",
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
logging.info(
|
||||
f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}"
|
||||
)
|
||||
self._wandb = wandb
|
||||
|
||||
def log_policy(self, checkpoint_dir: Path):
|
||||
@@ -105,7 +109,9 @@ class WandBLogger:
|
||||
artifact_name = f"{self._group}-{step_id}"
|
||||
artifact_name = get_safe_wandb_artifact_name(artifact_name)
|
||||
artifact = self._wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
artifact.add_file(
|
||||
checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE
|
||||
)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
|
||||
Reference in New Issue
Block a user