[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
2945bbb221
commit
7c05755823
@@ -18,8 +18,8 @@
|
||||
# TODO: (1) better device management
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union, Dict, List
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -124,17 +124,13 @@ class SACPolicy(
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(
|
||||
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
|
||||
),
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = (
|
||||
-np.prod(config.output_shapes["action"][0]) / 2
|
||||
) # (-dim(A)/2)
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
@@ -146,10 +142,11 @@ class SACPolicy(
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
"""Custom save method to handle TensorDict properly"""
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import save_model
|
||||
|
||||
save_model(self, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE))
|
||||
@@ -177,12 +174,14 @@ class SACPolicy(
|
||||
**model_kwargs,
|
||||
) -> "SACPolicy":
|
||||
"""Custom load method to handle loading SAC policy from saved files"""
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
from safetensors.torch import load_model
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
# Check if model_id is a local path or a hub model ID
|
||||
@@ -302,14 +301,10 @@ class SACPolicy(
|
||||
) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
|
||||
"action"
|
||||
]
|
||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
@@ -353,21 +348,15 @@ class SACPolicy(
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (
|
||||
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
|
||||
).mean()
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
@@ -408,11 +397,7 @@ class MLP(nn.Module):
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
@@ -424,11 +409,7 @@ class MLP(nn.Module):
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
|
||||
# If we're at the final layer and a final activation is specified, use it
|
||||
if (
|
||||
i + 1 == len(hidden_dims)
|
||||
and activate_final
|
||||
and final_activation is not None
|
||||
):
|
||||
if i + 1 == len(hidden_dims) and activate_final and final_activation is not None:
|
||||
layers.append(
|
||||
final_activation
|
||||
if isinstance(final_activation, nn.Module)
|
||||
@@ -436,9 +417,7 @@ class MLP(nn.Module):
|
||||
)
|
||||
else:
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
@@ -639,15 +618,11 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1.0)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
@@ -660,9 +635,7 @@ class Policy(nn.Module):
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log(
|
||||
(1 - actions.pow(2)) + 1e-6
|
||||
) # Adjust log-probs for Tanh
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
@@ -709,9 +682,7 @@ class SACObservationEncoder(nn.Module):
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
@@ -738,9 +709,7 @@ class SACObservationEncoder(nn.Module):
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(
|
||||
in_features=self.aggregation_size, out_features=config.latent_dim
|
||||
)
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
@@ -753,19 +722,13 @@ class SACObservationEncoder(nn.Module):
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
if len(self.all_image_keys) > 0:
|
||||
images_batched = torch.cat(
|
||||
[obs_dict[key] for key in self.all_image_keys], dim=0
|
||||
)
|
||||
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(
|
||||
images_batched, dim=0, chunks=len(self.all_image_keys)
|
||||
)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(
|
||||
self.env_state_enc_layers(obs_dict["observation.environment_state"])
|
||||
)
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
@@ -833,9 +796,7 @@ class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = (
|
||||
self._load_pretrained_vision_encoder(config)
|
||||
)
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -846,21 +807,15 @@ class PretrainedImageEncoder(nn.Module):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(
|
||||
config.vision_encoder_name, trust_remote_code=True
|
||||
)
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported vision encoder architecture, make sure you are using a CNN"
|
||||
)
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
@@ -896,9 +851,7 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][
|
||||
key
|
||||
].view(3, 1, 1)
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user