Refactor SACPolicy and configuration to replace 'grasp_critic' terminology with 'discrete_critic'. Update related methods and comments for clarity and consistency in handling discrete actions.

This commit is contained in:
AdilZouitine
2025-04-18 14:38:22 +00:00
parent 0d70f0b85c
commit a7a51cfc9c
4 changed files with 70 additions and 68 deletions

View File

@@ -833,7 +833,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
if self.gripper_penalty_in_reward:
reward += gripper_penalty
else:
info["gripper_penalty"] = gripper_penalty
info["discrete_penalty"] = gripper_penalty
return obs, reward, terminated, truncated, info

View File

@@ -410,7 +410,7 @@ def add_actor_information_and_train(
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss (includes both main critic and grasp critic)
# Use the forward method for critic loss (includes both main critic and discrete critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
@@ -422,16 +422,16 @@ def add_actor_information_and_train(
)
optimizers["critic"].step()
# Grasp critic optimization (if available)
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
optimizers["discrete_critic"].step()
# Update target networks
policy.update_target_networks()
@@ -468,7 +468,7 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss (includes both main critic and grasp critic)
# Use the forward method for critic loss (includes both main critic and discrete critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
@@ -486,20 +486,20 @@ def add_actor_information_and_train(
"critic_grad_norm": critic_grad_norm,
}
# Grasp critic optimization (if available)
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["grasp_critic"].step()
optimizers["discrete_critic"].step()
# Add grasp critic info to training info
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
# Add discrete critic info to training info
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
@@ -782,8 +782,8 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
optimizer_discrete_critic = torch.optim.Adam(
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
@@ -793,7 +793,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["grasp_critic"] = optimizer_grasp_critic
optimizers["discrete_critic"] = optimizer_discrete_critic
return optimizers, lr_scheduler