fix the bug of target critic updates, roll back to origial temperature implementation, added debug logging info

This commit is contained in:
KeWang1017
2025-01-02 21:05:14 +00:00
committed by Ke-Wang1017
parent f1f04eb4f9
commit eec28baa63
3 changed files with 60 additions and 41 deletions

View File

@@ -56,7 +56,7 @@ class SACConfig:
state_encoder_hidden_dim = 256 state_encoder_hidden_dim = 256
latent_dim = 256 latent_dim = 256
target_entropy = None target_entropy = None
backup_entropy = True backup_entropy = False
critic_network_kwargs = { critic_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,

View File

@@ -90,7 +90,7 @@ class SACPolicy(
**config.policy_kwargs **config.policy_kwargs
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) config.target_entropy = -np.prod(config.output_shapes["action"][0])/2 # (-dim(A)/2)
self.temperature = LagrangeMultiplier(init_value=config.temperature_init) self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self): def reset(self):
@@ -111,7 +111,7 @@ class SACPolicy(
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation""" """Select action for inference/evaluation"""
_, _, actions = self.actor(batch) actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions return actions
@@ -155,23 +155,28 @@ class SACPolicy(
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
action_preds, log_probs, _ = self.actor(next_observations) with torch.no_grad():
action_preds, log_probs, _ = self.actor(next_observations)
# 2- compute q targets # 2- compute q targets
q_targets = self.critic_forward(next_observations, action_preds, use_target=True) q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics) indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics] indices = indices[:self.config.num_subsample_critics]
q_targets = q_targets[indices] q_targets = q_targets[indices]
# critics subsample size # critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation min_q, _ = q_targets.min(dim=0) # Get values from min operation
# compute td target
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
# compute td target
td_target = rewards + self.config.discount * min_q
if self.config.use_backup_entropy:
td_target -= self.config.discount * self.temperature() * log_probs \
* ~batch["observation.state_is_pad"][:,0] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
# print(f"Target Q-values: mean={td_target.mean():.3f}, max={td_target.max():.3f}")
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False) q_preds = self.critic_forward(observations, actions, use_target=False)
@@ -197,10 +202,15 @@ class SACPolicy(
# 2- get actions (batch_size, action_dim) and log probs (batch_size,) # 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs, _ = self.actor(observations) actions, log_probs, _ = self.actor(observations)
# 3- get q-value predictions # 3- get q-value predictions
with torch.inference_mode(): # with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False) q_preds = self.critic_forward(observations, actions, use_target=False)
# q_preds_min = torch.min(q_preds, axis=0)
min_q_preds = q_preds.min(dim=0)[0]
# print(f"Q-values stats: mean={min_q_preds.mean():.3f}, min={min_q_preds.min():.3f}, max={min_q_preds.max():.3f}")
# print(f"Log probs stats: mean={log_probs.mean():.3f}, min={log_probs.min():.3f}, max={log_probs.max():.3f}")
# breakpoint()
actor_loss = ( actor_loss = (
-(q_preds - temperature * log_probs).mean() -(min_q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1] * ~batch["observation.state_is_pad"][:,0] # shape: [batch_size, horizon+1]
* ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon] * ~batch["action_is_pad"][:,0] # shape: [batch_size, horizon]
).mean() ).mean()
@@ -208,6 +218,8 @@ class SACPolicy(
# calculate temperature loss # calculate temperature loss
# 1- calculate entropy # 1- calculate entropy
with torch.no_grad():
actions, log_probs, _ = self.actor(observations)
entropy = -log_probs.mean() entropy = -log_probs.mean()
temperature_loss = self.temperature( temperature_loss = self.temperature(
lhs=entropy, lhs=entropy,
@@ -219,8 +231,17 @@ class SACPolicy(
return { return {
"critics_loss": critics_loss.item(), "critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(), "actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts":min_q_preds.min().item(),
"max_q_predicts":min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(), "temperature_loss": temperature_loss.item(),
"temperature": temperature.item(), "temperature": temperature.item(),
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_mean": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": entropy.item(), "entropy": entropy.item(),
"loss": loss, "loss": loss,
} }
@@ -236,8 +257,8 @@ class SACPolicy(
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_( target_param.data.copy_(
target_param.data * self.config.critic_target_update_weight + param.data * self.config.critic_target_update_weight +
param.data * (1.0 - self.config.critic_target_update_weight) target_param.data * (1.0 - self.config.critic_target_update_weight)
) )
class MLP(nn.Module): class MLP(nn.Module):
@@ -391,15 +412,16 @@ class Policy(nn.Module):
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
if self.use_tanh_squash: if self.use_tanh_squash:
log_std = torch.tanh(log_std) log_std = torch.tanh(log_std)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else: else:
log_std = self.fixed_std.expand_as(means) log_std = self.fixed_std.expand_as(means)
# uses tahn activation function to squash the action to be in the range of [-1, 1] # uses tahn activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std)) normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
x_t = torch.clamp(x_t, -2.0, 2.0)
log_probs = normal.log_prob(x_t) log_probs = normal.log_prob(x_t)
if self.use_tanh_squash: if self.use_tanh_squash:
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
@@ -456,19 +478,15 @@ class SACObservationEncoder(nn.Module):
) )
if "observation.state" in config.input_shapes: if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
if "observation.environment_state" in config.input_shapes: if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim config.input_shapes["observation.environment_state"][0], config.latent_dim
), ),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
@@ -506,26 +524,27 @@ class LagrangeMultiplier(nn.Module):
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) # init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
init_value = torch.tensor(init_value, device=self.device)
# Initialize the Lagrange multiplier as a parameter # Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter( self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
) )
self.to(self.device)
def forward( def forward(
self, self,
lhs: Optional[torch.Tensor | float | int] = None, lhs: Optional[torch.Tensor | float | int] = None,
rhs: Optional[torch.Tensor | float | int] = None rhs: Optional[torch.Tensor | float | int] = None
) -> torch.Tensor: ) -> torch.Tensor:
# Get the multiplier value based on parameterization # Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange) # multiplier = torch.nn.functional.softplus(self.lagrange)
log_multiplier = torch.log(self.lagrange)
# Return the raw multiplier if no constraint values provided # Return the raw multiplier if no constraint values provided
if lhs is None: if lhs is None:
return multiplier return log_multiplier.exp()
# Convert inputs to tensors and move to device # Convert inputs to tensors and move to device
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device) lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
@@ -536,9 +555,9 @@ class LagrangeMultiplier(nn.Module):
diff = lhs - rhs diff = lhs - rhs
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" assert diff.shape == log_multiplier.shape, f"Shape mismatch: {diff.shape} vs {log_multiplier.shape}"
return multiplier * diff return log_multiplier.exp() * diff # numerically better
def orthogonal_init(): def orthogonal_init():

View File

@@ -19,7 +19,7 @@ training:
grad_clip_norm: 10.0 grad_clip_norm: 10.0
lr: 3e-4 lr: 3e-4
eval_freq: 10000 eval_freq: 2500
log_freq: 500 log_freq: 500
save_freq: 50000 save_freq: 50000
@@ -29,7 +29,7 @@ training:
online_steps_between_rollouts: 1000 online_steps_between_rollouts: 1000
online_sampling_ratio: 1.0 online_sampling_ratio: 1.0
online_env_seed: 10000 online_env_seed: 10000
online_buffer_capacity: 40000 online_buffer_capacity: 10000
online_buffer_seed_size: 0 online_buffer_seed_size: 0
do_online_rollout_async: false do_online_rollout_async: false
@@ -70,9 +70,9 @@ policy:
temperature_init: 1.0 temperature_init: 1.0
num_critics: 2 num_critics: 2
num_subsample_critics: None num_subsample_critics: None
critic_lr: 3e-4 critic_lr: 1e-4
actor_lr: 3e-4 actor_lr: 1e-4
temperature_lr: 3e-4 temperature_lr: 1e-4
critic_target_update_weight: 0.005 critic_target_update_weight: 0.005
utd_ratio: 2 utd_ratio: 2