Enhance MLP class in modeling_sac.py with detailed docstring and refactor layer construction for improved readability. Simplify layer addition logic by removing unnecessary conditions and ensuring consistent handling of activations and dropout.

This commit is contained in:
AdilZouitine
2025-04-18 14:15:06 +00:00
parent fb92935601
commit 54c3c6d684

View File

@@ -652,6 +652,26 @@ class SACObservationEncoder(nn.Module):
class MLP(nn.Module):
"""Multi-layer perceptron builder.
Dynamically constructs a sequence of layers based on `hidden_dims`:
1) Linear (in_dim -> out_dim)
2) Optional Dropout if `dropout_rate` > 0 and (not final layer or `activate_final`)
3) LayerNorm on the output features
4) Activation (standard for intermediate layers, `final_activation` for last layer if `activate_final`)
Arguments:
input_dim (int): Size of input feature dimension.
hidden_dims (list[int]): Sizes for each hidden layer.
activations (Callable or str): Activation to apply between layers.
activate_final (bool): Whether to apply activation at the final layer.
dropout_rate (Optional[float]): Dropout probability applied before normalization and activation.
final_activation (Optional[Callable or str]): Activation for the final layer when `activate_final` is True.
For each layer, `in_dim` is updated to the previous `out_dim`. All constructed modules are
stored in `self.net` as an `nn.Sequential` container.
"""
def __init__(
self,
input_dim: int,
@@ -662,38 +682,25 @@ class MLP(nn.Module):
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.activate_final = activate_final
layers = []
layers: list[nn.Module] = []
in_dim = input_dim
total = len(hidden_dims)
# First layer uses input_dim
layers.append(nn.Linear(input_dim, hidden_dims[0]))
for idx, out_dim in enumerate(hidden_dims):
# 1) linear transform
layers.append(nn.Linear(in_dim, out_dim))
# Add activation after first layer
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
# Rest of the layers
for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
is_last = idx == total - 1
# 2-4) optionally add dropout, normalization, and activation
if not is_last or activate_final:
if dropout_rate and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(nn.LayerNorm(out_dim))
act_cls = final_activation if is_last and final_activation else activations
act = act_cls if isinstance(act_cls, nn.Module) else getattr(nn, act_cls)()
layers.append(act)
# If we're at the final layer and a final activation is specified, use it
if i + 1 == len(hidden_dims) and activate_final and final_activation is not None:
layers.append(
final_activation
if isinstance(final_activation, nn.Module)
else getattr(nn, final_activation)()
)
else:
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
in_dim = out_dim
self.net = nn.Sequential(*layers)
@@ -734,41 +741,15 @@ class CriticHead(nn.Module):
class CriticEnsemble(nn.Module):
"""
┌──────────────────┬─────────────────────────────────────────────────────────┐
│ Critic Ensemble │ │
├──────────────────┘ │
│ │
│ ┌────┐ ┌────┐ ┌────┐ │
│ Q1 │ │ Q2 │ │ Qn │ │
│ └────┘ └────┘ └────┘ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ │ │ │ │ │ │
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
│ │ │ │ │ ... │ num_critics │ │
│ │ │ │ │ │ │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ ▲ ▲ ▲ │
│ └───────────────────┴───────┬────────────────────────────┘ │
│ │ │
│ │ │
│ ┌───────────────────┐ │
│ │ Embedding │ │
│ │ │ │
│ └───────────────────┘ │
│ ▲ │
│ │ │
│ ┌─────────────┴────────────┐ │
│ │ │ │
│ │ SACObservationEncoder │ │
│ │ │ │
│ └──────────────────────────┘ │
│ ▲ │
│ │ │
│ │ │
│ │ │
└───────────────────────────┬────────────────────┬───────────────────────────┘
│ Observation │
└────────────────────┘
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
output_normalization (nn.Module): normalization layer for actions.
init_final (Optional[float]): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
"""
def __init__(
@@ -850,7 +831,6 @@ class GraspCritic(nn.Module):
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device by cloning first to avoid inplace operations
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, cache=observation_features)
return self.output_layer(self.net(obs_enc))