forked from tangger/lerobot
Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_10_dataset_v2.1
This commit is contained in:
@@ -58,7 +58,7 @@ def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
# Check that they have exactly the same set of keys.
|
||||
if target.keys() != source.keys():
|
||||
raise ValueError(
|
||||
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}"
|
||||
f"Dictionary keys do not match.\nExpected: {target.keys()}, got: {source.keys()}"
|
||||
)
|
||||
|
||||
# Recursively update each key.
|
||||
|
||||
@@ -102,7 +102,7 @@ class WandBLogger:
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
if mode in {"train", "eval"}:
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
for k, v in d.items():
|
||||
@@ -114,7 +114,7 @@ class WandBLogger:
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode in {"train", "eval"}:
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
||||
|
||||
Reference in New Issue
Block a user