[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:41:27 +00:00
committed by AdilZouitine
parent 2945bbb221
commit 7c05755823
123 changed files with 1161 additions and 3425 deletions

View File

@@ -18,8 +18,8 @@
# TODO: (1) better device management
import math
from typing import Callable, Optional, Tuple, Union, Dict, List
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import einops
import numpy as np
@@ -124,17 +124,13 @@ class SACPolicy(
self.actor = Policy(
encoder=encoder_actor,
network=MLP(
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
),
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
encoder_is_shared=config.shared_encoder,
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = (
-np.prod(config.output_shapes["action"][0]) / 2
) # (-dim(A)/2)
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO (azouitine): Handle the case where the temparameter is a fixed
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
@@ -146,10 +142,11 @@ class SACPolicy(
def _save_pretrained(self, save_directory):
"""Custom save method to handle TensorDict properly"""
import os
import json
import os
from dataclasses import asdict
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import save_model
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
@@ -177,12 +174,14 @@ class SACPolicy(
**model_kwargs,
) -> "SACPolicy":
"""Custom load method to handle loading SAC policy from saved files"""
import os
import json
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
from safetensors.torch import load_model
from lerobot.common.policies.sac.configuration_sac import SACConfig
# Check if model_id is a local path or a hub model ID
@@ -302,14 +301,10 @@ class SACPolicy(
) -> Tensor:
self.temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(
next_observations, next_observation_features
)
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
"action"
]
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
# 2- compute q targets
q_targets = self.critic_forward(
@@ -353,21 +348,15 @@ class SACPolicy(
).sum()
return critics_loss
def compute_loss_temperature(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
).mean()
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self, observations, observation_features: Tensor | None = None
) -> Tensor:
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
self.temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations, observation_features)
@@ -408,11 +397,7 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
)
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
# Rest of the layers
for i in range(1, len(hidden_dims)):
@@ -424,11 +409,7 @@ class MLP(nn.Module):
layers.append(nn.LayerNorm(hidden_dims[i]))
# If we're at the final layer and a final activation is specified, use it
if (
i + 1 == len(hidden_dims)
and activate_final
and final_activation is not None
):
if i + 1 == len(hidden_dims) and activate_final and final_activation is not None:
layers.append(
final_activation
if isinstance(final_activation, nn.Module)
@@ -436,9 +417,7 @@ class MLP(nn.Module):
)
else:
layers.append(
activations
if isinstance(activations, nn.Module)
else getattr(nn, activations)()
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
@@ -639,15 +618,11 @@ class Policy(nn.Module):
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), (
"[ERROR] log_std became NaN after std_layer!"
)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (
self.log_std_max - self.log_std_min
) * (log_std + 1.0)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
@@ -660,9 +635,7 @@ class Policy(nn.Module):
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log(
(1 - actions.pow(2)) + 1e-6
) # Adjust log-probs for Tanh
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
@@ -709,9 +682,7 @@ class SACObservationEncoder(nn.Module):
freeze_image_encoder(self.image_enc_layers)
else:
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
self.all_image_keys = [
k for k in config.input_shapes if k.startswith("observation.image")
]
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
@@ -738,9 +709,7 @@ class SACObservationEncoder(nn.Module):
self.aggregation_size += config.latent_dim
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
self.aggregation_layer = nn.Linear(
in_features=self.aggregation_size, out_features=config.latent_dim
)
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
@@ -753,19 +722,13 @@ class SACObservationEncoder(nn.Module):
obs_dict = self.input_normalization(obs_dict)
# Batch all images along the batch dimension, then encode them.
if len(self.all_image_keys) > 0:
images_batched = torch.cat(
[obs_dict[key] for key in self.all_image_keys], dim=0
)
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
images_batched = self.image_enc_layers(images_batched)
embeddings_chunks = torch.chunk(
images_batched, dim=0, chunks=len(self.all_image_keys)
)
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
feat.extend(embeddings_chunks)
if "observation.environment_state" in self.config.input_shapes:
feat.append(
self.env_state_enc_layers(obs_dict["observation.environment_state"])
)
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
@@ -833,9 +796,7 @@ class PretrainedImageEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = (
self._load_pretrained_vision_encoder(config)
)
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
self.image_enc_proj = nn.Sequential(
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -846,21 +807,15 @@ class PretrainedImageEncoder(nn.Module):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(
config.vision_encoder_name, trust_remote_code=True
)
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
# self.image_enc_layers.pooler = Identity()
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
-1
] # Last channel dimension
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError(
"Unsupported vision encoder architecture, make sure you are using a CNN"
)
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
@@ -896,9 +851,7 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][
key
].view(3, 1, 1)
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params