import torch from torch import nn from lerobot.common.policies.sac.modeling_sac import MLP def test_mlp_with_default_args(): mlp = MLP(input_dim=10, hidden_dims=[256, 256]) x = torch.randn(10) y = mlp(x) assert y.shape == (256,) def test_mlp_with_batch_dim(): mlp = MLP(input_dim=10, hidden_dims=[256, 256]) x = torch.randn(2, 10) y = mlp(x) assert y.shape == (2, 256) def test_forward_with_empty_hidden_dims(): mlp = MLP(input_dim=10, hidden_dims=[]) x = torch.randn(1, 10) assert mlp(x).shape == (1, 10) def test_mlp_with_dropout(): mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1) x = torch.randn(1, 10) y = mlp(x) assert y.shape == (1, 11) drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net) assert drop_out_layers_count == 2 def test_mlp_with_custom_final_activation(): mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh()) x = torch.randn(1, 10) y = mlp(x) assert y.shape == (1, 256) assert (y >= -1).all() and (y <= 1).all()