Refactor SAC policy and training loop to enhance discrete action support
- Updated SACPolicy to conditionally compute losses for grasp critic based on num_discrete_actions. - Simplified forward method to return loss outputs as a dictionary for better clarity. - Adjusted learner_server to handle both main and grasp critic losses during training. - Ensured optimizers are created conditionally for grasp critic based on configuration settings.
This commit is contained in:
committed by
Adil Zouitine
parent
c3f2487026
commit
e35ee47b07
@@ -228,7 +228,7 @@ class SACPolicy(
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
|
||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
@@ -246,7 +246,6 @@ class SACPolicy(
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# TODO: (maractingi, azouitine) Respect the function signature we output tensors
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch["action"]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
@@ -259,7 +258,7 @@ class SACPolicy(
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
return self.compute_loss_critic(
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
@@ -268,29 +267,28 @@ class SACPolicy(
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_grasp_critic = self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
||||
|
||||
if model == "grasp_critic":
|
||||
return self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
if model == "actor":
|
||||
return self.compute_loss_actor(
|
||||
return {"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
)}
|
||||
|
||||
if model == "temperature":
|
||||
return self.compute_loss_temperature(
|
||||
return {"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
)}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
@@ -305,18 +303,16 @@ class SACPolicy(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_grasp_target_networks(self):
|
||||
"""Update grasp target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.grasp_critic_target.parameters(),
|
||||
self.grasp_critic.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.grasp_critic_target.parameters(),
|
||||
self.grasp_critic.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def update_temperature(self):
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
Reference in New Issue
Block a user