Handle gripper penalty

This commit is contained in:
AdilZouitine
2025-04-07 08:23:49 +00:00
committed by Michel Aractingi
parent 8bcf41761d
commit d5a87f67cf
3 changed files with 147 additions and 33 deletions

View File

@@ -288,6 +288,7 @@ class SACPolicy(
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
@@ -296,6 +297,7 @@ class SACPolicy(
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_grasp_critic": loss_grasp_critic}
if model == "actor":
@@ -413,6 +415,7 @@ class SACPolicy(
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
@@ -420,6 +423,9 @@ class SACPolicy(
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = actions_discrete.long()
if complementary_info is not None:
gripper_penalties: Tensor | None = complementary_info.get("gripper_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_grasp_qs = self.grasp_critic_forward(
@@ -440,7 +446,10 @@ class SACPolicy(
).squeeze(-1)
# Compute target Q-value with Bellman equation
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
rewards_gripper = rewards
if gripper_penalties is not None:
rewards_gripper = rewards - gripper_penalties
target_grasp_q = rewards_gripper + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(