Potential fixes for SAC instability and NAN bug

This commit is contained in:
joeclinton1
2025-01-03 21:12:59 +00:00
committed by Ke-Wang1017
parent f99e670976
commit db3925df28

View File

@@ -18,8 +18,7 @@
# TODO: (1) better device management
from collections import deque
from copy import deepcopy
from typing import Callable, Optional, Sequence, Tuple
from typing import Callable, Optional, Sequence, Tuple, Union
import einops
import numpy as np
@@ -72,8 +71,8 @@ class SACPolicy(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs
)
**config.critic_network_kwargs,
),
)
critic_nets.append(critic_net)
@@ -83,8 +82,8 @@ class SACPolicy(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs
)
**config.critic_network_kwargs,
),
)
target_critic_nets.append(target_critic_net)
@@ -93,15 +92,12 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(
input_dim=encoder_actor.output_dim,
**config.actor_network_kwargs
),
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
**config.policy_kwargs
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0])/2 # (-dim(A)/2)
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self):
@@ -126,7 +122,9 @@ class SACPolicy(
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
@@ -141,7 +139,6 @@ class SACPolicy(
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
@@ -175,17 +172,22 @@ class SACPolicy(
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics]
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
# breakpoint()
if self.config.use_backup_entropy:
min_q -= self.temperature() * log_probs * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
min_q -= (
self.temperature()
* log_probs
* ~batch["observation.state_is_pad"][:, 0]
* ~batch["action_is_pad"][:, 0]
) # shape: [batch_size, horizon]
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
# td_target -= self.config.discount * self.temperature() * log_probs \
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
# td_target -= self.config.discount * self.temperature() * log_probs \
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
# 3- compute predicted qs
@@ -195,17 +197,17 @@ class SACPolicy(
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = (
F.mse_loss(
q_preds,
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"][:,0] # shape: [batch_size, horizon]
* ~batch["observation.state_is_pad"][:,1] # shape: [batch_size, horizon+1]
).mean()
q_preds,
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"][:, 0] # shape: [batch_size, horizon]
* ~batch["observation.state_is_pad"][:, 1] # shape: [batch_size, horizon+1]
).mean()
# calculate actors loss
# 1- temperature
@@ -213,8 +215,8 @@ class SACPolicy(
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs, _ = self.actor(observations)
# 3- get q-value predictions
# with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False)
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False)
# q_preds_min = torch.min(q_preds, axis=0)
min_q_preds = q_preds.min(dim=0)[0]
# print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}")
@@ -222,56 +224,53 @@ class SACPolicy(
# breakpoint()
actor_loss = (
-(min_q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
).mean()
# calculate temperature loss
# 1- calculate entropy
with torch.no_grad():
actions, log_probs, _ = self.actor(observations)
entropy = -log_probs.mean()
temperature_loss = self.temperature(
lhs=entropy,
rhs=self.config.target_entropy
)
temperature_loss = self.temperature(lhs=entropy, rhs=self.config.target_entropy)
loss = critics_loss + actor_loss + temperature_loss
return {
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts":min_q_preds.min().item(),
"max_q_predicts":min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_mean": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": entropy.item(),
"loss": loss,
}
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts": min_q_preds.min().item(),
"max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": entropy.item(),
"loss": loss,
}
def update(self):
# TODO: implement UTD update
# First update only critics for utd_ratio-1 times
#for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target
# for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target
# Then update critic, critic target, actor and temperature
"""Update target networks with exponential moving average"""
with torch.no_grad():
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight +
target_param.data * (1.0 - self.config.critic_target_update_weight)
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
class MLP(nn.Module):
def __init__(
self,
@@ -296,13 +295,15 @@ class MLP(nn.Module):
# Rest of the layers
for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@@ -316,7 +317,7 @@ class Critic(nn.Module):
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
device: str = "cuda"
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
@@ -347,9 +348,7 @@ class Critic(nn.Module):
actions: torch.Tensor,
) -> torch.Tensor:
# Move each tensor in observations to device
observations = {
k: v.to(self.device) for k, v in observations.items()
}
observations = {k: v.to(self.device) for k, v in observations.items()}
actions = actions.to(self.device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
@@ -371,7 +370,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cuda"
device: str = "cuda",
):
super().__init__()
self.device = torch.device(device)
@@ -412,7 +411,6 @@ class Policy(nn.Module):
self,
observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations)
@@ -423,23 +421,28 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
log_std = self.fixed_std.expand_as(means)
# uses tahn activation function to squash the action to be in the range of [-1, 1]
# uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
x_t = torch.clamp(x_t, -2.0, 2.0)
log_probs = normal.log_prob(x_t)
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim
means = torch.tanh(means)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
log_probs = log_probs.sum(-1) # Sum over action dimensions
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
@@ -495,9 +498,7 @@ class SACObservationEncoder(nn.Module):
)
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.latent_dim
),
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
@@ -527,48 +528,47 @@ class SACObservationEncoder(nn.Module):
class LagrangeMultiplier(nn.Module):
def __init__(
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
super().__init__()
self.device = torch.device(device)
# init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
init_value = torch.tensor(init_value, device=self.device)
# Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
)
# Parameterize log(alpha) directly to ensure positivity
log_alpha = torch.log(torch.tensor(init_value, dtype=torch.float32, device=self.device))
self.log_alpha = nn.Parameter(torch.full(constraint_shape, log_alpha))
def forward(
self,
lhs: Optional[torch.Tensor | float | int] = None,
rhs: Optional[torch.Tensor | float | int] = None
lhs: Optional[Union[torch.Tensor, float, int]] = None,
rhs: Optional[Union[torch.Tensor, float, int]] = None,
) -> torch.Tensor:
# Get the multiplier value based on parameterization
# multiplier = torch.nn.functional.softplus(self.lagrange)
log_multiplier = torch.log(self.lagrange)
# Compute alpha = exp(log_alpha)
alpha = self.log_alpha.exp()
# Return the raw multiplier if no constraint values provided
# Return alpha directly if no constraints provided
if lhs is None:
return log_multiplier.exp()
return alpha
# Convert inputs to tensors and move to device
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
lhs = (
torch.tensor(lhs, device=self.device)
if not isinstance(lhs, torch.Tensor)
else lhs.to(self.device)
)
if rhs is not None:
rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device)
rhs = (
torch.tensor(rhs, device=self.device)
if not isinstance(rhs, torch.Tensor)
else rhs.to(self.device)
)
else:
rhs = torch.zeros_like(lhs, device=self.device)
# Compute the difference and apply the multiplier
diff = lhs - rhs
assert diff.shape == log_multiplier.shape, f"Shape mismatch: {diff.shape} vs {log_multiplier.shape}"
assert diff.shape == alpha.shape, f"Shape mismatch: {diff.shape} vs {alpha.shape}"
return log_multiplier.exp() * diff # numerically better
return alpha * diff
def orthogonal_init():
@@ -580,6 +580,7 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device)
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.