Correct losses and factorisation
This commit is contained in:
@@ -56,7 +56,8 @@ class SACConfig:
|
||||
state_encoder_hidden_dim = 256
|
||||
latent_dim = 256
|
||||
target_entropy = None
|
||||
backup_entropy = False
|
||||
# backup_entropy = False
|
||||
use_backup_entropy = True
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
|
||||
@@ -98,7 +98,10 @@ class SACPolicy(
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
# TODO: fix later device
|
||||
# TODO: Handle the case where the temparameter is a fixed
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -144,6 +147,9 @@ class SACPolicy(
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
# We have to actualize the value of the temperature because in the previous
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for calculating the right td index.
|
||||
@@ -156,18 +162,11 @@ class SACPolicy(
|
||||
observations[k] = batch[k][:, 0]
|
||||
next_observations[k] = batch[k][:, 1]
|
||||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
with torch.no_grad():
|
||||
action_preds, log_probs, _ = self.actor(next_observations)
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
|
||||
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
@@ -177,14 +176,8 @@ class SACPolicy(
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
# breakpoint()
|
||||
if self.config.use_backup_entropy:
|
||||
min_q -= (
|
||||
self.temperature()
|
||||
* log_probs
|
||||
* ~batch["observation.state_is_pad"][:, 0]
|
||||
* ~batch["action_is_pad"][:, 0]
|
||||
) # shape: [batch_size, horizon]
|
||||
min_q -= self.temperature * next_log_probs
|
||||
td_target = rewards + self.config.discount * min_q * ~batch["next.done"]
|
||||
|
||||
# 3- compute predicted qs
|
||||
@@ -192,43 +185,28 @@ class SACPolicy(
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
|
||||
* ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"][:, 0] # shape: [batch_size, horizon]
|
||||
* ~batch["observation.state_is_pad"][:, 1] # shape: [batch_size, horizon+1]
|
||||
).mean()
|
||||
).mean(1)
|
||||
).sum()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
temperature = self.temperature
|
||||
actions, log_probs, _ = self.actor(observations)
|
||||
# 3- get q-value predictions
|
||||
with torch.inference_mode():
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
# q_preds_min = torch.min(q_preds, axis=0)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = (
|
||||
-(min_q_preds - temperature * log_probs).mean()
|
||||
* ~batch["observation.state_is_pad"][:, 0] # shape: [batch_size, horizon+1]
|
||||
* ~batch["action_is_pad"][:, 0] # shape: [batch_size, horizon]
|
||||
).mean()
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
with torch.no_grad():
|
||||
actions, log_probs, _ = self.actor(observations)
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = self.temperature(lhs=entropy, rhs=self.config.target_entropy)
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
@@ -239,14 +217,14 @@ class SACPolicy(
|
||||
"min_q_predicts": min_q_preds.min().item(),
|
||||
"max_q_predicts": min_q_preds.max().item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"temperature": temperature,
|
||||
"mean_log_probs": log_probs.mean().item(),
|
||||
"min_log_probs": log_probs.min().item(),
|
||||
"max_log_probs": log_probs.max().item(),
|
||||
"td_target_mean": td_target.mean().item(),
|
||||
"td_target_max": td_target.max().item(),
|
||||
"action_mean": actions.mean().item(),
|
||||
"entropy": entropy.item(),
|
||||
"entropy": log_probs.mean().item(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
@@ -312,7 +290,7 @@ class Critic(nn.Module):
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cuda",
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
@@ -365,7 +343,7 @@ class Policy(nn.Module):
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cuda",
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
@@ -438,6 +416,7 @@ class Policy(nn.Module):
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
@@ -522,55 +501,11 @@ class SACObservationEncoder(nn.Module):
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
class LagrangeMultiplier(nn.Module):
|
||||
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
|
||||
# Parameterize log(alpha) directly to ensure positivity
|
||||
log_alpha = torch.log(torch.tensor(init_value, dtype=torch.float32, device=self.device))
|
||||
self.log_alpha = nn.Parameter(torch.full(constraint_shape, log_alpha))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lhs: Optional[Union[torch.Tensor, float, int]] = None,
|
||||
rhs: Optional[Union[torch.Tensor, float, int]] = None,
|
||||
) -> torch.Tensor:
|
||||
# Compute alpha = exp(log_alpha)
|
||||
alpha = self.log_alpha.exp()
|
||||
|
||||
# Return alpha directly if no constraints provided
|
||||
if lhs is None:
|
||||
return alpha
|
||||
|
||||
# Convert inputs to tensors and move to device
|
||||
lhs = (
|
||||
torch.tensor(lhs, device=self.device)
|
||||
if not isinstance(lhs, torch.Tensor)
|
||||
else lhs.to(self.device)
|
||||
)
|
||||
if rhs is not None:
|
||||
rhs = (
|
||||
torch.tensor(rhs, device=self.device)
|
||||
if not isinstance(rhs, torch.Tensor)
|
||||
else rhs.to(self.device)
|
||||
)
|
||||
else:
|
||||
rhs = torch.zeros_like(lhs, device=self.device)
|
||||
|
||||
# Compute the difference and apply the multiplier
|
||||
diff = lhs - rhs
|
||||
|
||||
assert diff.shape == alpha.shape, f"Shape mismatch: {diff.shape} vs {alpha.shape}"
|
||||
|
||||
return alpha * diff
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
@@ -99,7 +99,8 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
{"params": [policy.log_alpha], "lr": policy.config.temperature_lr},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
Reference in New Issue
Block a user