From 9edae4a8de4eb141af04189b14d14ff9391cafd7 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 7 Jan 2025 17:07:55 +0100 Subject: [PATCH] Correct losses and factorisation --- .../common/policies/sac/configuration_sac.py | 3 +- lerobot/common/policies/sac/modeling_sac.py | 117 ++++-------------- lerobot/scripts/train.py | 3 +- 3 files changed, 30 insertions(+), 93 deletions(-) diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 8ad07a73..561bc788 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -56,7 +56,8 @@ class SACConfig: state_encoder_hidden_dim = 256 latent_dim = 256 target_entropy = None - backup_entropy = False + # backup_entropy = False + use_backup_entropy = True critic_network_kwargs = { "hidden_dims": [256, 256], "activate_final": True, diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index ff021956..50f93dea 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -98,7 +98,10 @@ class SACPolicy( ) if config.target_entropy is None: config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2) - self.temperature = LagrangeMultiplier(init_value=config.temperature_init) + # TODO: fix later device + # TODO: Handle the case where the temparameter is a fixed + self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu") + self.temperature = self.log_alpha.exp().item() def reset(self): """ @@ -144,6 +147,9 @@ class SACPolicy( Returns a dictionary with loss as a tensor, and other information as native floats. """ + # We have to actualize the value of the temperature because in the previous + self.temperature = self.log_alpha.exp().item() + batch = self.normalize_inputs(batch) # batch shape is (b, 2, ...) where index 1 returns the current observation and # the next observation for calculating the right td index. @@ -156,18 +162,11 @@ class SACPolicy( observations[k] = batch[k][:, 0] next_observations[k] = batch[k][:, 1] - # perform image augmentation - - # reward bias from HIL-SERL code base - # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - - # calculate critics loss - # 1- compute actions from policy with torch.no_grad(): - action_preds, log_probs, _ = self.actor(next_observations) + next_action_preds, next_log_probs, _ = self.actor(next_observations) # 2- compute q targets - q_targets = self.critic_forward(next_observations, action_preds, use_target=True) + q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True) # subsample critics to prevent overfitting if use high UTD (update to date) if self.config.num_subsample_critics is not None: @@ -177,14 +176,8 @@ class SACPolicy( # critics subsample size min_q, _ = q_targets.min(dim=0) # Get values from min operation - # breakpoint() if self.config.use_backup_entropy: - min_q -= ( - self.temperature() - * log_probs - * ~batch["observation.state_is_pad"][:, 0] - * ~batch["action_is_pad"][:, 0] - ) # shape: [batch_size, horizon] + min_q -= self.temperature * next_log_probs td_target = rewards + self.config.discount * min_q * ~batch["next.done"] # 3- compute predicted qs @@ -192,43 +185,28 @@ class SACPolicy( # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) + # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up critics_loss = ( F.mse_loss( - q_preds, - einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), + input=q_preds, + target=td_target_duplicate, reduction="none", - ).sum(0) # sum over ensemble - # `q_preds_ensemble` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1] - * ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon] - # q_targets depends on the reward and the next observations. - * ~batch["next.reward_is_pad"][:, 0] # shape: [batch_size, horizon] - * ~batch["observation.state_is_pad"][:, 1] # shape: [batch_size, horizon+1] - ).mean() + ).mean(1) + ).sum() - # calculate actors loss - # 1- temperature - temperature = self.temperature() - # 2- get actions (batch_size, action_dim) and log probs (batch_size,) + temperature = self.temperature actions, log_probs, _ = self.actor(observations) - # 3- get q-value predictions with torch.inference_mode(): q_preds = self.critic_forward(observations, actions, use_target=False) - # q_preds_min = torch.min(q_preds, axis=0) min_q_preds = q_preds.min(dim=0)[0] - actor_loss = ( - -(min_q_preds - temperature * log_probs).mean() - * ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1] - * ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon] - ).mean() + actor_loss = ((temperature * log_probs) - min_q_preds).mean() # calculate temperature loss - # 1- calculate entropy with torch.no_grad(): - actions, log_probs, _ = self.actor(observations) - entropy = -log_probs.mean() - temperature_loss = self.temperature(lhs=entropy, rhs=self.config.target_entropy) + _, log_probs, _ = self.actor(observations) + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean() loss = critics_loss + actor_loss + temperature_loss @@ -239,14 +217,14 @@ class SACPolicy( "min_q_predicts": min_q_preds.min().item(), "max_q_predicts": min_q_preds.max().item(), "temperature_loss": temperature_loss.item(), - "temperature": temperature.item(), + "temperature": temperature, "mean_log_probs": log_probs.mean().item(), "min_log_probs": log_probs.min().item(), "max_log_probs": log_probs.max().item(), "td_target_mean": td_target.mean().item(), "td_target_max": td_target.max().item(), "action_mean": actions.mean().item(), - "entropy": entropy.item(), + "entropy": log_probs.mean().item(), "loss": loss, } @@ -312,7 +290,7 @@ class Critic(nn.Module): encoder: Optional[nn.Module], network: nn.Module, init_final: Optional[float] = None, - device: str = "cuda", + device: str = "cpu", ): super().__init__() self.device = torch.device(device) @@ -365,7 +343,7 @@ class Policy(nn.Module): fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, use_tanh_squash: bool = False, - device: str = "cuda", + device: str = "cpu", ): super().__init__() self.device = torch.device(device) @@ -438,6 +416,7 @@ class Policy(nn.Module): actions = x_t # No Tanh; raw Gaussian sample log_probs = log_probs.sum(-1) # Sum over action dimensions + means = torch.tanh(means) if self.use_tanh_squash else means return actions, log_probs, means def get_features(self, observations: torch.Tensor) -> torch.Tensor: @@ -522,55 +501,11 @@ class SACObservationEncoder(nn.Module): return self.config.latent_dim -class LagrangeMultiplier(nn.Module): - def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"): - super().__init__() - self.device = torch.device(device) - - # Parameterize log(alpha) directly to ensure positivity - log_alpha = torch.log(torch.tensor(init_value, dtype=torch.float32, device=self.device)) - self.log_alpha = nn.Parameter(torch.full(constraint_shape, log_alpha)) - - def forward( - self, - lhs: Optional[Union[torch.Tensor, float, int]] = None, - rhs: Optional[Union[torch.Tensor, float, int]] = None, - ) -> torch.Tensor: - # Compute alpha = exp(log_alpha) - alpha = self.log_alpha.exp() - - # Return alpha directly if no constraints provided - if lhs is None: - return alpha - - # Convert inputs to tensors and move to device - lhs = ( - torch.tensor(lhs, device=self.device) - if not isinstance(lhs, torch.Tensor) - else lhs.to(self.device) - ) - if rhs is not None: - rhs = ( - torch.tensor(rhs, device=self.device) - if not isinstance(rhs, torch.Tensor) - else rhs.to(self.device) - ) - else: - rhs = torch.zeros_like(lhs, device=self.device) - - # Compute the difference and apply the multiplier - diff = lhs - rhs - - assert diff.shape == alpha.shape, f"Shape mismatch: {diff.shape} vs {alpha.shape}" - - return alpha * diff - - def orthogonal_init(): return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) -def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: +def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList: """Creates an ensemble of critic networks""" assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" return nn.ModuleList(critics).to(device) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index a4eb3528..7df18596 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -99,7 +99,8 @@ def make_optimizer_and_scheduler(cfg, policy): [ {"params": policy.actor.parameters(), "lr": policy.config.actor_lr}, {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr}, - {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr}, + # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module + {"params": [policy.log_alpha], "lr": policy.config.temperature_lr}, ] ) lr_scheduler = None