forked from tangger/lerobot
Enhance SACPolicy and learner server for improved grasp critic integration
- Updated SACPolicy to conditionally compute grasp critic losses based on the presence of discrete actions. - Refactored the forward method to handle grasp critic model selection and loss computation more clearly. - Adjusted learner server to utilize optimized parameters for grasp critic during training. - Improved action handling in the ManiskillMockGripperWrapper to accommodate both tuple and single action inputs.
This commit is contained in:
committed by
Michel Aractingi
parent
f9fb9d4594
commit
6167886472
@@ -52,8 +52,6 @@ class SACPolicy(
|
||||
self.config = config
|
||||
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
if config.num_discrete_actions is not None:
|
||||
continuous_action_dim -= 1
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
@@ -191,7 +189,7 @@ class SACPolicy(
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
discrete_action_value = self.grasp_critic(batch)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
@@ -236,7 +234,7 @@ class SACPolicy(
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature"] = "critic",
|
||||
model: Literal["actor", "critic", "temperature", "grasp_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
@@ -275,18 +273,25 @@ class SACPolicy(
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_grasp_critic = self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
return {"loss_critic": loss_critic, "loss_grasp_critic": loss_grasp_critic}
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "grasp_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
loss_grasp_critic = self.compute_loss_grasp_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
return {"loss_grasp_critic": loss_grasp_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
@@ -373,7 +378,6 @@ class SACPolicy(
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
@@ -407,30 +411,38 @@ class SACPolicy(
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:]
|
||||
actions = actions.long()
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_grasp_qs = self.grasp_critic_forward(next_observations, use_target=False)
|
||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1)
|
||||
next_grasp_qs = self.grasp_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_grasp_action = torch.argmax(next_grasp_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_grasp_qs = self.grasp_critic_forward(observations=next_observations, use_target=True)
|
||||
target_next_grasp_qs = self.grasp_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_grasp_q = torch.gather(
|
||||
target_next_grasp_qs, dim=1, index=best_next_grasp_action.unsqueeze(-1)
|
||||
target_next_grasp_qs, dim=1, index=best_next_grasp_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
target_grasp_q = rewards + (1 - done) * self.config.discount * target_next_grasp_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_grasp_qs = self.grasp_critic_forward(observations=observations, use_target=False)
|
||||
predicted_grasp_qs = self.grasp_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions.unsqueeze(-1)).squeeze(-1)
|
||||
predicted_grasp_q = torch.gather(predicted_grasp_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
grasp_critic_loss = F.mse_loss(input=predicted_grasp_q, target=target_grasp_q)
|
||||
@@ -642,49 +654,52 @@ class GraspCritic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
output_dim: int = 3, # TODO (azouitine): rename it number of discret acitons smth like that
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
output_dim: int = 3,
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
init_final: Optional[float] = None,
|
||||
encoder_is_shared: bool = False,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.output_dim = output_dim
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
|
||||
self.output_layer = nn.Linear(in_features=out_features, out_features=self.output_dim)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
self.parameters_to_optimize += list(self.net.parameters())
|
||||
self.parameters_to_optimize += list(self.output_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = (
|
||||
observation_features
|
||||
observation_features.to(device)
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
return self.output_layer(self.network(obs_enc))
|
||||
return self.output_layer(self.net(obs_enc))
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user