Added support for checkpointing the policy. We can save and load the policy state dict, optimizers state, optimization step and interaction step

Added functions for converting the replay buffer from and to LeRobotDataset. When we want to save the replay buffer, we convert it first to LeRobotDataset format and save it locally and vice-versa.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-30 17:39:41 +00:00
parent 03616db82c
commit aebea08a99
5 changed files with 215 additions and 91 deletions

View File

@@ -29,6 +29,7 @@ from torch import Tensor
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy(
@@ -44,7 +45,6 @@ class SACPolicy(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
device: str = "cpu",
):
super().__init__()
@@ -92,7 +92,6 @@ class SACPolicy(
for _ in range(config.num_critics)
]
),
device=device,
)
self.critic_target = CriticEnsemble(
@@ -106,7 +105,6 @@ class SACPolicy(
for _ in range(config.num_critics)
]
),
device=device,
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
@@ -115,7 +113,6 @@ class SACPolicy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
device=device,
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
@@ -123,13 +120,22 @@ class SACPolicy(
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
self.log_alpha = torch.zeros(1, requires_grad=True, device=torch.device("cuda:0"))
self.temperature = self.log_alpha.exp().item()
def reset(self):
"""Reset the policy"""
pass
def to(self, *args, **kwargs):
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
@@ -308,17 +314,12 @@ class CriticEnsemble(nn.Module):
encoder: Optional[nn.Module],
network_list: nn.Module,
init_final: Optional[float] = None,
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network_list = network_list
self.init_final = init_final
# for network in network_list:
# network.to(self.device)
# Find the last Linear layer's output dimension
for layer in reversed(network_list[0].net):
if isinstance(layer, nn.Linear):
@@ -329,29 +330,28 @@ class CriticEnsemble(nn.Module):
self.output_layers = []
if init_final is not None:
for _ in network_list:
output_layer = nn.Linear(out_features, 1, device=device)
output_layer = nn.Linear(out_features, 1)
nn.init.uniform_(output_layer.weight, -init_final, init_final)
nn.init.uniform_(output_layer.bias, -init_final, init_final)
self.output_layers.append(output_layer)
else:
self.output_layers = []
for _ in network_list:
output_layer = nn.Linear(out_features, 1, device=device)
output_layer = nn.Linear(out_features, 1)
orthogonal_init()(output_layer.weight)
self.output_layers.append(output_layer)
self.output_layers = nn.ModuleList(self.output_layers)
self.to(self.device)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(self.device) for k, v in observations.items()}
actions = actions.to(self.device)
observations = {k: v.to(device) for k, v in observations.items()}
actions = actions.to(device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
@@ -375,17 +375,15 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cpu",
encoder_is_shared: bool = False,
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash
self.parameters_to_optimize = []
@@ -417,8 +415,6 @@ class Policy(nn.Module):
orthogonal_init()(self.std_layer.weight)
self.parameters_to_optimize += list(self.std_layer.parameters())
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
@@ -460,7 +456,8 @@ class Policy(nn.Module):
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
device = get_device_from_parameters(self)
observations = observations.to(device)
if self.encoder is not None:
with torch.inference_mode():
return self.encoder(observations)