forked from tangger/lerobot
Add grasp critic to the training loop
- Integrated the grasp critic gradient update to the training loop in learner_server - Added Adam optimizer and configured grasp critic learning rate in configuration_sac - Added target critics networks update after the critics gradient step
This commit is contained in:
committed by
Michel Aractingi
parent
fdd04efdb7
commit
3a2308d86f
@@ -167,6 +167,7 @@ class SACConfig(PreTrainedConfig):
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_lr: float = 3e-4
|
||||
grasp_critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
critic_target_update_weight: float = 0.005
|
||||
|
||||
@@ -214,7 +214,7 @@ class SACPolicy(
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
||||
model: Literal["actor", "critic", "grasp_critic", "temperature"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
@@ -227,7 +227,7 @@ class SACPolicy(
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", or "temperature")
|
||||
model: Which model to compute the loss for ("actor", "critic", "grasp_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
@@ -254,6 +254,21 @@ class SACPolicy(
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
if model == "grasp_critic":
|
||||
# Extract grasp_critic-specific components
|
||||
complementary_info: dict[str, Tensor] = batch["complementary_info"]
|
||||
|
||||
return self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
|
||||
if model == "actor":
|
||||
return self.compute_loss_actor(
|
||||
|
||||
Reference in New Issue
Block a user