diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index aa130984e..9705d5176 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -652,6 +652,26 @@ class SACObservationEncoder(nn.Module): class MLP(nn.Module): + """Multi-layer perceptron builder. + + Dynamically constructs a sequence of layers based on `hidden_dims`: + 1) Linear (in_dim -> out_dim) + 2) Optional Dropout if `dropout_rate` > 0 and (not final layer or `activate_final`) + 3) LayerNorm on the output features + 4) Activation (standard for intermediate layers, `final_activation` for last layer if `activate_final`) + + Arguments: + input_dim (int): Size of input feature dimension. + hidden_dims (list[int]): Sizes for each hidden layer. + activations (Callable or str): Activation to apply between layers. + activate_final (bool): Whether to apply activation at the final layer. + dropout_rate (Optional[float]): Dropout probability applied before normalization and activation. + final_activation (Optional[Callable or str]): Activation for the final layer when `activate_final` is True. + + For each layer, `in_dim` is updated to the previous `out_dim`. All constructed modules are + stored in `self.net` as an `nn.Sequential` container. + """ + def __init__( self, input_dim: int, @@ -662,38 +682,25 @@ class MLP(nn.Module): final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, ): super().__init__() - self.activate_final = activate_final - layers = [] + layers: list[nn.Module] = [] + in_dim = input_dim + total = len(hidden_dims) - # First layer uses input_dim - layers.append(nn.Linear(input_dim, hidden_dims[0])) + for idx, out_dim in enumerate(hidden_dims): + # 1) linear transform + layers.append(nn.Linear(in_dim, out_dim)) - # Add activation after first layer - if dropout_rate is not None and dropout_rate > 0: - layers.append(nn.Dropout(p=dropout_rate)) - layers.append(nn.LayerNorm(hidden_dims[0])) - layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) - - # Rest of the layers - for i in range(1, len(hidden_dims)): - layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i])) - - if i + 1 < len(hidden_dims) or activate_final: - if dropout_rate is not None and dropout_rate > 0: + is_last = idx == total - 1 + # 2-4) optionally add dropout, normalization, and activation + if not is_last or activate_final: + if dropout_rate and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) - layers.append(nn.LayerNorm(hidden_dims[i])) + layers.append(nn.LayerNorm(out_dim)) + act_cls = final_activation if is_last and final_activation else activations + act = act_cls if isinstance(act_cls, nn.Module) else getattr(nn, act_cls)() + layers.append(act) - # If we're at the final layer and a final activation is specified, use it - if i + 1 == len(hidden_dims) and activate_final and final_activation is not None: - layers.append( - final_activation - if isinstance(final_activation, nn.Module) - else getattr(nn, final_activation)() - ) - else: - layers.append( - activations if isinstance(activations, nn.Module) else getattr(nn, activations)() - ) + in_dim = out_dim self.net = nn.Sequential(*layers) @@ -734,41 +741,15 @@ class CriticHead(nn.Module): class CriticEnsemble(nn.Module): """ - ┌──────────────────┬─────────────────────────────────────────────────────────┐ - │ Critic Ensemble │ │ - ├──────────────────┘ │ - │ │ - │ ┌────┐ ┌────┐ ┌────┐ │ - │ │ Q1 │ │ Q2 │ │ Qn │ │ - │ └────┘ └────┘ └────┘ │ - │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ - │ │ │ │ │ │ │ │ - │ │ MLP 1 │ │ MLP 2 │ │ MLP │ │ - │ │ │ │ │ ... │ num_critics │ │ - │ │ │ │ │ │ │ │ - │ └──────────────┘ └──────────────┘ └──────────────┘ │ - │ ▲ ▲ ▲ │ - │ └───────────────────┴───────┬────────────────────────────┘ │ - │ │ │ - │ │ │ - │ ┌───────────────────┐ │ - │ │ Embedding │ │ - │ │ │ │ - │ └───────────────────┘ │ - │ ▲ │ - │ │ │ - │ ┌─────────────┴────────────┐ │ - │ │ │ │ - │ │ SACObservationEncoder │ │ - │ │ │ │ - │ └──────────────────────────┘ │ - │ ▲ │ - │ │ │ - │ │ │ - │ │ │ - └───────────────────────────┬────────────────────┬───────────────────────────┘ - │ Observation │ - └────────────────────┘ + CriticEnsemble wraps multiple CriticHead modules into an ensemble. + + Args: + encoder (SACObservationEncoder): encoder for observations. + ensemble (List[CriticHead]): list of critic heads. + output_normalization (nn.Module): normalization layer for actions. + init_final (Optional[float]): optional initializer scale for final layers. + + Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. """ def __init__( @@ -850,7 +831,6 @@ class GraspCritic(nn.Module): self, observations: torch.Tensor, observation_features: torch.Tensor | None = None ) -> torch.Tensor: device = get_device_from_parameters(self) - # Move each tensor in observations to device by cloning first to avoid inplace operations observations = {k: v.to(device) for k, v in observations.items()} obs_enc = self.encoder(observations, cache=observation_features) return self.output_layer(self.net(obs_enc))