Refactor input and output normalization handling in SACPolicy for improved clarity and efficiency. Consolidate encoder initialization logic and remove redundant else statements.
This commit is contained in:
committed by
Michel Aractingi
parent
3424644ecd
commit
fb075a709d
@@ -54,39 +54,36 @@ class SACPolicy(
|
|||||||
|
|
||||||
continuous_action_dim = config.output_features["action"].shape[0]
|
continuous_action_dim = config.output_features["action"].shape[0]
|
||||||
|
|
||||||
if config.dataset_stats is not None:
|
# Default to identity normalizations
|
||||||
input_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
self.normalize_inputs = nn.Identity()
|
||||||
|
self.normalize_targets = nn.Identity()
|
||||||
|
self.unnormalize_outputs = nn.Identity()
|
||||||
|
# Apply normalization if dataset stats provided
|
||||||
|
if config.dataset_stats:
|
||||||
|
params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
||||||
self.normalize_inputs = Normalize(
|
self.normalize_inputs = Normalize(
|
||||||
config.input_features,
|
config.input_features,
|
||||||
config.normalization_mapping,
|
config.normalization_mapping,
|
||||||
input_normalization_params,
|
params,
|
||||||
)
|
)
|
||||||
else:
|
stats = dataset_stats or params
|
||||||
self.normalize_inputs = nn.Identity()
|
|
||||||
|
|
||||||
if config.dataset_stats is not None:
|
|
||||||
output_normalization_params = _convert_normalization_params_to_tensor(config.dataset_stats)
|
|
||||||
|
|
||||||
# HACK: This is hacky and should be removed
|
|
||||||
dataset_stats = dataset_stats or output_normalization_params
|
|
||||||
self.normalize_targets = Normalize(
|
self.normalize_targets = Normalize(
|
||||||
config.output_features, config.normalization_mapping, dataset_stats
|
config.output_features,
|
||||||
|
config.normalization_mapping,
|
||||||
|
stats,
|
||||||
)
|
)
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
config.output_features, config.normalization_mapping, dataset_stats
|
config.output_features,
|
||||||
|
config.normalization_mapping,
|
||||||
|
stats,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.normalize_targets = nn.Identity()
|
|
||||||
self.unnormalize_outputs = nn.Identity()
|
|
||||||
|
|
||||||
# NOTE: For images the encoder should be shared between the actor and critic
|
# NOTE: For images the encoder should be shared between the actor and critic
|
||||||
if config.shared_encoder:
|
|
||||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
|
||||||
encoder_actor: SACObservationEncoder = encoder_critic
|
|
||||||
else:
|
|
||||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
|
||||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
|
||||||
self.shared_encoder = config.shared_encoder
|
self.shared_encoder = config.shared_encoder
|
||||||
|
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||||
|
encoder_actor = (
|
||||||
|
encoder_critic if self.shared_encoder else SACObservationEncoder(config, self.normalize_inputs)
|
||||||
|
)
|
||||||
|
|
||||||
# Create a list of critic heads
|
# Create a list of critic heads
|
||||||
critic_heads = [
|
critic_heads = [
|
||||||
@@ -161,10 +158,6 @@ class SACPolicy(
|
|||||||
)
|
)
|
||||||
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
config.target_entropy = -np.prod(continuous_action_dim + discrete_actions_dim) / 2 # (-dim(A)/2)
|
||||||
|
|
||||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
|
||||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
|
||||||
# it triggers "can't optimize a non-leaf Tensor"
|
|
||||||
|
|
||||||
temperature_init = config.temperature_init
|
temperature_init = config.temperature_init
|
||||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
|
||||||
self.temperature = self.log_alpha.exp().item()
|
self.temperature = self.log_alpha.exp().item()
|
||||||
@@ -187,17 +180,10 @@ class SACPolicy(
|
|||||||
"""Reset the policy"""
|
"""Reset the policy"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
|
||||||
if self.actor.fixed_std is not None:
|
|
||||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
|
||||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
|
||||||
super().to(*args, **kwargs)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select action for inference/evaluation"""
|
"""Select action for inference/evaluation"""
|
||||||
# We cached the encoder output to avoid recomputing it
|
# We cached the encoder output to avoid recomputing it if the encoder is shared
|
||||||
observations_features = None
|
observations_features = None
|
||||||
if self.shared_encoder:
|
if self.shared_encoder:
|
||||||
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
|
observations_features = self.actor.encoder.get_cached_image_features(batch=batch, normalize=True)
|
||||||
@@ -1172,90 +1158,3 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
|||||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||||
|
|
||||||
return converted_params
|
return converted_params
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# # Benchmark the CriticEnsemble performance
|
|
||||||
# import time
|
|
||||||
|
|
||||||
# # Configuration
|
|
||||||
# num_critics = 10
|
|
||||||
# batch_size = 32
|
|
||||||
# action_dim = 7
|
|
||||||
# obs_dim = 64
|
|
||||||
# hidden_dims = [256, 256]
|
|
||||||
# num_iterations = 100
|
|
||||||
|
|
||||||
# print("Creating test environment...")
|
|
||||||
|
|
||||||
# # Create a simple dummy encoder
|
|
||||||
# class DummyEncoder(nn.Module):
|
|
||||||
# def __init__(self):
|
|
||||||
# super().__init__()
|
|
||||||
# self.output_dim = obs_dim
|
|
||||||
# self.parameters_to_optimize = []
|
|
||||||
|
|
||||||
# def forward(self, obs):
|
|
||||||
# # Just return a random tensor of the right shape
|
|
||||||
# # In practice, this would encode the observations
|
|
||||||
# return torch.randn(batch_size, obs_dim, device=device)
|
|
||||||
|
|
||||||
# # Create critic heads
|
|
||||||
# print(f"Creating {num_critics} critic heads...")
|
|
||||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
# critic_heads = [
|
|
||||||
# CriticHead(
|
|
||||||
# input_dim=obs_dim + action_dim,
|
|
||||||
# hidden_dims=hidden_dims,
|
|
||||||
# ).to(device)
|
|
||||||
# for _ in range(num_critics)
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# # Create the critic ensemble
|
|
||||||
# print("Creating CriticEnsemble...")
|
|
||||||
# critic_ensemble = CriticEnsemble(
|
|
||||||
# encoder=DummyEncoder().to(device),
|
|
||||||
# ensemble=critic_heads,
|
|
||||||
# output_normalization=nn.Identity(),
|
|
||||||
# ).to(device)
|
|
||||||
|
|
||||||
# # Create random input data
|
|
||||||
# print("Creating input data...")
|
|
||||||
# obs_dict = {
|
|
||||||
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
|
|
||||||
# }
|
|
||||||
# actions = torch.randn(batch_size, action_dim, device=device)
|
|
||||||
|
|
||||||
# # Warmup run
|
|
||||||
# print("Warming up...")
|
|
||||||
# _ = critic_ensemble(obs_dict, actions)
|
|
||||||
|
|
||||||
# # Time the forward pass
|
|
||||||
# print(f"Running benchmark with {num_iterations} iterations...")
|
|
||||||
# start_time = time.perf_counter()
|
|
||||||
# for _ in range(num_iterations):
|
|
||||||
# q_values = critic_ensemble(obs_dict, actions)
|
|
||||||
# end_time = time.perf_counter()
|
|
||||||
|
|
||||||
# # Print results
|
|
||||||
# elapsed_time = end_time - start_time
|
|
||||||
# print(f"Total time: {elapsed_time:.4f} seconds")
|
|
||||||
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
|
|
||||||
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
|
|
||||||
|
|
||||||
# Verify that all critic heads produce different outputs
|
|
||||||
# This confirms each critic head is unique
|
|
||||||
# print("\nVerifying critic outputs are different:")
|
|
||||||
# for i in range(num_critics):
|
|
||||||
# for j in range(i + 1, num_critics):
|
|
||||||
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
|
|
||||||
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
|
|
||||||
|
|
||||||
from lerobot.configs import parser
|
|
||||||
|
|
||||||
@parser.wrap()
|
|
||||||
def main(config: SACConfig):
|
|
||||||
policy = SACPolicy(config=config)
|
|
||||||
print("yolo")
|
|
||||||
|
|
||||||
main()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user