Enhance SACPolicy and learner server for improved grasp critic integration

- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions.
- Refactored the forward method to handle grasp critic model selection and loss computation more clearly.
- Adjusted learner server to utilize optimized parameters for grasp critic during training.
- Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
This commit is contained in:
AdilZouitine
2025-04-02 15:50:39 +00:00
committed by Michel Aractingi
parent f9fb9d4594
commit 6167886472
3 changed files with 72 additions and 50 deletions

View File

@@ -52,8 +52,6 @@ class SACPolicy(
self.config = config
continuous_action_dim = config.output_features["action"].shape[0]
if config.num_discrete_actions is not None:
continuous_action_dim -= 1
if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
@@ -191,7 +189,7 @@ class SACPolicy(
if self.config.num_discrete_actions is not None:
discrete_action_value = self.grasp_critic(batch)
discrete_action = torch.argmax(discrete_action_value, dim=-1)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
actions = torch.cat([actions, discrete_action], dim=-1)
return actions
@@ -236,7 +234,7 @@ class SACPolicy(
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature"] = "critic",
model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
@@ -275,18 +273,25 @@ class SACPolicy(
observation_features=observation_features,
next_observation_features=next_observation_features,
)
if self.config.num_discrete_actions is not None:
loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
)
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
return {"loss_critic": loss_critic}
if model == "grasp_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_grasp_critic = self.compute_loss_grasp_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
return {"loss_grasp_critic": loss_grasp_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
@@ -373,7 +378,6 @@ class SACPolicy(
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
@@ -407,30 +411,38 @@ class SACPolicy(
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:]
actions = actions.long()
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = actions_discrete.long()
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
next_grasp_qs = self.grasp_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
target_next_grasp_qs = self.grasp_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_grasp_q = torch.gather(
target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1)
target_next_grasp_qs, dim=1, index=best_next_grasp_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
# Get predicted Q-values for current observations
predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
predicted_grasp_qs = self.grasp_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1)
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
@@ -642,49 +654,52 @@ class GraspCritic(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that
input_dim: int,
hidden_dims: list[int],
output_dim: int = 3,
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
encoder_is_shared: bool = False,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.encoder = encoder
self.network = network
self.output_dim = output_dim
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.parameters_to_optimize += list(self.network.parameters())
if self.encoder is not None and not encoder_is_shared:
self.parameters_to_optimize += list(self.encoder.parameters())
self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
self.parameters_to_optimize = []
self.parameters_to_optimize += list(self.net.parameters())
self.parameters_to_optimize += list(self.output_layer.parameters())
def forward(
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
# Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()}
# Encode observations if encoder exists
obs_enc = (
observation_features
observation_features.to(device)
if observation_features is not None
else (observations if self.encoder is None else self.encoder(observations))
)
return self.output_layer(self.network(obs_enc))
return self.output_layer(self.net(obs_enc))
class Policy(nn.Module):