Refactor input and output normalization handling in SACPolicy for improved clarity and efficiency. Consolidate encoder initialization logic and remove redundant else statements.

This commit is contained in:
AdilZouitine
2025-04-17 16:05:11 +00:00
parent dc1548fe1a
commit 7a3d8756b4

View File

@@ -54,39 +54,36 @@ class SACPolicy(
continuous_action_dim = config.output_features["action"].shape[0]
if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
# Default to identity normalizations
self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
# Apply normalization if dataset stats provided
if config.dataset_stats:
params = _convert_normalization_params_to_tensor(config.dataset_stats)
self.normalize_inputs = Normalize(
config.input_features,
config.normalization_mapping,
input_normalization_params,
params,
)
else:
self.normalize_inputs = nn.Identity()
if config.dataset_stats is not None:
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
stats = dataset_stats or params
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
config.output_features,
config.normalization_mapping,
stats,
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
config.output_features,
config.normalization_mapping,
stats,
)
else:
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
# NOTE: For images the encoder should be shared between the actor and critic
if config.shared_encoder:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor: SACObservationEncoder = encoder_critic
else:
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
self.shared_encoder = config.shared_encoder
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
encoder_actor = (
encoder_critic if self.shared_encoder else SACObservationEncoder(config, self.normalize_inputs)
)
# Create a list of critic heads
critic_heads = [
@@ -161,10 +158,6 @@ class SACPolicy(
)
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
# it triggers "can't optimize a non-leaf Tensor"
temperature_init = config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item()
@@ -187,17 +180,10 @@ class SACPolicy(
"""Reset the policy"""
pass
def to(self, *args, **kwargs):
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
# We cached the encoder output to avoid recomputing it
# We cached the encoder output to avoid recomputing it if the encoder is shared
observations_features = None
if self.shared_encoder:
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
@@ -1172,90 +1158,3 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params
if __name__ == "__main__":
# # Benchmark the CriticEnsemble performance
# import time
# # Configuration
# num_critics = 10
# batch_size = 32
# action_dim = 7
# obs_dim = 64
# hidden_dims = [256, 256]
# num_iterations = 100
# print("Creating test environment...")
# # Create a simple dummy encoder
# class DummyEncoder(nn.Module):
# def __init__(self):
# super().__init__()
# self.output_dim = obs_dim
# self.parameters_to_optimize = []
# def forward(self, obs):
# # Just return a random tensor of the right shape
# # In practice, this would encode the observations
# return torch.randn(batch_size, obs_dim, device=device)
# # Create critic heads
# print(f"Creating {num_critics} critic heads...")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# critic_heads = [
# CriticHead(
# input_dim=obs_dim + action_dim,
# hidden_dims=hidden_dims,
# ).to(device)
# for _ in range(num_critics)
# ]
# # Create the critic ensemble
# print("Creating CriticEnsemble...")
# critic_ensemble = CriticEnsemble(
# encoder=DummyEncoder().to(device),
# ensemble=critic_heads,
# output_normalization=nn.Identity(),
# ).to(device)
# # Create random input data
# print("Creating input data...")
# obs_dict = {
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
# }
# actions = torch.randn(batch_size, action_dim, device=device)
# # Warmup run
# print("Warming up...")
# _ = critic_ensemble(obs_dict, actions)
# # Time the forward pass
# print(f"Running benchmark with {num_iterations} iterations...")
# start_time = time.perf_counter()
# for _ in range(num_iterations):
# q_values = critic_ensemble(obs_dict, actions)
# end_time = time.perf_counter()
# # Print results
# elapsed_time = end_time - start_time
# print(f"Total time: {elapsed_time:.4f} seconds")
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# Verify that all critic heads produce different outputs
# This confirms each critic head is unique
# print("\nVerifying critic outputs are different:")
# for i in range(num_critics):
# for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
from lerobot.configs import parser
@parser.wrap()
def main(config: SACConfig):
policy = SACPolicy(config=config)
print("yolo")
main()