diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 9f0a60df..fb2e5542 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -89,12 +89,18 @@ class SACPolicy( Returns a dictionary with loss as a tensor, and other information as native floats. """ - observation_batch = - next_obaservation_batch = - action_batch = - reward_batch = - dones_batch = - + batch = self.normalize_inputs(batch) + # batch shape is (b, 2, ...) where index 1 returns the current observation and + # the next observation for caluculating the right td index. + actions = batch["action"][:, 0] + rewards = batch["next.reward"][:, 0] + observations = {} + next_observations = {} + for k in batch: + if k.startswith("observation."): + observations[k] = batch[k][:, 0] + next_observations[k] = batch[k][:, 1] + # perform image augmentation # reward bias @@ -104,34 +110,51 @@ class SACPolicy( # calculate critics loss # 1- compute actions from policy - next_actions = .. + action_preds, log_probs = self.actor_network(observations) # 2- compute q targets - q_targets = self.target_qs(next_obaservation_batch, next_actions) + q_targets = self.target_qs(next_observations, action_preds) # critics subsample size min_q = q_targets.min(dim=0) # backup entropy - td_target = reward_batch + self.discount * min_q + td_target = rewards + self.discount * min_q # 3- compute predicted qs - q_preds = self.critic_ensemble(observation_batch, action_batch) + q_preds = self.critic_ensemble(observations, actions) # 4- Calculate loss - critics_loss = F.mse_loss(q_preds, - einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])) # dones masks - + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + critics_loss = ( + F.mse_loss( + q_preds, + einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"] + * ~batch["observation.state_is_pad"][1:] + ).sum(0).mean() + # calculate actors loss # 1- temperature temperature = self.temperature() # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - actions, log_probs = self.actor_network(observation_batch) + actions, log_probs = self.actor_network(observations) \ # 3- get q-value predictions with torch.no_grad(): - q_preds = self.critic_ensemble(observation_batch, actions, return_type="mean") - actor_loss = -(q_preds - temperature * log_probs).mean() + q_preds = self.critic_ensemble(observations, actions, return_type="mean") + actor_loss = ( + -(q_preds - temperature * log_probs).mean() + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ).mean() + # calculate temperature loss # 1- calculate entropy @@ -141,8 +164,8 @@ class SACPolicy( loss = critics_loss + actor_loss + temperature_loss return { - "Q_value_loss": critics_loss.item(), - "pi_loss": actor_loss.item(), + "critics_loss": critics_loss.item(), + "actor_loss": actor_loss.item(), "temperature_loss": temperature_loss.item(), "temperature": temperature.item(), "entropy": entropy.item(), @@ -153,4 +176,12 @@ class SACPolicy( def update(self): self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight) #for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()): - # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) \ No newline at end of file + # target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight) + +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig): + + super().__init__() + self.config = config