Refactor input and output normalization handling in SACPolicy for improved clarity and efficiency. Consolidate encoder initialization logic and remove redundant else statements.
This commit is contained in:
@@ -54,39 +54,36 @@ class SACPolicy(
|
||||
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
# Default to identity normalizations
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
self.unnormalize_outputs = nn.Identity()
|
||||
# Apply normalization if dataset stats provided
|
||||
if config.dataset_stats:
|
||||
params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features,
|
||||
config.normalization_mapping,
|
||||
input_normalization_params,
|
||||
params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
if config.dataset_stats is not None:
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
config.output_features,
|
||||
config.normalization_mapping,
|
||||
stats,
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
config.output_features,
|
||||
config.normalization_mapping,
|
||||
stats,
|
||||
)
|
||||
else:
|
||||
self.normalize_targets = nn.Identity()
|
||||
self.unnormalize_outputs = nn.Identity()
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
self.shared_encoder = config.shared_encoder
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = (
|
||||
encoder_critic if self.shared_encoder else SACObservationEncoder(config, self.normalize_inputs)
|
||||
)
|
||||
|
||||
# Create a list of critic heads
|
||||
critic_heads = [
|
||||
@@ -161,10 +158,6 @@ class SACPolicy(
|
||||
)
|
||||
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
# it triggers "can't optimize a non-leaf Tensor"
|
||||
|
||||
temperature_init = config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
@@ -187,17 +180,10 @@ class SACPolicy(
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||
if self.actor.fixed_std is not None:
|
||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
super().to(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
# We cached the encoder output to avoid recomputing it
|
||||
# We cached the encoder output to avoid recomputing it if the encoder is shared
|
||||
observations_features = None
|
||||
if self.shared_encoder:
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
|
||||
@@ -1172,90 +1158,3 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# # Benchmark the CriticEnsemble performance
|
||||
# import time
|
||||
|
||||
# # Configuration
|
||||
# num_critics = 10
|
||||
# batch_size = 32
|
||||
# action_dim = 7
|
||||
# obs_dim = 64
|
||||
# hidden_dims = [256, 256]
|
||||
# num_iterations = 100
|
||||
|
||||
# print("Creating test environment...")
|
||||
|
||||
# # Create a simple dummy encoder
|
||||
# class DummyEncoder(nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.output_dim = obs_dim
|
||||
# self.parameters_to_optimize = []
|
||||
|
||||
# def forward(self, obs):
|
||||
# # Just return a random tensor of the right shape
|
||||
# # In practice, this would encode the observations
|
||||
# return torch.randn(batch_size, obs_dim, device=device)
|
||||
|
||||
# # Create critic heads
|
||||
# print(f"Creating {num_critics} critic heads...")
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# critic_heads = [
|
||||
# CriticHead(
|
||||
# input_dim=obs_dim + action_dim,
|
||||
# hidden_dims=hidden_dims,
|
||||
# ).to(device)
|
||||
# for _ in range(num_critics)
|
||||
# ]
|
||||
|
||||
# # Create the critic ensemble
|
||||
# print("Creating CriticEnsemble...")
|
||||
# critic_ensemble = CriticEnsemble(
|
||||
# encoder=DummyEncoder().to(device),
|
||||
# ensemble=critic_heads,
|
||||
# output_normalization=nn.Identity(),
|
||||
# ).to(device)
|
||||
|
||||
# # Create random input data
|
||||
# print("Creating input data...")
|
||||
# obs_dict = {
|
||||
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
|
||||
# }
|
||||
# actions = torch.randn(batch_size, action_dim, device=device)
|
||||
|
||||
# # Warmup run
|
||||
# print("Warming up...")
|
||||
# _ = critic_ensemble(obs_dict, actions)
|
||||
|
||||
# # Time the forward pass
|
||||
# print(f"Running benchmark with {num_iterations} iterations...")
|
||||
# start_time = time.perf_counter()
|
||||
# for _ in range(num_iterations):
|
||||
# q_values = critic_ensemble(obs_dict, actions)
|
||||
# end_time = time.perf_counter()
|
||||
|
||||
# # Print results
|
||||
# elapsed_time = end_time - start_time
|
||||
# print(f"Total time: {elapsed_time:.4f} seconds")
|
||||
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
||||
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
||||
|
||||
# Verify that all critic heads produce different outputs
|
||||
# This confirms each critic head is unique
|
||||
# print("\nVerifying critic outputs are different:")
|
||||
# for i in range(num_critics):
|
||||
# for j in range(i + 1, num_critics):
|
||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
||||
|
||||
from lerobot.configs import parser
|
||||
|
||||
@parser.wrap()
|
||||
def main(config: SACConfig):
|
||||
policy = SACPolicy(config=config)
|
||||
print("yolo")
|
||||
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user