forked from tangger/lerobot
Merge remote-tracking branch 'origin/main' into user/aliberts/2025_02_10_dataset_v2.1
This commit is contained in:
@@ -4,3 +4,14 @@ OBS_ROBOT = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
|
||||
# files & directories
|
||||
CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
PRETRAINED_MODEL_DIR = "pretrained_model"
|
||||
TRAINING_STATE_DIR = "training_state"
|
||||
RNG_STATE = "rng_state.safetensors"
|
||||
TRAINING_STEP = "training_step.json"
|
||||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
|
||||
|
||||
# TODO(rcadene, alexander-soare): clean this file
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import asdict
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_global_random_state
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode
|
||||
|
||||
PRETRAINED_MODEL = "pretrained_model"
|
||||
TRAINING_STATE = "training_state.pth"
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Primary logger object. Logs either locally or using wandb.
|
||||
|
||||
The logger creates the following directory structure:
|
||||
|
||||
provided_log_dir
|
||||
├── checkpoints
|
||||
│ ├── specific_checkpoint_name
|
||||
│ │ ├── pretrained_model # Hugging Face pretrained model directory
|
||||
│ │ │ ├── ...
|
||||
│ │ └── training_state.pth # optimizer, scheduler, and random states + training step
|
||||
| ├── another_specific_checkpoint_name
|
||||
│ │ ├── ...
|
||||
| ├── ...
|
||||
│ └── last # a softlink to the last logged checkpoint
|
||||
"""
|
||||
|
||||
pretrained_model_dir_name = PRETRAINED_MODEL
|
||||
training_state_file_name = TRAINING_STATE
|
||||
|
||||
def __init__(self, cfg: TrainPipelineConfig):
|
||||
self._cfg = cfg
|
||||
self.log_dir = cfg.output_dir
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.job_name = cfg.job_name
|
||||
self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir)
|
||||
self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir)
|
||||
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir)
|
||||
|
||||
# Set up WandB.
|
||||
self._group = cfg_to_group(cfg)
|
||||
run_offline = not cfg.wandb.enable or not cfg.wandb.project
|
||||
if run_offline:
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=cfg.wandb.project,
|
||||
entity=cfg.wandb.entity,
|
||||
name=self.job_name,
|
||||
notes=cfg.wandb.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self.log_dir,
|
||||
config=asdict(self._cfg),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
@classmethod
|
||||
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
|
||||
return Path(log_dir) / "checkpoints"
|
||||
|
||||
@classmethod
|
||||
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
|
||||
return cls.get_checkpoints_dir(log_dir) / "last"
|
||||
|
||||
@classmethod
|
||||
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
|
||||
"""
|
||||
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
|
||||
be saved.
|
||||
"""
|
||||
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
|
||||
|
||||
def save_model(self, save_dir: Path, policy: PreTrainedPolicy, wandb_artifact_name: str | None = None):
|
||||
"""Save the weights of the Policy model using PyTorchModelHubMixin.
|
||||
|
||||
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
|
||||
|
||||
Optionally also upload the model to WandB.
|
||||
"""
|
||||
|
||||
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
||||
register_features_types()
|
||||
policy.save_pretrained(save_dir)
|
||||
# Also save the full config for the env configuration.
|
||||
self._cfg.save_pretrained(save_dir)
|
||||
if self._wandb and not self._cfg.wandb.disable_artifact:
|
||||
# note wandb artifact does not accept ":" or "/" in its name
|
||||
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
|
||||
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
if self.last_checkpoint_dir.exists():
|
||||
os.remove(self.last_checkpoint_dir)
|
||||
|
||||
def save_training_state(
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
training_state = {}
|
||||
training_state["step"] = train_step
|
||||
training_state.update(get_global_random_state())
|
||||
if optimizer is not None:
|
||||
training_state["optimizer"] = optimizer.state_dict()
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
train_step: int,
|
||||
identifier: str,
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
wandb_artifact_name = (
|
||||
None
|
||||
if self._wandb is None
|
||||
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
|
||||
)
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
|
||||
relative_target = checkpoint_dir.relative_to(self.last_checkpoint_dir.parent)
|
||||
self.last_checkpoint_dir.symlink_to(relative_target)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
assert self._wandb is not None
|
||||
wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
|
||||
def register_features_types():
|
||||
draccus.decode.register(FeatureType, lambda x: FeatureType[x])
|
||||
draccus.encode.register(FeatureType, lambda x: x.name)
|
||||
|
||||
draccus.decode.register(NormalizationMode, lambda x: NormalizationMode[x])
|
||||
draccus.encode.register(NormalizationMode, lambda x: x.name)
|
||||
@@ -14,15 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.logger import TRAINING_STATE
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
@@ -40,22 +36,5 @@ def make_optimizer_and_scheduler(
|
||||
"""
|
||||
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
|
||||
optimizer = cfg.optimizer.build(params)
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.offline.steps) if cfg.scheduler is not None else None
|
||||
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
|
||||
return optimizer, lr_scheduler
|
||||
|
||||
|
||||
def load_training_state(checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the checkpoint directory, load the optimizer state, scheduler state, and random state, and
|
||||
return the global training step.
|
||||
"""
|
||||
# TODO(aliberts): use safetensors instead as weights_only=False is unsafe
|
||||
training_state = torch.load(checkpoint_dir / TRAINING_STATE, weights_only=False)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
raise ValueError("The checkpoint contains a scheduler state_dict, but no LRScheduler was provided.")
|
||||
# Small HACK to get the expected keys: use `get_global_random_state`.
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], optimizer, scheduler
|
||||
|
||||
@@ -1,8 +1,32 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -68,3 +92,27 @@ class SGDConfig(OptimizerConfig):
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
||||
)
|
||||
loaded_state_dict["param_groups"] = param_groups
|
||||
|
||||
optimizer.load_state_dict(loaded_state_dict)
|
||||
return optimizer
|
||||
|
||||
@@ -1,11 +1,31 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.datasets.utils import write_json
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@dataclass
|
||||
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@@ -89,3 +109,14 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
return cosine_decay_schedule(current_step)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
||||
|
||||
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
|
||||
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return scheduler
|
||||
|
||||
@@ -144,7 +144,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
@@ -169,11 +169,11 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
loss = l1_loss
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
|
||||
@@ -143,7 +143,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
@@ -153,7 +153,8 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
|
||||
@@ -163,12 +163,17 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
|
||||
@abc.abstractmethod
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""_summary_
|
||||
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
Args:
|
||||
batch (dict[str, Tensor]): _description_
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
|
||||
is a Tensor, all other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -302,7 +302,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
@@ -495,7 +495,6 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
}
|
||||
)
|
||||
@@ -505,7 +504,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
return loss, info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's parameters with an EMA step."""
|
||||
|
||||
@@ -156,7 +156,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
@@ -170,16 +170,16 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
)
|
||||
return {
|
||||
"loss": loss,
|
||||
return loss, {
|
||||
"n_different_codes": n_different_codes,
|
||||
"n_different_combinations": n_different_combinations,
|
||||
"recon_l1_error": recon_l1_error,
|
||||
}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
loss = loss_dict.pop("loss")
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
@@ -342,7 +342,7 @@ class VQBeTModel(nn.Module):
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
@@ -482,7 +482,7 @@ class VQBeTHead(nn.Module):
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations, recon_l1_error
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
def forward(self, x, **kwargs) -> dict:
|
||||
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||
N, T, _ = x.shape
|
||||
# we calculate N and T side parallely. Thus, the dimensions would be
|
||||
|
||||
@@ -13,10 +13,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import imageio
|
||||
|
||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||
T = TypeVar("T", bound=JsonLike)
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
# Filter out DeprecationWarnings raised from pkg_resources
|
||||
@@ -25,3 +31,81 @@ def write_video(video_path, stacked_frames, fps):
|
||||
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||
)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
|
||||
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
|
||||
"""
|
||||
Loads the JSON data from `fpath` and recursively fills `obj` with the
|
||||
corresponding values (strictly matching structure and types).
|
||||
Tuples in `obj` are expected to be lists in the JSON data, which will be
|
||||
converted back into tuples.
|
||||
"""
|
||||
with open(fpath, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
def _deserialize(target, source):
|
||||
"""
|
||||
Recursively overwrite the structure in `target` with data from `source`,
|
||||
performing strict checks on structure and type.
|
||||
Returns the updated version of `target` (especially important for tuples).
|
||||
"""
|
||||
|
||||
# If the target is a dictionary, source must be a dictionary as well.
|
||||
if isinstance(target, dict):
|
||||
if not isinstance(source, dict):
|
||||
raise TypeError(f"Type mismatch: expected dict, got {type(source)}")
|
||||
|
||||
# Check that they have exactly the same set of keys.
|
||||
if target.keys() != source.keys():
|
||||
raise ValueError(
|
||||
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}"
|
||||
)
|
||||
|
||||
# Recursively update each key.
|
||||
for k in target:
|
||||
target[k] = _deserialize(target[k], source[k])
|
||||
|
||||
return target
|
||||
|
||||
# If the target is a list, source must be a list as well.
|
||||
elif isinstance(target, list):
|
||||
if not isinstance(source, list):
|
||||
raise TypeError(f"Type mismatch: expected list, got {type(source)}")
|
||||
|
||||
# Check length
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
|
||||
|
||||
# Recursively update each element.
|
||||
for i in range(len(target)):
|
||||
target[i] = _deserialize(target[i], source[i])
|
||||
|
||||
return target
|
||||
|
||||
# If the target is a tuple, the source must be a list in JSON,
|
||||
# which we'll convert back to a tuple.
|
||||
elif isinstance(target, tuple):
|
||||
if not isinstance(source, list):
|
||||
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
|
||||
|
||||
if len(target) != len(source):
|
||||
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
|
||||
|
||||
# Convert each element, forming a new tuple.
|
||||
converted_items = []
|
||||
for t_item, s_item in zip(target, source, strict=False):
|
||||
converted_items.append(_deserialize(t_item, s_item))
|
||||
|
||||
# Return a brand new tuple (tuples are immutable in Python).
|
||||
return tuple(converted_items)
|
||||
|
||||
# Otherwise, we're dealing with a "primitive" (int, float, str, bool, None).
|
||||
else:
|
||||
# Check the exact type. If these must match 1:1, do:
|
||||
if type(target) is not type(source):
|
||||
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
|
||||
return source
|
||||
|
||||
# Perform the in-place/recursive deserialization
|
||||
updated_obj = _deserialize(obj, data)
|
||||
return updated_obj
|
||||
|
||||
163
lerobot/common/utils/logging_utils.py
Normal file
163
lerobot/common/utils/logging_utils.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.utils.utils import format_big_number
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, fmt: str = ":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.sum = 0.0
|
||||
self.count = 0.0
|
||||
|
||||
def update(self, val: float, n: int = 1) -> None:
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = "{name}:{avg" + self.fmt + "}"
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class MetricsTracker:
|
||||
"""
|
||||
A helper class to track and log metrics over time.
|
||||
|
||||
Usage pattern:
|
||||
|
||||
```python
|
||||
# initialize, potentially with non-zero initial step (e.g. if resuming run)
|
||||
metrics = {"loss": AverageMeter("loss", ":.3f")}
|
||||
train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step)
|
||||
|
||||
# update metrics derived from step (samples, episodes, epochs) at each training step
|
||||
train_metrics.step()
|
||||
|
||||
# update various metrics
|
||||
loss = policy.forward(batch)
|
||||
train_metrics.loss = loss
|
||||
|
||||
# display current metrics
|
||||
logging.info(train_metrics)
|
||||
|
||||
# export for wandb
|
||||
wandb.log(train_metrics.to_dict())
|
||||
|
||||
# reset averages after logging
|
||||
train_metrics.reset_averages()
|
||||
```
|
||||
"""
|
||||
|
||||
__keys__ = [
|
||||
"_batch_size",
|
||||
"_num_frames",
|
||||
"_avg_samples_per_ep",
|
||||
"metrics",
|
||||
"steps",
|
||||
"samples",
|
||||
"episodes",
|
||||
"epochs",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
num_episodes: int,
|
||||
metrics: dict[str, AverageMeter],
|
||||
initial_step: int = 0,
|
||||
):
|
||||
self.__dict__.update({k: None for k in self.__keys__})
|
||||
self._batch_size = batch_size
|
||||
self._num_frames = num_frames
|
||||
self._avg_samples_per_ep = num_frames / num_episodes
|
||||
self.metrics = metrics
|
||||
|
||||
self.steps = initial_step
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
||||
self.samples = self.steps * self._batch_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||
if name in self.__dict__:
|
||||
return self.__dict__[name]
|
||||
elif name in self.metrics:
|
||||
return self.metrics[name]
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in self.__dict__:
|
||||
super().__setattr__(name, value)
|
||||
elif name in self.metrics:
|
||||
self.metrics[name].update(value)
|
||||
else:
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
def __str__(self) -> str:
|
||||
display_list = [
|
||||
f"step:{format_big_number(self.steps)}",
|
||||
# number of samples seen during training
|
||||
f"smpl:{format_big_number(self.samples)}",
|
||||
# number of episodes seen during training
|
||||
f"ep:{format_big_number(self.episodes)}",
|
||||
# number of time all unique samples are seen
|
||||
f"epch:{self.epochs:.2f}",
|
||||
*[str(m) for m in self.metrics.values()],
|
||||
]
|
||||
return " ".join(display_list)
|
||||
|
||||
def to_dict(self, use_avg: bool = True) -> dict[str, int | float]:
|
||||
"""
|
||||
Returns the current metric values (or averages if `use_avg=True`) as a dict.
|
||||
"""
|
||||
return {
|
||||
"steps": self.steps,
|
||||
"samples": self.samples,
|
||||
"episodes": self.episodes,
|
||||
"epochs": self.epochs,
|
||||
**{k: m.avg if use_avg else m.val for k, m in self.metrics.items()},
|
||||
}
|
||||
|
||||
def reset_averages(self) -> None:
|
||||
"""Resets average meters."""
|
||||
for m in self.metrics.values():
|
||||
m.reset()
|
||||
191
lerobot/common/utils/random_utils.py
Normal file
191
lerobot/common/utils/random_utils.py
Normal file
@@ -0,0 +1,191 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.constants import RNG_STATE
|
||||
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict
|
||||
|
||||
|
||||
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using
|
||||
`safetensors.save_file()` or `torch.save()`.
|
||||
"""
|
||||
py_state = random.getstate()
|
||||
return {
|
||||
"py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
|
||||
"py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
|
||||
}
|
||||
|
||||
|
||||
def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
||||
"""
|
||||
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
||||
random.setstate(py_state)
|
||||
|
||||
|
||||
def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using
|
||||
`safetensors.save_file()` or `torch.save()`.
|
||||
"""
|
||||
np_state = np.random.get_state()
|
||||
# Ensure no breaking changes from numpy
|
||||
assert np_state[0] == "MT19937"
|
||||
return {
|
||||
"np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
|
||||
"np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
|
||||
"np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
|
||||
"np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
|
||||
}
|
||||
|
||||
|
||||
def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`.
|
||||
"""
|
||||
np_state = (
|
||||
"MT19937",
|
||||
rng_state_dict["np_rng_state_values"].numpy(),
|
||||
rng_state_dict["np_rng_state_index"].item(),
|
||||
rng_state_dict["np_rng_has_gauss"].item(),
|
||||
rng_state_dict["np_rng_cached_gaussian"].item(),
|
||||
)
|
||||
np.random.set_state(np_state)
|
||||
|
||||
|
||||
def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using
|
||||
`safetensors.save_file()` or `torch.save()`.
|
||||
"""
|
||||
torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
|
||||
if torch.cuda.is_available():
|
||||
torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
|
||||
return torch_rng_state_dict
|
||||
|
||||
|
||||
def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`.
|
||||
"""
|
||||
torch.set_rng_state(rng_state_dict["torch_rng_state"])
|
||||
if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
|
||||
torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
|
||||
|
||||
|
||||
def serialize_rng_state() -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat
|
||||
dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`.
|
||||
"""
|
||||
py_rng_state_dict = serialize_python_rng_state()
|
||||
np_rng_state_dict = serialize_numpy_rng_state()
|
||||
torch_rng_state_dict = serialize_torch_rng_state()
|
||||
|
||||
return {
|
||||
**py_rng_state_dict,
|
||||
**np_rng_state_dict,
|
||||
**torch_rng_state_dict,
|
||||
}
|
||||
|
||||
|
||||
def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by
|
||||
`serialize_rng_state()`.
|
||||
"""
|
||||
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
||||
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
||||
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
||||
|
||||
deserialize_python_rng_state(py_rng_state_dict)
|
||||
deserialize_numpy_rng_state(np_rng_state_dict)
|
||||
deserialize_torch_rng_state(torch_rng_state_dict)
|
||||
|
||||
|
||||
def save_rng_state(save_dir: Path) -> None:
|
||||
rng_state_dict = serialize_rng_state()
|
||||
flat_rng_state_dict = flatten_dict(rng_state_dict)
|
||||
save_file(flat_rng_state_dict, save_dir / RNG_STATE)
|
||||
|
||||
|
||||
def load_rng_state(save_dir: Path) -> None:
|
||||
flat_rng_state_dict = load_file(save_dir / RNG_STATE)
|
||||
rng_state_dict = unflatten_dict(flat_rng_state_dict)
|
||||
deserialize_rng_state(rng_state_dict)
|
||||
|
||||
|
||||
def get_rng_state() -> dict[str, Any]:
|
||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||
random_state_dict = {
|
||||
"random_state": random.getstate(),
|
||||
"numpy_random_state": np.random.get_state(),
|
||||
"torch_random_state": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
||||
return random_state_dict
|
||||
|
||||
|
||||
def set_rng_state(random_state_dict: dict[str, Any]):
|
||||
"""Set the random state for `random`, `numpy`, and `torch`.
|
||||
|
||||
Args:
|
||||
random_state_dict: A dictionary of the form returned by `get_rng_state`.
|
||||
"""
|
||||
random.setstate(random_state_dict["random_state"])
|
||||
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_seed(seed) -> None:
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
a = random.random() # produces some random number
|
||||
with seeded_context(1337):
|
||||
b = random.random() # produces some other random number
|
||||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||
```
|
||||
"""
|
||||
random_state_dict = get_rng_state()
|
||||
set_seed(seed)
|
||||
yield None
|
||||
set_rng_state(random_state_dict)
|
||||
161
lerobot/common/utils/train_utils.py
Normal file
161
lerobot/common/utils/train_utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.common.datasets.utils import load_json, write_json
|
||||
from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.random_utils import load_rng_state, save_rng_state
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
|
||||
def get_step_identifier(step: int, total_steps: int) -> str:
|
||||
num_digits = max(6, len(str(total_steps)))
|
||||
return f"{step:0{num_digits}d}"
|
||||
|
||||
|
||||
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
|
||||
"""Returns the checkpoint sub-directory corresponding to the step number."""
|
||||
step_identifier = get_step_identifier(step, total_steps)
|
||||
return output_dir / CHECKPOINTS_DIR / step_identifier
|
||||
|
||||
|
||||
def save_training_step(step: int, save_dir: Path) -> None:
|
||||
write_json({"step": step}, save_dir / TRAINING_STEP)
|
||||
|
||||
|
||||
def load_training_step(save_dir: Path) -> int:
|
||||
training_step = load_json(save_dir / TRAINING_STEP)
|
||||
return training_step["step"]
|
||||
|
||||
|
||||
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
||||
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
||||
if last_checkpoint_dir.is_symlink():
|
||||
last_checkpoint_dir.unlink()
|
||||
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
|
||||
last_checkpoint_dir.symlink_to(relative_target)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
checkpoint_dir: Path,
|
||||
step: int,
|
||||
cfg: TrainPipelineConfig,
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
005000/ # training step at checkpoint
|
||||
├── pretrained_model/
|
||||
│ ├── config.json # policy config
|
||||
│ ├── model.safetensors # policy weights
|
||||
│ └── train_config.json # train config
|
||||
└── training_state/
|
||||
├── optimizer_param_groups.json # optimizer param groups
|
||||
├── optimizer_state.safetensors # optimizer state
|
||||
├── rng_state.safetensors # rng states
|
||||
├── scheduler_state.json # scheduler state
|
||||
└── training_step.json # training step
|
||||
|
||||
Args:
|
||||
cfg (TrainPipelineConfig): The training config used for this run.
|
||||
step (int): The training step at that checkpoint.
|
||||
policy (PreTrainedPolicy): The policy to save.
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
def save_training_state(
|
||||
checkpoint_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer | None = None,
|
||||
scheduler: LRScheduler | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
|
||||
Args:
|
||||
save_dir (Path): The directory to save artifacts to.
|
||||
train_step (int): Current training step.
|
||||
optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict.
|
||||
Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
if scheduler is not None:
|
||||
save_scheduler_state(scheduler, save_dir)
|
||||
|
||||
|
||||
def load_training_state(
|
||||
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
|
||||
) -> tuple[int, Optimizer, LRScheduler | None]:
|
||||
"""
|
||||
Loads the training step, optimizer state, scheduler state, and rng state.
|
||||
This is used to resume a training run.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
||||
optimizer (Optimizer): The optimizer to load the state_dict to.
|
||||
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
||||
|
||||
Raises:
|
||||
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
||||
|
||||
Returns:
|
||||
tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their
|
||||
state_dict loaded.
|
||||
"""
|
||||
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
if not training_state_dir.is_dir():
|
||||
raise NotADirectoryError(training_state_dir)
|
||||
|
||||
load_rng_state(training_state_dir)
|
||||
step = load_training_step(training_state_dir)
|
||||
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
||||
if scheduler is not None:
|
||||
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
||||
|
||||
return step, optimizer, scheduler
|
||||
@@ -17,14 +17,10 @@ import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -106,59 +102,6 @@ def is_amp_available(device: str):
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def get_global_random_state() -> dict[str, Any]:
|
||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||
random_state_dict = {
|
||||
"random_state": random.getstate(),
|
||||
"numpy_random_state": np.random.get_state(),
|
||||
"torch_random_state": torch.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
||||
return random_state_dict
|
||||
|
||||
|
||||
def set_global_random_state(random_state_dict: dict[str, Any]):
|
||||
"""Set the random state for `random`, `numpy`, and `torch`.
|
||||
|
||||
Args:
|
||||
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
|
||||
"""
|
||||
random.setstate(random_state_dict["random_state"])
|
||||
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_global_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
a = random.random() # produces some random number
|
||||
with seeded_context(1337):
|
||||
b = random.random() # produces some other random number
|
||||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||
```
|
||||
"""
|
||||
random_state_dict = get_global_random_state()
|
||||
set_global_seed(seed)
|
||||
yield None
|
||||
set_global_random_state(random_state_dict)
|
||||
|
||||
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
121
lerobot/common/utils/wandb_utils.py
Normal file
121
lerobot/common/utils/wandb_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.constants import PRETRAINED_MODEL_DIR
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
|
||||
"""Return a group name for logging. Optionally returns group name as list."""
|
||||
lst = [
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"dataset:{cfg.dataset.repo_id}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.env is not None:
|
||||
lst.append(f"env:{cfg.env.type}")
|
||||
return lst if return_list else "-".join(lst)
|
||||
|
||||
|
||||
def get_wandb_run_id_from_filesystem(log_dir: Path) -> str:
|
||||
# Get the WandB run ID.
|
||||
paths = glob(str(log_dir / "wandb/latest-run/run-*"))
|
||||
if len(paths) != 1:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
|
||||
if match is None:
|
||||
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
|
||||
wandb_run_id = match.groups(0)[0]
|
||||
return wandb_run_id
|
||||
|
||||
|
||||
def get_safe_wandb_artifact_name(name: str):
|
||||
"""WandB artifacts don't accept ":" or "/" in their name."""
|
||||
return name.replace(":", "_").replace("/", "_")
|
||||
|
||||
|
||||
class WandBLogger:
|
||||
"""A helper class to log object using wandb."""
|
||||
|
||||
def __init__(self, cfg: TrainPipelineConfig):
|
||||
self.cfg = cfg.wandb
|
||||
self.log_dir = cfg.output_dir
|
||||
self.job_name = cfg.job_name
|
||||
self.env_fps = cfg.env.fps if cfg.env else None
|
||||
self._group = cfg_to_group(cfg)
|
||||
|
||||
# Set up WandB.
|
||||
os.environ["WANDB_SILENT"] = "True"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=self.cfg.project,
|
||||
entity=self.cfg.entity,
|
||||
name=self.job_name,
|
||||
notes=self.cfg.notes,
|
||||
tags=cfg_to_group(cfg, return_list=True),
|
||||
dir=self.log_dir,
|
||||
config=cfg.to_dict(),
|
||||
# TODO(rcadene): try set to True
|
||||
save_code=False,
|
||||
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def log_policy(self, checkpoint_dir: Path):
|
||||
"""Checkpoints the policy to wandb."""
|
||||
if self.cfg.disable_artifact:
|
||||
return
|
||||
|
||||
step_id = checkpoint_dir.name
|
||||
artifact_name = f"{self._group}-{step_id}"
|
||||
artifact_name = get_safe_wandb_artifact_name(artifact_name)
|
||||
artifact = self._wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
|
||||
self._wandb.log_artifact(artifact)
|
||||
|
||||
def log_dict(self, d: dict, step: int, mode: str = "train"):
|
||||
if mode in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
Reference in New Issue
Block a user