Added support for checkpointing the policy. We can save and load the policy state dict, optimizers state, optimization step and interaction step

Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-30 17:39:41 +00:00
parent 03616db82c
commit aebea08a99
5 changed files with 215 additions and 91 deletions

View File

@@ -174,18 +174,32 @@ class Logger:
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer,
optimizer: Optimizer | dict,
scheduler: LRScheduler | None,
interaction_step: int | None = None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
if type(optimizer) is dict:
optimizer_state_dict = {}
for k in optimizer:
optimizer_state_dict[k] = optimizer[k].state_dict()
else:
optimizer_state_dict = optimizer.state_dict()
training_state = {
"step": train_step,
"optimizer": optimizer.state_dict(),
"optimizer": optimizer_state_dict,
**get_global_random_state(),
}
# Interaction step is related to the distributed training code
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
if interaction_step is not None:
training_state["interaction_step"] = interaction_step
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
@@ -197,6 +211,7 @@ class Logger:
optimizer: Optimizer,
scheduler: LRScheduler | None,
identifier: str,
interaction_step: int | None = None,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
@@ -208,16 +223,24 @@ class Logger:
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
"""
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
random state, and return the global training step.
"""
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
optimizer.load_state_dict(training_state["optimizer"])
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
if type(training_state["optimizer"]) is dict:
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
"Optimizer dictionaries do not have the same keys during resume!"
)
for k, v in training_state["optimizer"].items():
optimizer[k].load_state_dict(v)
else:
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
@@ -228,7 +251,7 @@ class Logger:
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"]
def log_dict(self, d, step:int | None = None, mode="train", custom_step_key: str | None = None):
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
"""Log a dictionary of metrics to WandB."""
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
@@ -236,10 +259,9 @@ class Logger:
raise ValueError("Either step or custom_step_key must be provided.")
if self._wandb is not None:
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
# increases with each wandb.log call, but in the case of asynchronous RL for example,
# multiple time steps is possible for example, the interaction step with the environment,
# multiple time steps is possible for example, the interaction step with the environment,
# the training step, the evaluation step, etc. So we need to define a custom step key
# to log the correct step for each metric.
if custom_step_key is not None and self._wandb_custom_step_key is None:
@@ -247,7 +269,7 @@ class Logger:
# custom step.
self._wandb_custom_step_key = f"{mode}/{custom_step_key}"
self._wandb.define_metric(self._wandb_custom_step_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str, wandb.Table)):
logging.warning(
@@ -267,8 +289,6 @@ class Logger:
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None

View File

@@ -29,6 +29,7 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy(
@@ -44,7 +45,6 @@ class SACPolicy(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
device: str = "cpu",
):
super().__init__()
@@ -92,7 +92,6 @@ class SACPolicy(
for _ in range(config.num_critics)
]
),
device=device,
)
self.critic_target = CriticEnsemble(
@@ -106,7 +105,6 @@ class SACPolicy(
for _ in range(config.num_critics)
]
),
device=device,
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
@@ -115,7 +113,6 @@ class SACPolicy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
device=device,
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
@@ -123,13 +120,22 @@ class SACPolicy(
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0"))
self.temperature = self.log_alpha.exp().item()
def reset(self):
"""Reset the policy"""
pass
def to(self, *args, **kwargs):
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
@@ -308,17 +314,12 @@ class CriticEnsemble(nn.Module):
encoder: Optional[nn.Module],
network_list: nn.Module,
init_final: Optional[float] = None,
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network_list = network_list
self.init_final = init_final
# for network in network_list:
# network.to(self.device)
# Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear):
@@ -329,29 +330,28 @@ class CriticEnsemble(nn.Module):
self.output_layers = []
if init_final is not None:
for _ in network_list:
output_layer = nn.Linear(out_features, 1, device=device)
output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(output_layer.weight, -init_final, init_final)
nn.init.uniform_(output_layer.bias, -init_final, init_final)
self.output_layers.append(output_layer)
else:
self.output_layers = []
for _ in network_list:
output_layer = nn.Linear(out_features, 1, device=device)
output_layer = nn.Linear(out_features, 1)
orthogonal_init()(output_layer.weight)
self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers)
self.to(self.device)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(self.device) for k, v in observations.items()}
actions = actions.to(self.device)
observations = {k: v.to(device) for k, v in observations.items()}
actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
@@ -375,17 +375,15 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cpu",
encoder_is_shared: bool = False,
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = []
@@ -417,8 +415,6 @@ class Policy(nn.Module):
orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
@@ -460,7 +456,8 @@ class Policy(nn.Module):
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
device = get_device_from_parameters(self)
observations = observations.to(device)
if self.encoder is not None:
with torch.inference_mode():
return self.encoder(observations)