Potential fixes for SAC instability and NAN bug

This commit is contained in:
joeclinton1
2025-01-03 21:12:59 +00:00
committed by Ke-Wang1017
parent f99e670976
commit db3925df28

View File

@@ -18,8 +18,7 @@
# TODO: (1) better device management # TODO: (1) better device management
from collections import deque from collections import deque
from copy import deepcopy from typing import Callable, Optional, Sequence, Tuple, Union
from typing import Callable, Optional, Sequence, Tuple
import einops import einops
import numpy as np import numpy as np
@@ -72,8 +71,8 @@ class SACPolicy(
encoder=encoder_critic, encoder=encoder_critic,
network=MLP( network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs **config.critic_network_kwargs,
) ),
) )
critic_nets.append(critic_net) critic_nets.append(critic_net)
@@ -83,8 +82,8 @@ class SACPolicy(
encoder=encoder_critic, encoder=encoder_critic,
network=MLP( network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs **config.critic_network_kwargs,
) ),
) )
target_critic_nets.append(target_critic_net) target_critic_nets.append(target_critic_net)
@@ -93,16 +92,13 @@ class SACPolicy(
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP( network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
input_dim=encoder_actor.output_dim,
**config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
**config.policy_kwargs **config.policy_kwargs,
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0])/2 # (-dim(A)/2) config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
self.temperature = LagrangeMultiplier(init_value=config.temperature_init) self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self): def reset(self):
""" """
@@ -125,15 +121,17 @@ class SACPolicy(
actions, _, _ = self.actor(batch) actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions return actions
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor: def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
) -> Tensor:
"""Forward pass through a critic network ensemble """Forward pass through a critic network ensemble
Args: Args:
observations: Dictionary of observations observations: Dictionary of observations
actions: Action tensor actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics use_target: If True, use target critics, otherwise use ensemble critics
Returns: Returns:
Tensor of Q-values from all critics Tensor of Q-values from all critics
""" """
@@ -141,15 +139,14 @@ class SACPolicy(
q_values = torch.stack([critic(observations, actions) for critic in critics]) q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats. Returns a dictionary with loss as a tensor, and other information as native floats.
""" """
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and # batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index. # the next observation for calculating the right td index.
actions = batch["action"][:, 0] actions = batch["action"][:, 0]
rewards = batch["next.reward"][:, 0] rewards = batch["next.reward"][:, 0]
observations = {} observations = {}
@@ -158,12 +155,12 @@ class SACPolicy(
if k.startswith("observation."): if k.startswith("observation."):
observations[k] = batch[k][:, 0] observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1] next_observations[k] = batch[k][:, 1]
# perform image augmentation # perform image augmentation
# reward bias from HIL-SERL code base # reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
with torch.no_grad(): with torch.no_grad():
@@ -175,103 +172,105 @@ class SACPolicy(
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics) indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics] indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices] q_targets = q_targets[indices]
# critics subsample size # critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation min_q, _ = q_targets.min(dim=0) # Get values from min operation
# breakpoint() # breakpoint()
if self.config.use_backup_entropy: if self.config.use_backup_entropy:
min_q -= self.temperature() * log_probs * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] min_q -= (
self.temperature()
* log_probs
* ~batch["observation.state_is_pad"][:, 0]
* ~batch["action_is_pad"][:, 0]
) # shape: [batch_size, horizon]
td_target = rewards + self.config.discount * min_q * ~batch["next.done"] td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
# td_target -= self.config.discount * self.temperature() * log_probs \ # td_target -= self.config.discount * self.temperature() * log_probs \
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] # * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}") # print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False) q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = ( critics_loss = (
F.mse_loss( F.mse_loss(
q_preds, q_preds,
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
reduction="none", reduction="none",
).sum(0) # sum over ensemble ).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions. # `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1] * ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] * ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
# q_targets depends on the reward and the next observations. # q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"][:,0] # shape: [batch_size, horizon] * ~batch["next.reward_is_pad"][:, 0] # shape: [batch_size, horizon]
* ~batch["observation.state_is_pad"][:,1] # shape: [batch_size, horizon+1] * ~batch["observation.state_is_pad"][:, 1] # shape: [batch_size, horizon+1]
).mean() ).mean()
# calculate actors loss # calculate actors loss
# 1- temperature # 1- temperature
temperature = self.temperature() temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,) # 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs, _ = self.actor(observations) actions, log_probs, _ = self.actor(observations)
# 3- get q-value predictions # 3- get q-value predictions
# with torch.inference_mode(): with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False) q_preds = self.critic_forward(observations, actions, use_target=False)
# q_preds_min = torch.min(q_preds, axis=0) # q_preds_min = torch.min(q_preds, axis=0)
min_q_preds = q_preds.min(dim=0)[0] min_q_preds = q_preds.min(dim=0)[0]
# print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}") # print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}")
# print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}") # print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}")
# breakpoint() # breakpoint()
actor_loss = ( actor_loss = (
-(min_q_preds - temperature * log_probs).mean() -(min_q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1] * ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] * ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
).mean() ).mean()
# calculate temperature loss # calculate temperature loss
# 1- calculate entropy # 1- calculate entropy
with torch.no_grad(): with torch.no_grad():
actions, log_probs, _ = self.actor(observations) actions, log_probs, _ = self.actor(observations)
entropy = -log_probs.mean() entropy = -log_probs.mean()
temperature_loss = self.temperature( temperature_loss = self.temperature(lhs=entropy, rhs=self.config.target_entropy)
lhs=entropy,
rhs=self.config.target_entropy
)
loss = critics_loss + actor_loss + temperature_loss loss = critics_loss + actor_loss + temperature_loss
return { return {
"critics_loss": critics_loss.item(), "critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(), "actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(), "mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts":min_q_preds.min().item(), "min_q_predicts": min_q_preds.min().item(),
"max_q_predicts":min_q_preds.max().item(), "max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(), "temperature_loss": temperature_loss.item(),
"temperature": temperature.item(), "temperature": temperature.item(),
"mean_log_probs": log_probs.mean().item(), "mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(), "min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(), "max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(), "td_target_mean": td_target.mean().item(),
"td_target_mean": td_target.max().item(), "td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(), "action_mean": actions.mean().item(),
"entropy": entropy.item(), "entropy": entropy.item(),
"loss": loss, "loss": loss,
} }
def update(self): def update(self):
# TODO: implement UTD update # TODO: implement UTD update
# First update only critics for utd_ratio-1 times # First update only critics for utd_ratio-1 times
#for critic_step in range(self.config.utd_ratio - 1): # for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target # only update critic and critic target
# Then update critic, critic target, actor and temperature # Then update critic, critic target, actor and temperature
"""Update target networks with exponential moving average""" """Update target networks with exponential moving average"""
with torch.no_grad(): with torch.no_grad():
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_( target_param.data.copy_(
param.data * self.config.critic_target_update_weight + param.data * self.config.critic_target_update_weight
target_param.data * (1.0 - self.config.critic_target_update_weight) + target_param.data * (1.0 - self.config.critic_target_update_weight)
) )
class MLP(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, self,
@@ -284,52 +283,54 @@ class MLP(nn.Module):
super().__init__() super().__init__()
self.activate_final = activate_final self.activate_final = activate_final
layers = [] layers = []
# First layer uses input_dim # First layer uses input_dim
layers.append(nn.Linear(input_dim, hidden_dims[0])) layers.append(nn.Linear(input_dim, hidden_dims[0]))
# Add activation after first layer # Add activation after first layer
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0])) layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
# Rest of the layers # Rest of the layers
for i in range(1, len(hidden_dims)): for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])) layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final: if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i])) layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers) self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x) return self.net(x)
class Critic(nn.Module): class Critic(nn.Module):
def __init__( def __init__(
self, self,
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
self.encoder = encoder self.encoder = encoder
self.network = network self.network = network
self.init_final = init_final self.init_final = init_final
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
out_features = layer.out_features out_features = layer.out_features
break break
# Output layer # Output layer
if init_final is not None: if init_final is not None:
self.output_layer = nn.Linear(out_features, 1) self.output_layer = nn.Linear(out_features, 1)
@@ -338,22 +339,20 @@ class Critic(nn.Module):
else: else:
self.output_layer = nn.Linear(out_features, 1) self.output_layer = nn.Linear(out_features, 1)
orthogonal_init()(self.output_layer.weight) orthogonal_init()(self.output_layer.weight)
self.to(self.device) self.to(self.device)
def forward( def forward(
self, self,
observations: dict[str, torch.Tensor], observations: dict[str, torch.Tensor],
actions: torch.Tensor, actions: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Move each tensor in observations to device # Move each tensor in observations to device
observations = { observations = {k: v.to(self.device) for k, v in observations.items()}
k: v.to(self.device) for k, v in observations.items()
}
actions = actions.to(self.device) actions = actions.to(self.device)
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1) inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs) x = self.network(inputs)
value = self.output_layer(x) value = self.output_layer(x)
@@ -371,7 +370,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
use_tanh_squash: bool = False, use_tanh_squash: bool = False,
device: str = "cuda" device: str = "cuda",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@@ -382,13 +381,13 @@ class Policy(nn.Module):
self.log_std_max = log_std_max self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.use_tanh_squash = use_tanh_squash self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension # Find the last Linear layer's output dimension
for layer in reversed(network.net): for layer in reversed(network.net):
if isinstance(layer, nn.Linear): if isinstance(layer, nn.Linear):
out_features = layer.out_features out_features = layer.out_features
break break
# Mean layer # Mean layer
self.mean_layer = nn.Linear(out_features, action_dim) self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None: if init_final is not None:
@@ -396,7 +395,7 @@ class Policy(nn.Module):
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.mean_layer.weight) orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter # Standard deviation layer or parameter
if fixed_std is None: if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim) self.std_layer = nn.Linear(out_features, action_dim)
@@ -405,43 +404,47 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final) nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else: else:
orthogonal_init()(self.std_layer.weight) orthogonal_init()(self.std_layer.weight)
self.to(self.device) self.to(self.device)
def forward( def forward(
self, self,
observations: torch.Tensor, observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
# Get network outputs # Get network outputs
outputs = self.network(obs_enc) outputs = self.network(obs_enc)
means = self.mean_layer(outputs) means = self.mean_layer(outputs)
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash: if self.use_tanh_squash:
log_std = torch.tanh(log_std) log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else: else:
log_std = self.fixed_std.expand_as(means) log_std = self.fixed_std.expand_as(means)
# uses tahn activation function to squash the action to be in the range of [-1, 1] # uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std)) normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
x_t = torch.clamp(x_t, -2.0, 2.0) log_probs = normal.log_prob(x_t) # Base log probability before Tanh
log_probs = normal.log_prob(x_t)
if self.use_tanh_squash: if self.use_tanh_squash:
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
log_probs = log_probs.sum(-1) # sum over action dim else:
means = torch.tanh(means) actions = x_t # No Tanh; raw Gaussian sample
log_probs = log_probs.sum(-1) # Sum over action dimensions
return actions, log_probs, means return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations""" """Get encoded features from observations"""
observations = observations.to(self.device) observations = observations.to(self.device)
@@ -495,9 +498,7 @@ class SACObservationEncoder(nn.Module):
) )
if "observation.environment_state" in config.input_shapes: if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
config.input_shapes["observation.environment_state"][0], config.latent_dim
),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
@@ -519,7 +520,7 @@ class SACObservationEncoder(nn.Module):
feat.append(self.state_enc_layers(obs_dict["observation.state"])) feat.append(self.state_enc_layers(obs_dict["observation.state"]))
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way # TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
return torch.stack(feat, dim=0).mean(0) return torch.stack(feat, dim=0).mean(0)
@property @property
def output_dim(self) -> int: def output_dim(self) -> int:
"""Returns the dimension of the encoder output""" """Returns the dimension of the encoder output"""
@@ -527,48 +528,47 @@ class SACObservationEncoder(nn.Module):
class LagrangeMultiplier(nn.Module): class LagrangeMultiplier(nn.Module):
def __init__( def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
# init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
init_value = torch.tensor(init_value, device=self.device)
# Parameterize log(alpha) directly to ensure positivity
# Initialize the Lagrange multiplier as a parameter log_alpha = torch.log(torch.tensor(init_value, dtype=torch.float32, device=self.device))
self.lagrange = nn.Parameter( self.log_alpha = nn.Parameter(torch.full(constraint_shape, log_alpha))
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
)
def forward( def forward(
self, self,
lhs: Optional[torch.Tensor | float | int] = None, lhs: Optional[Union[torch.Tensor, float, int]] = None,
rhs: Optional[torch.Tensor | float | int] = None rhs: Optional[Union[torch.Tensor, float, int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# Get the multiplier value based on parameterization # Compute alpha = exp(log_alpha)
# multiplier = torch.nn.functional.softplus(self.lagrange) alpha = self.log_alpha.exp()
log_multiplier = torch.log(self.lagrange)
# Return the raw multiplier if no constraint values provided # Return alpha directly if no constraints provided
if lhs is None: if lhs is None:
return log_multiplier.exp() return alpha
# Convert inputs to tensors and move to device # Convert inputs to tensors and move to device
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device) lhs = (
torch.tensor(lhs, device=self.device)
if not isinstance(lhs, torch.Tensor)
else lhs.to(self.device)
)
if rhs is not None: if rhs is not None:
rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device) rhs = (
torch.tensor(rhs, device=self.device)
if not isinstance(rhs, torch.Tensor)
else rhs.to(self.device)
)
else: else:
rhs = torch.zeros_like(lhs, device=self.device) rhs = torch.zeros_like(lhs, device=self.device)
# Compute the difference and apply the multiplier
diff = lhs - rhs diff = lhs - rhs
assert diff.shape == log_multiplier.shape, f"Shape mismatch: {diff.shape} vs {log_multiplier.shape}" assert diff.shape == alpha.shape, f"Shape mismatch: {diff.shape} vs {alpha.shape}"
return log_multiplier.exp() * diff # numerically better return alpha * diff
def orthogonal_init(): def orthogonal_init():
@@ -580,6 +580,7 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device) return nn.ModuleList(critics).to(device)
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.
@@ -587,7 +588,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
Args: Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions. (B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *. can be more than 1 dimensions, generally different from *.
Returns: Returns:
A return value from the callable reshaped to (**, *). A return value from the callable reshaped to (**, *).
@@ -597,4 +598,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
start_dims = image_tensor.shape[:-3] start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4) inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp) flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))