[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
bb69cb3c8c
commit
85fe8a3f4e
@@ -59,7 +59,9 @@ class SACPolicy(
|
||||
config.input_normalization_params
|
||||
)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, input_normalization_params
|
||||
config.input_shapes,
|
||||
config.input_normalization_modes,
|
||||
input_normalization_params,
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
@@ -90,7 +92,8 @@ class SACPolicy(
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
input_dim=encoder_critic.output_dim
|
||||
+ config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
@@ -104,7 +107,8 @@ class SACPolicy(
|
||||
ensemble=Ensemble(
|
||||
[
|
||||
CriticHead(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
input_dim=encoder_critic.output_dim
|
||||
+ config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
@@ -120,13 +124,17 @@ class SACPolicy(
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
network=MLP(
|
||||
input_dim=encoder_actor.output_dim, **config.actor_network_kwargs
|
||||
),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
config.target_entropy = (
|
||||
-np.prod(config.output_shapes["action"][0]) / 2
|
||||
) # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
@@ -153,7 +161,11 @@ class SACPolicy(
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, observation_features: Tensor | None = None
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
@@ -173,21 +185,37 @@ class SACPolicy(
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=False,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, observation_features: Tensor | None = None, next_observation_features: Tensor | None = None) -> Tensor:
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
next_action_preds, next_log_probs, _ = self.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True, observation_features=next_observation_features
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
@@ -204,7 +232,12 @@ class SACPolicy(
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False, observation_features=observation_features)
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
@@ -219,20 +252,31 @@ class SACPolicy(
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
def compute_loss_temperature(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
temperature_loss = (
|
||||
-self.log_alpha.exp() * (log_probs + self.config.target_entropy)
|
||||
).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
def compute_loss_actor(
|
||||
self, observations, observation_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False, observation_features=observation_features)
|
||||
q_preds = self.critic_forward(
|
||||
observations,
|
||||
actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
@@ -259,7 +303,11 @@ class MLP(nn.Module):
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
layers.append(
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
@@ -270,7 +318,9 @@ class MLP(nn.Module):
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
activations
|
||||
if isinstance(activations, nn.Module)
|
||||
else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
@@ -381,7 +431,11 @@ class CriticEnsemble(nn.Module):
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
q_values = self.ensemble(inputs) # [num_critics, B, 1]
|
||||
@@ -445,7 +499,11 @@ class Policy(nn.Module):
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observation_features if observation_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
obs_enc = (
|
||||
observation_features
|
||||
if observation_features is not None
|
||||
else (observations if self.encoder is None else self.encoder(observations))
|
||||
)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
@@ -454,11 +512,15 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
log_std = self.log_std_min + 0.5 * (
|
||||
self.log_std_max - self.log_std_min
|
||||
) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
@@ -471,7 +533,9 @@ class Policy(nn.Module):
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
log_probs -= torch.log(
|
||||
(1 - actions.pow(2)) + 1e-6
|
||||
) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
@@ -518,12 +582,15 @@ class SACObservationEncoder(nn.Module):
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
self.all_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
self.all_image_keys = [
|
||||
k for k in config.input_shapes if k.startswith("observation.image")
|
||||
]
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
|
||||
in_features=config.input_shapes["observation.state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
@@ -544,7 +611,9 @@ class SACObservationEncoder(nn.Module):
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.aggregation_layer = nn.Linear(
|
||||
in_features=self.aggregation_size, out_features=config.latent_dim
|
||||
)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
@@ -557,13 +626,19 @@ class SACObservationEncoder(nn.Module):
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Batch all images along the batch dimension, then encode them.
|
||||
if len(self.all_image_keys) > 0:
|
||||
images_batched = torch.cat([obs_dict[key] for key in self.all_image_keys], dim=0)
|
||||
images_batched = torch.cat(
|
||||
[obs_dict[key] for key in self.all_image_keys], dim=0
|
||||
)
|
||||
images_batched = self.image_enc_layers(images_batched)
|
||||
embeddings_chunks = torch.chunk(images_batched, dim=0, chunks=len(self.all_image_keys))
|
||||
embeddings_chunks = torch.chunk(
|
||||
images_batched, dim=0, chunks=len(self.all_image_keys)
|
||||
)
|
||||
feat.extend(embeddings_chunks)
|
||||
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
feat.append(
|
||||
self.env_state_enc_layers(obs_dict["observation.environment_state"])
|
||||
)
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
@@ -631,7 +706,9 @@ class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_layers, self.image_enc_out_shape = (
|
||||
self._load_pretrained_vision_encoder(config)
|
||||
)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
@@ -642,15 +719,21 @@ class PretrainedImageEncoder(nn.Module):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
self.image_enc_layers = AutoModel.from_pretrained(
|
||||
config.vision_encoder_name, trust_remote_code=True
|
||||
)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[
|
||||
-1
|
||||
] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
raise ValueError(
|
||||
"Unsupported vision encoder architecture, make sure you are using a CNN"
|
||||
)
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
@@ -673,7 +756,7 @@ def orthogonal_init():
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
@@ -701,7 +784,9 @@ class Ensemble(nn.Module):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
|
||||
return torch.vmap(self._call, (0, None), randomness="different")(
|
||||
self.params, *args, **kwargs
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Vectorized {len(self)}x " + self._repr
|
||||
@@ -710,7 +795,9 @@ class Ensemble(nn.Module):
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
def flatten_forward_unflatten(
|
||||
fn: Callable[[Tensor], Tensor], image_tensor: Tensor
|
||||
) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
@@ -736,7 +823,9 @@ def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
converted_params[outer_key][key] = converted_params[outer_key][
|
||||
key
|
||||
].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user