From fb075a709d703cbf40f6588f73e63784fa5fceb6 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 17 Apr 2025 16:05:11 +0000 Subject: [PATCH] Refactor input and output normalization handling in SACPolicy for improved clarity and efficiency. Consolidate encoder initialization logic and remove redundant else statements. --- lerobot/common/policies/sac/modeling_sac.py | 141 +++----------------- 1 file changed, 20 insertions(+), 121 deletions(-) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 24d536002..98b0df935 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -54,39 +54,36 @@ class SACPolicy( continuous_action_dim = config.output_features["action"].shape[0] - if config.dataset_stats is not None: - input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) + # Default to identity normalizations + self.normalize_inputs = nn.Identity() + self.normalize_targets = nn.Identity() + self.unnormalize_outputs = nn.Identity() + # Apply normalization if dataset stats provided + if config.dataset_stats: + params = _convert_normalization_params_to_tensor(config.dataset_stats) self.normalize_inputs = Normalize( config.input_features, config.normalization_mapping, - input_normalization_params, + params, ) - else: - self.normalize_inputs = nn.Identity() - - if config.dataset_stats is not None: - output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats) - - # HACK: This is hacky and should be removed - dataset_stats = dataset_stats or output_normalization_params + stats = dataset_stats or params self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats + config.output_features, + config.normalization_mapping, + stats, ) self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats + config.output_features, + config.normalization_mapping, + stats, ) - else: - self.normalize_targets = nn.Identity() - self.unnormalize_outputs = nn.Identity() # NOTE: For images the encoder should be shared between the actor and critic - if config.shared_encoder: - encoder_critic = SACObservationEncoder(config, self.normalize_inputs) - encoder_actor: SACObservationEncoder = encoder_critic - else: - encoder_critic = SACObservationEncoder(config, self.normalize_inputs) - encoder_actor = SACObservationEncoder(config, self.normalize_inputs) self.shared_encoder = config.shared_encoder + encoder_critic = SACObservationEncoder(config, self.normalize_inputs) + encoder_actor = ( + encoder_critic if self.shared_encoder else SACObservationEncoder(config, self.normalize_inputs) + ) # Create a list of critic heads critic_heads = [ @@ -161,10 +158,6 @@ class SACPolicy( ) config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2) - # TODO (azouitine): Handle the case where the temparameter is a fixed - # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise - # it triggers "can't optimize a non-leaf Tensor" - temperature_init = config.temperature_init self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)])) self.temperature = self.log_alpha.exp().item() @@ -187,17 +180,10 @@ class SACPolicy( """Reset the policy""" pass - def to(self, *args, **kwargs): - """Override .to(device) method to involve moving the log_alpha fixed_std""" - if self.actor.fixed_std is not None: - self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs) - # self.log_alpha = self.log_alpha.to(*args, **kwargs) - super().to(*args, **kwargs) - @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - # We cached the encoder output to avoid recomputing it + # We cached the encoder output to avoid recomputing it if the encoder is shared observations_features = None if self.shared_encoder: observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True) @@ -1172,90 +1158,3 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) return converted_params - - -if __name__ == "__main__": - # # Benchmark the CriticEnsemble performance - # import time - - # # Configuration - # num_critics = 10 - # batch_size = 32 - # action_dim = 7 - # obs_dim = 64 - # hidden_dims = [256, 256] - # num_iterations = 100 - - # print("Creating test environment...") - - # # Create a simple dummy encoder - # class DummyEncoder(nn.Module): - # def __init__(self): - # super().__init__() - # self.output_dim = obs_dim - # self.parameters_to_optimize = [] - - # def forward(self, obs): - # # Just return a random tensor of the right shape - # # In practice, this would encode the observations - # return torch.randn(batch_size, obs_dim, device=device) - - # # Create critic heads - # print(f"Creating {num_critics} critic heads...") - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # critic_heads = [ - # CriticHead( - # input_dim=obs_dim + action_dim, - # hidden_dims=hidden_dims, - # ).to(device) - # for _ in range(num_critics) - # ] - - # # Create the critic ensemble - # print("Creating CriticEnsemble...") - # critic_ensemble = CriticEnsemble( - # encoder=DummyEncoder().to(device), - # ensemble=critic_heads, - # output_normalization=nn.Identity(), - # ).to(device) - - # # Create random input data - # print("Creating input data...") - # obs_dict = { - # "observation.state": torch.randn(batch_size, obs_dim, device=device), - # } - # actions = torch.randn(batch_size, action_dim, device=device) - - # # Warmup run - # print("Warming up...") - # _ = critic_ensemble(obs_dict, actions) - - # # Time the forward pass - # print(f"Running benchmark with {num_iterations} iterations...") - # start_time = time.perf_counter() - # for _ in range(num_iterations): - # q_values = critic_ensemble(obs_dict, actions) - # end_time = time.perf_counter() - - # # Print results - # elapsed_time = end_time - start_time - # print(f"Total time: {elapsed_time:.4f} seconds") - # print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms") - # print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size] - - # Verify that all critic heads produce different outputs - # This confirms each critic head is unique - # print("\nVerifying critic outputs are different:") - # for i in range(num_critics): - # for j in range(i + 1, num_critics): - # diff = torch.abs(q_values[i] - q_values[j]).mean().item() - # print(f"Mean difference between critic {i} and {j}: {diff:.6f}") - - from lerobot.configs import parser - - @parser.wrap() - def main(config: SACConfig): - policy = SACPolicy(config=config) - print("yolo") - - main()