[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-28 17:20:38 +00:00
committed by Michel Aractingi
parent 8eb3c1510c
commit eb44a06a9b
16 changed files with 93 additions and 91 deletions

View File

@@ -115,11 +115,13 @@ class WandBLogger:
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None):
def log_dict(
self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None
):
if mode not in {"train", "eval"}:
raise ValueError(mode)
if step is None and custom_step_key is None:
raise ValueError("Either step or custom_step_key must be provided.")
raise ValueError("Either step or custom_step_key must be provided.")
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
@@ -142,10 +144,7 @@ class WandBLogger:
continue
# Do not log the custom step key itself.
if (
self._wandb_custom_step_key is not None
and k in self._wandb_custom_step_key
):
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if custom_step_key is not None:
@@ -160,7 +159,6 @@ class WandBLogger:
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
raise ValueError(mode)