Add grasp critic to the training loop

- Integrated the grasp critic gradient update to the training loop in learner_server
- Added Adam optimizer and configured grasp critic learning rate in configuration_sac
- Added target critics networks update after the critics gradient step
This commit is contained in:
s1lent4gnt
2025-03-31 18:06:21 +02:00
committed by Michel Aractingi
parent fdd04efdb7
commit 3a2308d86f
3 changed files with 53 additions and 2 deletions

View File

@@ -375,6 +375,7 @@ def add_actor_information_and_train(
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
complementary_info = batch["complementary_info"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
@@ -390,6 +391,7 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": complementary_info,
}
# Use the forward method for critic loss
@@ -404,7 +406,20 @@ def add_actor_information_and_train(
optimizers["critic"].step()
# Add gripper critic optimization
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
# clip gradients
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
policy.update_target_networks()
policy.update_grasp_target_networks()
batch = replay_buffer.sample(batch_size=batch_size)
@@ -435,6 +450,7 @@ def add_actor_information_and_train(
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": complementary_info,
}
# Use the forward method for critic loss
@@ -453,6 +469,22 @@ def add_actor_information_and_train(
training_infos["loss_critic"] = loss_critic.item()
training_infos["critic_grad_norm"] = critic_grad_norm
# Add gripper critic optimization
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
optimizers["grasp_critic"].zero_grad()
loss_grasp_critic.backward()
# clip gradients
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["grasp_critic"].step()
# Add training info for the grasp critic
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Use the forward method for actor loss
@@ -495,6 +527,7 @@ def add_actor_information_and_train(
last_time_policy_pushed = time.time()
policy.update_target_networks()
policy.update_grasp_target_networks()
# Log training metrics at specified intervals
if optimization_step % log_freq == 0:
@@ -729,11 +762,13 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
optimizer_grasp_critic = torch.optim.Adam(params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"grasp_critic": optimizer_grasp_critic,
"temperature": optimizer_temperature,
}
return optimizers, lr_scheduler