Refactor SACPolicy and configuration to replace 'grasp_critic' terminology with 'discrete_critic'. Update related methods and comments for clarity and consistency in handling discrete actions.
This commit is contained in:
@@ -833,7 +833,7 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
if self.gripper_penalty_in_reward:
|
||||
reward += gripper_penalty
|
||||
else:
|
||||
info["gripper_penalty"] = gripper_penalty
|
||||
info["discrete_penalty"] = gripper_penalty
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
@@ -410,7 +410,7 @@ def add_actor_information_and_train(
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
# Use the forward method for critic loss (includes both main critic and discrete critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
@@ -422,16 +422,16 @@ def add_actor_information_and_train(
|
||||
)
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["grasp_critic"].step()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Update target networks
|
||||
policy.update_target_networks()
|
||||
@@ -468,7 +468,7 @@ def add_actor_information_and_train(
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and grasp critic)
|
||||
# Use the forward method for critic loss (includes both main critic and discrete critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
@@ -486,20 +486,20 @@ def add_actor_information_and_train(
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Grasp critic optimization (if available)
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="grasp_critic")
|
||||
loss_grasp_critic = discrete_critic_output["loss_grasp_critic"]
|
||||
optimizers["grasp_critic"].zero_grad()
|
||||
loss_grasp_critic.backward()
|
||||
grasp_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.grasp_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["grasp_critic"].step()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Add grasp critic info to training info
|
||||
training_infos["loss_grasp_critic"] = loss_grasp_critic.item()
|
||||
training_infos["grasp_critic_grad_norm"] = grasp_critic_grad_norm
|
||||
# Add discrete critic info to training info
|
||||
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
|
||||
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
@@ -782,8 +782,8 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_grasp_critic = torch.optim.Adam(
|
||||
params=policy.grasp_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
optimizer_discrete_critic = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
@@ -793,7 +793,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["grasp_critic"] = optimizer_grasp_critic
|
||||
optimizers["discrete_critic"] = optimizer_discrete_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user