Refactor SACPolicy for improved readability and action dimension handling

- Cleaned up code formatting for better readability, including consistent spacing and removal of unnecessary blank lines.
- Consolidated continuous action dimension calculation to enhance clarity and maintainability.
- Simplified loss return statements in the forward method to improve code structure.
- Ensured grasp critic parameters are included conditionally based on configuration settings.
This commit is contained in:
AdilZouitine
2025-04-01 15:43:29 +00:00
committed by Michel Aractingi
parent d86d29fe21
commit f9fb9d4594
2 changed files with 37 additions and 36 deletions

View File

@@ -394,7 +394,7 @@ def add_actor_information_and_train(
# Use the forward method for critic loss (includes both main critic and grasp critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
@@ -405,7 +405,7 @@ def add_actor_information_and_train(
optimizers["critic"].step()
# Grasp critic optimization (if available)
if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
if "loss_grasp_critic" in critic_output:
loss_grasp_critic = critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
@@ -450,7 +450,7 @@ def add_actor_information_and_train(
# Use the forward method for critic loss (includes both main critic and grasp critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
@@ -475,7 +475,7 @@ def add_actor_information_and_train(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
# Add grasp critic info to training info
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
@@ -492,7 +492,7 @@ def add_actor_information_and_train(
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
@@ -506,7 +506,7 @@ def add_actor_information_and_train(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
@@ -756,7 +756,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=policy.critic_lr