Refactor SAC policy and training loop to enhance discrete action support
- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions. - Simplified forward method to return loss outputs as a dictionary for better clarity. - Adjusted learner_server to handle both main and grasp critic losses during training. - Ensured optimizers are created conditionally for grasp critic based on configuration settings.
This commit is contained in:
committed by
Michel Aractingi
parent
7361a11a4d
commit
f83d215e7a
@@ -392,32 +392,30 @@ def add_actor_information_and_train(
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Add gripper critic optimization
|
||||
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["grasp_critic"].step()
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output and hasattr(policy, "grasp_critic"):
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Update target networks
|
||||
policy.update_target_networks()
|
||||
policy.update_grasp_target_networks()
|
||||
|
||||
batch = replay_buffer.sample(batch_size=batch_size)
|
||||
|
||||
@@ -450,81 +448,80 @@ def add_actor_information_and_train(
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
loss_critic = policy.forward(forward_batch, model="critic")
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
training_infos["critic_grad_norm"] = critic_grad_norm
|
||||
# Initialize training info dictionary
|
||||
training_infos = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Add gripper critic optimization
|
||||
loss_grasp_critic = policy.forward(forward_batch, model="grasp_critic")
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
|
||||
# clip gradients
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Add training info for the grasp critic
|
||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||
# Grasp critic optimization (if available)
|
||||
if "loss_grasp_critic" in critic_output:
|
||||
loss_grasp_critic = critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
|
||||
# Add grasp critic info to training info
|
||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Use the forward method for actor loss
|
||||
loss_actor = policy.forward(forward_batch, model="actor")
|
||||
|
||||
# Actor optimization
|
||||
actor_output = policy.forward(forward_batch, model="actor")
|
||||
loss_actor = actor_output["loss_actor"]
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
|
||||
# clip gradients
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters_to_optimize, max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["actor"].step()
|
||||
|
||||
|
||||
# Add actor info to training info
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization using forward method
|
||||
loss_temperature = policy.forward(forward_batch, model="temperature")
|
||||
# Temperature optimization
|
||||
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||
loss_temperature = temperature_output["loss_temperature"]
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
|
||||
# clip gradients
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
|
||||
optimizers["temperature"].step()
|
||||
|
||||
|
||||
# Add temperature info to training info
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
|
||||
# Update temperature
|
||||
policy.update_temperature()
|
||||
|
||||
# Check if it's time to push updated policy to actors
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
# Update target networks
|
||||
policy.update_target_networks()
|
||||
policy.update_grasp_target_networks()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
if optimization_step % log_freq == 0:
|
||||
@@ -727,7 +724,7 @@ def save_training_checkpoint(
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
@@ -759,17 +756,20 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters(), lr=policy.config.grasp_critic_lr
|
||||
)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters(), lr=policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"grasp_critic": optimizer_grasp_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["grasp_critic"] = optimizer_grasp_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user