Enhance SAC configuration and policy with gradient clipping and temperature management

- Introduced `grad_clip_norm` parameter in SAC configuration for gradient clipping
- Updated SACPolicy to store temperature as an instance variable for consistent usage
- Modified loss calculations in SACPolicy to utilize the instance temperature
- Enhanced MLP and CriticHead to support a customizable final activation function
- Implemented gradient clipping in the learner server during training steps for both actor and critic
- Added tracking for gradient norms in training information
This commit is contained in:
AdilZouitine
2025-03-17 10:50:28 +00:00
committed by Michel Aractingi
parent 599326508f
commit 66816fd871
3 changed files with 60 additions and 9 deletions

View File

@@ -390,6 +390,10 @@ def add_actor_information_and_train(
if cfg.resume
else None,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
clip_grad_norm_value = cfg.training.grad_clip_norm
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
@@ -507,6 +511,12 @@ def add_actor_information_and_train(
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
)
optimizers["critic"].step()
batch = replay_buffer.sample(batch_size)
@@ -541,10 +551,17 @@ def add_actor_information_and_train(
)
optimizers["critic"].zero_grad()
loss_critic.backward()
# clip gradients
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.critic_ensemble.parameters(), clip_grad_norm_value
).item()
optimizers["critic"].step()
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
@@ -555,19 +572,35 @@ def add_actor_information_and_train(
optimizers["actor"].zero_grad()
loss_actor.backward()
# clip gradients
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
policy.actor.parameters_to_optimize, clip_grad_norm_value
).item()
optimizers["actor"].step()
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
loss_temperature = policy.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
# clip gradients
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
[policy.log_alpha], clip_grad_norm_value
).item()
optimizers["temperature"].step()
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
if (
time.time() - last_time_policy_pushed