Handle new config with sac

This commit is contained in:
AdilZouitine
2025-03-24 20:19:28 +00:00
parent b2025b852c
commit f483931fc0
4 changed files with 242 additions and 240 deletions

View File

@@ -29,18 +29,17 @@ import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
from lerobot.common.policies.utils import get_device_from_parameters
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
PreTrainedPolicy,
):
config_class = SACConfig
name = "sac"
def __init__(
@@ -48,35 +47,33 @@ class SACPolicy(
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__()
if config is None:
config = SACConfig()
super().__init__(config)
config.validate_features()
self.config = config
if config.input_normalization_modes is not None:
if config.dataset_stats is not None:
input_normalization_params = _convert_normalization_params_to_tensor(
config.input_normalization_params
config.dataset_stats
)
self.normalize_inputs = Normalize(
config.input_shapes,
config.input_normalization_modes,
config.input_features,
config.normalization_mapping,
input_normalization_params,
)
else:
self.normalize_inputs = nn.Identity()
output_normalization_params = _convert_normalization_params_to_tensor(
config.output_normalization_params
config.dataset_stats
)
# HACK: This is hacky and should be removed
dataset_stats = dataset_stats or output_normalization_params
dataset_stats = dataset_stats or output_normalization_params
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
# NOTE: For images the encoder should be shared between the actor and critic
@@ -90,7 +87,7 @@ class SACPolicy(
# Create a list of critic heads
critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@@ -105,7 +102,7 @@ class SACPolicy(
# Create target critic heads as deepcopies of the original critic heads
target_critic_heads = [
CriticHead(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
input_dim=encoder_critic.output_dim + config.output_features["action"].shape[0],
**config.critic_network_kwargs,
)
for _ in range(config.num_critics)
@@ -125,12 +122,12 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
action_dim=config.output_features["action"].shape[0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
config.target_entropy = -np.prod(config.output_features["action"].shape[0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@@ -140,102 +137,13 @@ class SACPolicy(
self.log_alpha = nn.Parameter(torch.tensor([math.log(temperature_init)]))
self.temperature = self.log_alpha.exp().item()
def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly"""
import json
import os
from dataclasses import asdict
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import save_model
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
# Save config
config_dict = asdict(self.config)
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
json.dump(config_dict, f, indent=2)
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
@classmethod
def _from_pretrained(
cls,
*,
model_id: str,
revision: Optional[str],
cache_dir: Optional[Union[str, Path]],
force_download: bool,
proxies: Optional[Dict],
resume_download: Optional[bool],
local_files_only: bool,
token: Optional[Union[str, bool]],
map_location: str = "cpu",
strict: bool = False,
**model_kwargs,
) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files"""
import json
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
if os.path.isdir(model_id):
model_path = Path(model_id)
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
config_file = os.path.join(model_path, CONFIG_NAME)
else:
# Download the safetensors file from the hub
safetensors_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
# Download the config file
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except Exception:
config_file = None
# Load or create config
if config_file and os.path.exists(config_file):
# Load config from file
with open(config_file) as f:
config_dict = json.load(f)
config = SACConfig(**config_dict)
else:
# Use the provided config or create a default one
config = model_kwargs.get("config", SACConfig())
# Create a new instance with the loaded config
model = cls(config=config)
# Load state dict from safetensors file
if os.path.exists(safetensors_file):
load_model(model, filename=safetensors_file, device=map_location)
return model
def get_optim_params(self) -> dict:
return {
"actor": self.actor.parameters_to_optimize,
"critic": self.critic_ensemble.parameters_to_optimize,
"temperature": self.log_alpha,
}
def reset(self):
"""Reset the policy"""
@@ -667,7 +575,7 @@ class SACObservationEncoder(nn.Module):
self.parameters_to_optimize = []
self.aggregation_size: int = 0
if any("observation.image" in key for key in config.input_shapes):
if any("observation.image" in key for key in config.input_features):
self.camera_number = config.camera_number
if self.config.vision_encoder_name is not None:
@@ -682,12 +590,12 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.all_image_keys = [k for k in config.input_features if k.startswith("observation.image")]
if "observation.state" in config.input_shapes:
if "observation.state" in config.input_features:
self.state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.state"][0],
in_features=config.input_features["observation.state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
@@ -697,10 +605,10 @@ class SACObservationEncoder(nn.Module):
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
if "observation.environment_state" in config.input_shapes:
if "observation.environment_state" in config.input_features:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
in_features=config.input_shapes["observation.environment_state"][0],
in_features=config.input_features["observation.environment_state"].shape[0],
out_features=config.latent_dim,
),
nn.LayerNorm(normalized_shape=config.latent_dim),
@@ -727,9 +635,9 @@ class SACObservationEncoder(nn.Module):
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
if "observation.environment_state" in self.config.input_features:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
if "observation.state" in self.config.input_features:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
features = torch.cat(tensors=feat, dim=-1)
@@ -744,11 +652,11 @@ class SACObservationEncoder(nn.Module):
class DefaultImageEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config: SACConfig):
super().__init__()
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_shapes["observation.image"][0],
in_channels=config.input_features["observation.image"].shape[0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
@@ -776,7 +684,7 @@ class DefaultImageEncoder(nn.Module):
),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
dummy_batch = torch.zeros(1, *config.input_features["observation.image"].shape)
with torch.inference_mode():
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
@@ -793,7 +701,7 @@ class DefaultImageEncoder(nn.Module):
class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
def __init__(self, config: SACConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
@@ -803,7 +711,7 @@ class PretrainedImageEncoder(nn.Module):
nn.Tanh(),
)
def _load_pretrained_vision_encoder(self, config):
def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder"""
from transformers import AutoModel
@@ -857,73 +765,73 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
if __name__ == "__main__":
# Benchmark the CriticEnsemble performance
import time
# # Benchmark the CriticEnsemble performance
# import time
# Configuration
num_critics = 10
batch_size = 32
action_dim = 7
obs_dim = 64
hidden_dims = [256, 256]
num_iterations = 100
# # Configuration
# num_critics = 10
# batch_size = 32
# action_dim = 7
# obs_dim = 64
# hidden_dims = [256, 256]
# num_iterations = 100
print("Creating test environment...")
# print("Creating test environment...")
# Create a simple dummy encoder
class DummyEncoder(nn.Module):
def __init__(self):
super().__init__()
self.output_dim = obs_dim
self.parameters_to_optimize = []
# # Create a simple dummy encoder
# class DummyEncoder(nn.Module):
# def __init__(self):
# super().__init__()
# self.output_dim = obs_dim
# self.parameters_to_optimize = []
def forward(self, obs):
# Just return a random tensor of the right shape
# In practice, this would encode the observations
return torch.randn(batch_size, obs_dim, device=device)
# def forward(self, obs):
# # Just return a random tensor of the right shape
# # In practice, this would encode the observations
# return torch.randn(batch_size, obs_dim, device=device)
# Create critic heads
print(f"Creating {num_critics} critic heads...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
critic_heads = [
CriticHead(
input_dim=obs_dim + action_dim,
hidden_dims=hidden_dims,
).to(device)
for _ in range(num_critics)
]
# # Create critic heads
# print(f"Creating {num_critics} critic heads...")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# critic_heads = [
# CriticHead(
# input_dim=obs_dim + action_dim,
# hidden_dims=hidden_dims,
# ).to(device)
# for _ in range(num_critics)
# ]
# Create the critic ensemble
print("Creating CriticEnsemble...")
critic_ensemble = CriticEnsemble(
encoder=DummyEncoder().to(device),
ensemble=critic_heads,
output_normalization=nn.Identity(),
).to(device)
# # Create the critic ensemble
# print("Creating CriticEnsemble...")
# critic_ensemble = CriticEnsemble(
# encoder=DummyEncoder().to(device),
# ensemble=critic_heads,
# output_normalization=nn.Identity(),
# ).to(device)
# Create random input data
print("Creating input data...")
obs_dict = {
"observation.state": torch.randn(batch_size, obs_dim, device=device),
}
actions = torch.randn(batch_size, action_dim, device=device)
# # Create random input data
# print("Creating input data...")
# obs_dict = {
# "observation.state": torch.randn(batch_size, obs_dim, device=device),
# }
# actions = torch.randn(batch_size, action_dim, device=device)
# Warmup run
print("Warming up...")
_ = critic_ensemble(obs_dict, actions)
# # Warmup run
# print("Warming up...")
# _ = critic_ensemble(obs_dict, actions)
# Time the forward pass
print(f"Running benchmark with {num_iterations} iterations...")
start_time = time.perf_counter()
for _ in range(num_iterations):
q_values = critic_ensemble(obs_dict, actions)
end_time = time.perf_counter()
# # Time the forward pass
# print(f"Running benchmark with {num_iterations} iterations...")
# start_time = time.perf_counter()
# for _ in range(num_iterations):
# q_values = critic_ensemble(obs_dict, actions)
# end_time = time.perf_counter()
# Print results
elapsed_time = end_time - start_time
print(f"Total time: {elapsed_time:.4f} seconds")
print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# # Print results
# elapsed_time = end_time - start_time
# print(f"Total time: {elapsed_time:.4f} seconds")
# print(f"Average time per iteration: {elapsed_time / num_iterations * 1000:.4f} ms")
# print(f"Output shape: {q_values.shape}") # Should be [num_critics, batch_size]
# Verify that all critic heads produce different outputs
# This confirms each critic head is unique
@@ -932,3 +840,11 @@ if __name__ == "__main__":
# for j in range(i + 1, num_critics):
# diff = torch.abs(q_values[i] - q_values[j]).mean().item()
# print(f"Mean difference between critic {i} and {j}: {diff:.6f}")
import draccus
from lerobot.configs import parser
@parser.wrap()
def main(config: SACConfig):
policy = SACPolicy(config=config)
print("yolo")
main()