forked from tangger/lerobot
Enhance MLP class in modeling_sac.py with detailed docstring and refactor layer construction for improved readability. Simplify layer addition logic by removing unnecessary conditions and ensuring consistent handling of activations and dropout.
This commit is contained in:
@@ -652,6 +652,26 @@ class SACObservationEncoder(nn.Module):
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Multi-layer perceptron builder.
|
||||
|
||||
Dynamically constructs a sequence of layers based on `hidden_dims`:
|
||||
1) Linear (in_dim -> out_dim)
|
||||
2) Optional Dropout if `dropout_rate` > 0 and (not final layer or `activate_final`)
|
||||
3) LayerNorm on the output features
|
||||
4) Activation (standard for intermediate layers, `final_activation` for last layer if `activate_final`)
|
||||
|
||||
Arguments:
|
||||
input_dim (int): Size of input feature dimension.
|
||||
hidden_dims (list[int]): Sizes for each hidden layer.
|
||||
activations (Callable or str): Activation to apply between layers.
|
||||
activate_final (bool): Whether to apply activation at the final layer.
|
||||
dropout_rate (Optional[float]): Dropout probability applied before normalization and activation.
|
||||
final_activation (Optional[Callable or str]): Activation for the final layer when `activate_final` is True.
|
||||
|
||||
For each layer, `in_dim` is updated to the previous `out_dim`. All constructed modules are
|
||||
stored in `self.net` as an `nn.Sequential` container.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
@@ -662,38 +682,25 @@ class MLP(nn.Module):
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
layers: list[nn.Module] = []
|
||||
in_dim = input_dim
|
||||
total = len(hidden_dims)
|
||||
|
||||
# First layer uses input_dim
|
||||
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||
for idx, out_dim in enumerate(hidden_dims):
|
||||
# 1) linear transform
|
||||
layers.append(nn.Linear(in_dim, out_dim))
|
||||
|
||||
# Add activation after first layer
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
is_last = idx == total - 1
|
||||
# 2-4) optionally add dropout, normalization, and activation
|
||||
if not is_last or activate_final:
|
||||
if dropout_rate and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(nn.LayerNorm(out_dim))
|
||||
act_cls = final_activation if is_last and final_activation else activations
|
||||
act = act_cls if isinstance(act_cls, nn.Module) else getattr(nn, act_cls)()
|
||||
layers.append(act)
|
||||
|
||||
# If we're at the final layer and a final activation is specified, use it
|
||||
if i + 1 == len(hidden_dims) and activate_final and final_activation is not None:
|
||||
layers.append(
|
||||
final_activation
|
||||
if isinstance(final_activation, nn.Module)
|
||||
else getattr(nn, final_activation)()
|
||||
)
|
||||
else:
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
in_dim = out_dim
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
@@ -734,41 +741,15 @@ class CriticHead(nn.Module):
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
┌──────────────────┬─────────────────────────────────────────────────────────┐
|
||||
│ Critic Ensemble │ │
|
||||
├──────────────────┘ │
|
||||
│ │
|
||||
│ ┌────┐ ┌────┐ ┌────┐ │
|
||||
│ │ Q1 │ │ Q2 │ │ Qn │ │
|
||||
│ └────┘ └────┘ └────┘ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
|
||||
│ │ │ │ │ ... │ num_critics │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ ▲ ▲ ▲ │
|
||||
│ └───────────────────┴───────┬────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────┐ │
|
||||
│ │ Embedding │ │
|
||||
│ │ │ │
|
||||
│ └───────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ ┌─────────────┴────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ SACObservationEncoder │ │
|
||||
│ │ │ │
|
||||
│ └──────────────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
└───────────────────────────┬────────────────────┬───────────────────────────┘
|
||||
│ Observation │
|
||||
└────────────────────┘
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
output_normalization (nn.Module): normalization layer for actions.
|
||||
init_final (Optional[float]): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -850,7 +831,6 @@ class GraspCritic(nn.Module):
|
||||
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device by cloning first to avoid inplace operations
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
return self.output_layer(self.net(obs_enc))
|
||||
|
||||
Reference in New Issue
Block a user