Refactor SAC policy and training loop to enhance discrete action support

- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions.
- Simplified forward method to return loss outputs as a dictionary for better clarity.
- Adjusted learner_server to handle both main and grasp critic losses during training.
- Ensured optimizers are created conditionally for grasp critic based on configuration settings.
This commit is contained in:
AdilZouitine
2025-04-01 11:42:28 +00:00
committed by Adil Zouitine
parent c3f2487026
commit e35ee47b07
3 changed files with 86 additions and 90 deletions

View File

@@ -228,7 +228,7 @@ class SACPolicy(
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
model: Literal["actor", "critic", "temperature"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
@@ -246,7 +246,6 @@ class SACPolicy(
Returns:
The computed loss tensor
"""
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
# Extract common components from batch
actions: Tensor = batch["action"]
observations: dict[str, Tensor] = batch["state"]
@@ -259,7 +258,7 @@ class SACPolicy(
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
return self.compute_loss_critic(
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
@@ -268,29 +267,28 @@ class SACPolicy(
observation_features=observation_features,
next_observation_features=next_observation_features,
)
if self.config.num_discrete_actions is not None:
loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
if model == "grasp_critic":
return self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
if model == "actor":
return self.compute_loss_actor(
return {"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
)}
if model == "temperature":
return self.compute_loss_temperature(
return {"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
)}
raise ValueError(f"Unknown model type: {model}")
@@ -305,18 +303,16 @@ class SACPolicy(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_grasp_target_networks(self):
"""Update grasp target networks with exponential moving average"""
for target_param, param in zip(
self.grasp_critic_target.parameters(),
self.grasp_critic.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.grasp_critic_target.parameters(),
self.grasp_critic.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()