Correct losses and factorisation

This commit is contained in:
Adil Zouitine
2025-01-07 17:07:55 +01:00
parent 89d8189d8b
commit 9edae4a8de
3 changed files with 30 additions and 93 deletions

View File

@@ -56,7 +56,8 @@ class SACConfig:
state_encoder_hidden_dim = 256
latent_dim = 256
target_entropy = None
backup_entropy = False
# backup_entropy = False
use_backup_entropy = True
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,

View File

@@ -98,7 +98,10 @@ class SACPolicy(
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
# TODO: fix later device
# TODO: Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
self.temperature = self.log_alpha.exp().item()
def reset(self):
"""
@@ -144,6 +147,9 @@ class SACPolicy(
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
# We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item()
batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
@@ -156,18 +162,11 @@ class SACPolicy(
observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1]
# perform image augmentation
# reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
with torch.no_grad():
action_preds, log_probs, _ = self.actor(next_observations)
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
@@ -177,14 +176,8 @@ class SACPolicy(
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
# breakpoint()
if self.config.use_backup_entropy:
min_q -= (
self.temperature()
* log_probs
* ~batch["observation.state_is_pad"][:, 0]
* ~batch["action_is_pad"][:, 0]
) # shape: [batch_size, horizon]
min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
# 3- compute predicted qs
@@ -192,43 +185,28 @@ class SACPolicy(
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
q_preds,
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
input=q_preds,
target=td_target_duplicate,
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"][:, 0] # shape: [batch_size, horizon]
* ~batch["observation.state_is_pad"][:, 1] # shape: [batch_size, horizon+1]
).mean()
).mean(1)
).sum()
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
temperature = self.temperature
actions, log_probs, _ = self.actor(observations)
# 3- get q-value predictions
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False)
# q_preds_min = torch.min(q_preds, axis=0)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = (
-(min_q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
).mean()
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
# calculate temperature loss
# 1- calculate entropy
with torch.no_grad():
actions, log_probs, _ = self.actor(observations)
entropy = -log_probs.mean()
temperature_loss = self.temperature(lhs=entropy, rhs=self.config.target_entropy)
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
loss = critics_loss + actor_loss + temperature_loss
@@ -239,14 +217,14 @@ class SACPolicy(
"min_q_predicts": min_q_preds.min().item(),
"max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"temperature": temperature,
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": entropy.item(),
"entropy": log_probs.mean().item(),
"loss": loss,
}
@@ -312,7 +290,7 @@ class Critic(nn.Module):
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
device: str = "cuda",
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
@@ -365,7 +343,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cuda",
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
@@ -438,6 +416,7 @@ class Policy(nn.Module):
actions = x_t # No Tanh; raw Gaussian sample
log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
@@ -522,55 +501,11 @@ class SACObservationEncoder(nn.Module):
return self.config.latent_dim
class LagrangeMultiplier(nn.Module):
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
super().__init__()
self.device = torch.device(device)
# Parameterize log(alpha) directly to ensure positivity
log_alpha = torch.log(torch.tensor(init_value, dtype=torch.float32, device=self.device))
self.log_alpha = nn.Parameter(torch.full(constraint_shape, log_alpha))
def forward(
self,
lhs: Optional[Union[torch.Tensor, float, int]] = None,
rhs: Optional[Union[torch.Tensor, float, int]] = None,
) -> torch.Tensor:
# Compute alpha = exp(log_alpha)
alpha = self.log_alpha.exp()
# Return alpha directly if no constraints provided
if lhs is None:
return alpha
# Convert inputs to tensors and move to device
lhs = (
torch.tensor(lhs, device=self.device)
if not isinstance(lhs, torch.Tensor)
else lhs.to(self.device)
)
if rhs is not None:
rhs = (
torch.tensor(rhs, device=self.device)
if not isinstance(rhs, torch.Tensor)
else rhs.to(self.device)
)
else:
rhs = torch.zeros_like(lhs, device=self.device)
# Compute the difference and apply the multiplier
diff = lhs - rhs
assert diff.shape == alpha.shape, f"Shape mismatch: {diff.shape} vs {alpha.shape}"
return alpha * diff
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device)

View File

@@ -99,7 +99,8 @@ def make_optimizer_and_scheduler(cfg, policy):
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
{"params": [policy.log_alpha], "lr": policy.config.temperature_lr},
]
)
lr_scheduler = None