forked from tangger/lerobot
Potential fixes for SAC instability and NAN bug
This commit is contained in:
@@ -18,8 +18,7 @@
|
||||
# TODO: (1) better device management
|
||||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Optional, Sequence, Tuple
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -72,8 +71,8 @@ class SACPolicy(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs
|
||||
)
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
@@ -83,8 +82,8 @@ class SACPolicy(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs
|
||||
)
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
|
||||
@@ -93,16 +92,13 @@ class SACPolicy(
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(
|
||||
input_dim=encoder_actor.output_dim,
|
||||
**config.actor_network_kwargs
|
||||
),
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
**config.policy_kwargs
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0])/2 # (-dim(A)/2)
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -125,15 +121,17 @@ class SACPolicy(
|
||||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
@@ -141,15 +139,14 @@ class SACPolicy(
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for calculating the right td index.
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for calculating the right td index.
|
||||
actions = batch["action"][:, 0]
|
||||
rewards = batch["next.reward"][:, 0]
|
||||
observations = {}
|
||||
@@ -158,12 +155,12 @@ class SACPolicy(
|
||||
if k.startswith("observation."):
|
||||
observations[k] = batch[k][:, 0]
|
||||
next_observations[k] = batch[k][:, 1]
|
||||
|
||||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias from HIL-SERL code base
|
||||
# reward bias from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
with torch.no_grad():
|
||||
@@ -175,103 +172,105 @@ class SACPolicy(
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[:self.config.num_subsample_critics]
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
# breakpoint()
|
||||
if self.config.use_backup_entropy:
|
||||
min_q -= self.temperature() * log_probs * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
min_q -= (
|
||||
self.temperature()
|
||||
* log_probs
|
||||
* ~batch["observation.state_is_pad"][:, 0]
|
||||
* ~batch["action_is_pad"][:, 0]
|
||||
) # shape: [batch_size, horizon]
|
||||
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
|
||||
# td_target -= self.config.discount * self.temperature() * log_probs \
|
||||
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
# td_target -= self.config.discount * self.temperature() * log_probs \
|
||||
# * ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
|
||||
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = (
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
|
||||
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
* ~batch["observation.state_is_pad"][:,1] # shape: [batch_size, horizon+1]
|
||||
).mean()
|
||||
|
||||
q_preds,
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
|
||||
* ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"][:, 0] # shape: [batch_size, horizon]
|
||||
* ~batch["observation.state_is_pad"][:, 1] # shape: [batch_size, horizon+1]
|
||||
).mean()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs, _ = self.actor(observations)
|
||||
# 3- get q-value predictions
|
||||
# with torch.inference_mode():
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
with torch.inference_mode():
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
# q_preds_min = torch.min(q_preds, axis=0)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
# print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}")
|
||||
# print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}")
|
||||
# breakpoint()
|
||||
actor_loss = (
|
||||
-(min_q_preds - temperature * log_probs).mean()
|
||||
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
|
||||
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
|
||||
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
|
||||
* ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
|
||||
).mean()
|
||||
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
with torch.no_grad():
|
||||
actions, log_probs, _ = self.actor(observations)
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = self.temperature(
|
||||
lhs=entropy,
|
||||
rhs=self.config.target_entropy
|
||||
)
|
||||
temperature_loss = self.temperature(lhs=entropy, rhs=self.config.target_entropy)
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"mean_q_predicts": min_q_preds.mean().item(),
|
||||
"min_q_predicts":min_q_preds.min().item(),
|
||||
"max_q_predicts":min_q_preds.max().item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"mean_log_probs": log_probs.mean().item(),
|
||||
"min_log_probs": log_probs.min().item(),
|
||||
"max_log_probs": log_probs.max().item(),
|
||||
"td_target_mean": td_target.mean().item(),
|
||||
"td_target_mean": td_target.max().item(),
|
||||
"action_mean": actions.mean().item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"mean_q_predicts": min_q_preds.mean().item(),
|
||||
"min_q_predicts": min_q_preds.min().item(),
|
||||
"max_q_predicts": min_q_preds.max().item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"mean_log_probs": log_probs.mean().item(),
|
||||
"min_log_probs": log_probs.min().item(),
|
||||
"max_log_probs": log_probs.max().item(),
|
||||
"td_target_mean": td_target.mean().item(),
|
||||
"td_target_max": td_target.max().item(),
|
||||
"action_mean": actions.mean().item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
def update(self):
|
||||
# TODO: implement UTD update
|
||||
# First update only critics for utd_ratio-1 times
|
||||
#for critic_step in range(self.config.utd_ratio - 1):
|
||||
# only update critic and critic target
|
||||
# for critic_step in range(self.config.utd_ratio - 1):
|
||||
# only update critic and critic target
|
||||
# Then update critic, critic target, actor and temperature
|
||||
"""Update target networks with exponential moving average"""
|
||||
with torch.no_grad():
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight +
|
||||
target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -284,52 +283,54 @@ class MLP(nn.Module):
|
||||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
|
||||
# First layer uses input_dim
|
||||
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||
|
||||
|
||||
# Add activation after first layer
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
|
||||
|
||||
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cuda"
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
@@ -338,22 +339,20 @@ class Critic(nn.Module):
|
||||
else:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Move each tensor in observations to device
|
||||
observations = {
|
||||
k: v.to(self.device) for k, v in observations.items()
|
||||
}
|
||||
observations = {k: v.to(self.device) for k, v in observations.items()}
|
||||
actions = actions.to(self.device)
|
||||
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
@@ -371,7 +370,7 @@ class Policy(nn.Module):
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cuda"
|
||||
device: str = "cuda",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
@@ -382,13 +381,13 @@ class Policy(nn.Module):
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
|
||||
# Mean layer
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
@@ -396,7 +395,7 @@ class Policy(nn.Module):
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
@@ -405,43 +404,47 @@ class Policy(nn.Module):
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tahn activation function to squash the action to be in the range of [-1, 1]
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
||||
x_t = torch.clamp(x_t, -2.0, 2.0)
|
||||
log_probs = normal.log_prob(x_t)
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
||||
log_probs = log_probs.sum(-1) # sum over action dim
|
||||
means = torch.tanh(means)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
return actions, log_probs, means
|
||||
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
@@ -495,9 +498,7 @@ class SACObservationEncoder(nn.Module):
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.latent_dim
|
||||
),
|
||||
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
@@ -519,7 +520,7 @@ class SACObservationEncoder(nn.Module):
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
@@ -527,48 +528,47 @@ class SACObservationEncoder(nn.Module):
|
||||
|
||||
|
||||
class LagrangeMultiplier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
init_value: float = 1.0,
|
||||
constraint_shape: Sequence[int] = (),
|
||||
device: str = "cuda"
|
||||
):
|
||||
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
# init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||
init_value = torch.tensor(init_value, device=self.device)
|
||||
|
||||
|
||||
# Initialize the Lagrange multiplier as a parameter
|
||||
self.lagrange = nn.Parameter(
|
||||
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
||||
)
|
||||
|
||||
# Parameterize log(alpha) directly to ensure positivity
|
||||
log_alpha = torch.log(torch.tensor(init_value, dtype=torch.float32, device=self.device))
|
||||
self.log_alpha = nn.Parameter(torch.full(constraint_shape, log_alpha))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lhs: Optional[torch.Tensor | float | int] = None,
|
||||
rhs: Optional[torch.Tensor | float | int] = None
|
||||
self,
|
||||
lhs: Optional[Union[torch.Tensor, float, int]] = None,
|
||||
rhs: Optional[Union[torch.Tensor, float, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
# Get the multiplier value based on parameterization
|
||||
# multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||
log_multiplier = torch.log(self.lagrange)
|
||||
# Compute alpha = exp(log_alpha)
|
||||
alpha = self.log_alpha.exp()
|
||||
|
||||
# Return the raw multiplier if no constraint values provided
|
||||
# Return alpha directly if no constraints provided
|
||||
if lhs is None:
|
||||
return log_multiplier.exp()
|
||||
|
||||
return alpha
|
||||
|
||||
# Convert inputs to tensors and move to device
|
||||
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
|
||||
lhs = (
|
||||
torch.tensor(lhs, device=self.device)
|
||||
if not isinstance(lhs, torch.Tensor)
|
||||
else lhs.to(self.device)
|
||||
)
|
||||
if rhs is not None:
|
||||
rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device)
|
||||
rhs = (
|
||||
torch.tensor(rhs, device=self.device)
|
||||
if not isinstance(rhs, torch.Tensor)
|
||||
else rhs.to(self.device)
|
||||
)
|
||||
else:
|
||||
rhs = torch.zeros_like(lhs, device=self.device)
|
||||
|
||||
|
||||
# Compute the difference and apply the multiplier
|
||||
diff = lhs - rhs
|
||||
|
||||
assert diff.shape == log_multiplier.shape, f"Shape mismatch: {diff.shape} vs {log_multiplier.shape}"
|
||||
|
||||
return log_multiplier.exp() * diff # numerically better
|
||||
|
||||
assert diff.shape == alpha.shape, f"Shape mismatch: {diff.shape} vs {alpha.shape}"
|
||||
|
||||
return alpha * diff
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
@@ -580,6 +580,7 @@ def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: s
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
@@ -587,7 +588,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
@@ -597,4 +598,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
Reference in New Issue
Block a user