Files
lerobot/lerobot/common/policies/sac/modeling_sac.py
2025-06-03 17:30:22 +02:00

1121 lines
43 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import asdict
from typing import Callable, List, Literal, Optional, Tuple
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
from lerobot.common.policies.normalize import NormalizeBuffer, UnnormalizeBuffer
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.common.policies.utils import get_device_from_parameters
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy(
PreTrainedPolicy,
):
config_class = SACConfig
name = "sac"
def __init__(
self,
config: SACConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
super().__init__(config)
config.validate_features()
self.config = config
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features["action"].shape[0]
self._init_normalization(dataset_stats)
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
self._init_temperature()
def get_optim_params(self) -> dict:
optim_params = {
"actor": [
p
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
}
if self.config.num_discrete_actions is not None:
optim_params["discrete_critic"] = self.discrete_critic.parameters()
return optim_params
def reset(self):
"""Reset the policy"""
pass
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
observations_features = None
if self.shared_encoder and self.actor.encoder.has_images:
# Cache and normalize image features
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
actions, _, _ = self.actor(batch, observations_features)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.num_discrete_actions is not None:
discrete_action_value = self.discrete_critic(batch, observations_features)
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
actions = torch.cat([actions, discrete_action], dim=-1)
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
Returns:
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch["action"]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
return {"loss_critic": loss_critic}
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_discrete_critic = self.compute_loss_discrete_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_discrete_critic": loss_discrete_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
}
if model == "temperature":
return {
"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
}
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
strict=False,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def update_temperature(self):
self.temperature = self.log_alpha.exp().item()
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})["action"]
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def compute_loss_discrete_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self.discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self.discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self.discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features)
actions_pi: Tensor = self.unnormalize_outputs({"action": actions_pi})["action"]
q_preds = self.critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _init_normalization(self, dataset_stats):
"""Initialize input/output normalization modules."""
self.normalize_inputs = nn.Identity()
self.normalize_targets = nn.Identity()
self.unnormalize_outputs = nn.Identity()
if self.config.dataset_stats is not None:
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
self.normalize_inputs = NormalizeBuffer(
self.config.input_features, self.config.normalization_mapping, params
)
stats = dataset_stats or params
self.normalize_targets = NormalizeBuffer(
self.config.output_features, self.config.normalization_mapping, stats
)
self.unnormalize_outputs = UnnormalizeBuffer(
self.config.output_features, self.config.normalization_mapping, stats
)
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
self.encoder_actor = (
self.encoder_critic
if self.shared_encoder
else SACObservationEncoder(self.config, self.normalize_inputs)
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional discrete critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critics()
def _init_discrete_critics(self):
"""Build discrete discrete critic ensemble and target networks."""
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
self.discrete_critic_target = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
# NOTE: The actor select only the continuous action part
self.actor = Policy(
encoder=self.encoder_actor,
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
action_dim=continuous_action_dim,
encoder_is_shared=self.shared_encoder,
**asdict(self.config.policy_kwargs),
)
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self):
"""Set up temperature parameter and initial log_alpha."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
self.temperature = self.log_alpha.exp().item()
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
super().__init__()
self.config = config
self.input_normalization = input_normalizer
self._init_image_layers()
self._init_state_layers()
self._compute_output_dim()
def _init_image_layers(self) -> None:
self.image_keys = [k for k in self.config.input_features if is_image_feature(k)]
self.has_images = bool(self.image_keys)
if not self.has_images:
return
if self.config.vision_encoder_name is not None:
self.image_encoder = PretrainedImageEncoder(self.config)
else:
self.image_encoder = DefaultImageEncoder(self.config)
if self.config.freeze_vision_encoder:
freeze_image_encoder(self.image_encoder)
dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape)
with torch.no_grad():
_, channels, height, width = self.image_encoder(dummy).shape
self.spatial_embeddings = nn.ModuleDict()
self.post_encoders = nn.ModuleDict()
for key in self.image_keys:
name = key.replace(".", "_")
self.spatial_embeddings[name] = SpatialLearnedEmbeddings(
height=height,
width=width,
channel=channels,
num_features=self.config.image_embedding_pooling_dim,
)
self.post_encoders[name] = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(
in_features=channels * self.config.image_embedding_pooling_dim,
out_features=self.config.latent_dim,
),
nn.LayerNorm(normalized_shape=self.config.latent_dim),
nn.Tanh(),
)
def _init_state_layers(self) -> None:
self.has_env = "observation.environment_state" in self.config.input_features
self.has_state = "observation.state" in self.config.input_features
if self.has_env:
dim = self.config.input_features["observation.environment_state"].shape[0]
self.env_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
if self.has_state:
dim = self.config.input_features["observation.state"].shape[0]
self.state_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
def _compute_output_dim(self) -> None:
out = 0
if self.has_images:
out += len(self.image_keys) * self.config.latent_dim
if self.has_env:
out += self.config.latent_dim
if self.has_state:
out += self.config.latent_dim
self._out_dim = out
def forward(
self, obs: dict[str, Tensor], cache: Optional[dict[str, Tensor]] = None, detach: bool = False
) -> Tensor:
obs = self.input_normalization(obs)
parts = []
if self.has_images:
if cache is None:
cache = self.get_cached_image_features(obs, normalize=False)
parts.append(self._encode_images(cache, detach))
if self.has_env:
parts.append(self.env_encoder(obs["observation.environment_state"]))
if self.has_state:
parts.append(self.state_encoder(obs["observation.state"]))
if parts:
return torch.cat(parts, dim=-1)
raise ValueError(
"No parts to concatenate, you should have at least one image or environment state or state"
)
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
"""Extract and optionally cache image features from observations.
This function processes image observations through the vision encoder once and returns
the resulting features.
When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and
reused across policy components (actor, critic, discrete_critic), avoiding redundant forward passes.
Performance impact:
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
- Caching these features can provide 2-4x speedup in training and inference
Normalization behavior:
- When called from inside forward(): set normalize=False since inputs are already normalized
- When called from outside forward(): set normalize=True to ensure proper input normalization
Usage patterns:
- Called in select_action() with normalize=True
- Called in learner_server.py's get_observation_features() to pre-compute features for all policy components
- Called internally by forward() with normalize=False
Args:
obs: Dictionary of observation tensors containing image keys
normalize: Whether to normalize observations before encoding
Set to True when calling directly from outside the encoder's forward method
Set to False when calling from within forward() where inputs are already normalized
Returns:
Dictionary mapping image keys to their corresponding encoded features
"""
if normalize:
obs = self.input_normalization(obs)
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
out = self.image_encoder(batched)
chunks = torch.chunk(out, len(self.image_keys), dim=0)
return dict(zip(self.image_keys, chunks, strict=False))
def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor:
"""Encode image features from cached observations.
This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders.
It also supports detaching the encoded features if specified.
Args:
cache (dict[str, Tensor]): The cached image features.
detach (bool): Usually when the encoder is shared between actor and critics,
we want to detach the encoded features on the policy side to avoid backprop through the encoder.
More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf`
Returns:
Tensor: The encoded image features.
"""
feats = []
for k, feat in cache.items():
safe_key = k.replace(".", "_")
x = self.spatial_embeddings[safe_key](feat)
x = self.post_encoders[safe_key](x)
if detach:
x = x.detach()
feats.append(x)
return torch.cat(feats, dim=-1)
@property
def output_dim(self) -> int:
return self._out_dim
class MLP(nn.Module):
"""Multi-layer perceptron builder.
Dynamically constructs a sequence of layers based on `hidden_dims`:
1) Linear (in_dim -> out_dim)
2) Optional Dropout if `dropout_rate` > 0 and (not final layer or `activate_final`)
3) LayerNorm on the output features
4) Activation (standard for intermediate layers, `final_activation` for last layer if `activate_final`)
Arguments:
input_dim (int): Size of input feature dimension.
hidden_dims (list[int]): Sizes for each hidden layer.
activations (Callable or str): Activation to apply between layers.
activate_final (bool): Whether to apply activation at the final layer.
dropout_rate (Optional[float]): Dropout probability applied before normalization and activation.
final_activation (Optional[Callable or str]): Activation for the final layer when `activate_final` is True.
For each layer, `in_dim` is updated to the previous `out_dim`. All constructed modules are
stored in `self.net` as an `nn.Sequential` container.
"""
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
layers: list[nn.Module] = []
in_dim = input_dim
total = len(hidden_dims)
for idx, out_dim in enumerate(hidden_dims):
# 1) linear transform
layers.append(nn.Linear(in_dim, out_dim))
is_last = idx == total - 1
# 2-4) optionally add dropout, normalization, and activation
if not is_last or activate_final:
if dropout_rate and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(out_dim))
act_cls = final_activation if is_last and final_activation else activations
act = act_cls if isinstance(act_cls, nn.Module) else getattr(nn, act_cls)()
layers.append(act)
in_dim = out_dim
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class CriticHead(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dims: list[int],
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.output_layer(self.net(x))
class CriticEnsemble(nn.Module):
"""
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (SACObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
output_normalization (nn.Module): normalization layer for actions.
init_final (Optional[float]): optional initializer scale for final layers.
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
"""
def __init__(
self,
encoder: SACObservationEncoder,
ensemble: List[CriticHead],
output_normalization: nn.Module,
init_final: Optional[float] = None,
):
super().__init__()
self.encoder = encoder
self.init_final = init_final
self.output_normalization = output_normalization
self.critics = nn.ModuleList(ensemble)
def forward(
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> torch.Tensor:
device = get_device_from_parameters(self)
# Move each tensor in observations to device
observations = {k: v.to(device) for k, v in observations.items()}
# NOTE: We normalize actions it helps for sample efficiency
actions: dict[str, torch.tensor] = {"action": actions}
# NOTE: Normalization layer took dict in input and outputs a dict that why
actions = self.output_normalization(actions)["action"]
actions = actions.to(device)
obs_enc = self.encoder(observations, cache=observation_features)
inputs = torch.cat([obs_enc, actions], dim=-1)
# Loop through critics and collect outputs
q_values = []
for critic in self.critics:
q_values.append(critic(inputs))
# Stack outputs to match expected shape [num_critics, batch_size]
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
return q_values
class DiscreteCritic(nn.Module):
def __init__(
self,
encoder: nn.Module,
input_dim: int,
hidden_dims: list[int],
output_dim: int = 3,
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
init_final: Optional[float] = None,
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
):
super().__init__()
self.encoder = encoder
self.output_dim = output_dim
self.net = MLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
activations=activations,
activate_final=activate_final,
dropout_rate=dropout_rate,
final_activation=final_activation,
)
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim)
if init_final is not None:
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.output_layer.weight)
def forward(
self, observations: torch.Tensor, observation_features: torch.Tensor | None = None
) -> torch.Tensor:
device = get_device_from_parameters(self)
observations = {k: v.to(device) for k, v in observations.items()}
obs_enc = self.encoder(observations, cache=observation_features)
return self.output_layer(self.net(obs_enc))
class Policy(nn.Module):
def __init__(
self,
encoder: SACObservationEncoder,
network: nn.Module,
action_dim: int,
log_std_min: float = -5,
log_std_max: float = 2,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder: SACObservationEncoder = encoder
self.network = network
self.action_dim = action_dim
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std
self.use_tanh_squash = use_tanh_squash
self.encoder_is_shared = encoder_is_shared
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter
if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
def forward(
self,
observations: torch.Tensor,
observation_features: torch.Tensor | None = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# We detach the encoder if it is shared to avoid backprop through it
# This is important to avoid the encoder to be updated through the policy
obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared)
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
std = torch.exp(log_std) # Match JAX "exp"
std = torch.clamp(std, self.log_std_min, self.log_std_max) # Match JAX default clip
else:
log_std = self.fixed_std.expand_as(means)
# Build transformed distribution
dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std)
# Sample actions (reparameterized)
actions = dist.rsample()
# Compute log_probs
log_probs = dist.log_prob(actions)
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
device = get_device_from_parameters(self)
observations = observations.to(device)
if self.encoder is not None:
with torch.inference_mode():
return self.encoder(observations)
return observations
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
image_key = next(key for key in config.input_features if is_image_feature(key))
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
in_channels=config.input_features[image_key].shape[0],
out_channels=config.image_encoder_hidden_dim,
kernel_size=7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=5,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
nn.Conv2d(
in_channels=config.image_encoder_hidden_dim,
out_channels=config.image_encoder_hidden_dim,
kernel_size=3,
stride=2,
),
nn.ReLU(),
)
def forward(self, x):
x = self.image_enc_layers(x)
return x
def freeze_image_encoder(image_encoder: nn.Module):
"""Freeze all parameters in the encoder"""
for param in image_encoder.parameters():
param.requires_grad = False
class PretrainedImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
def _load_pretrained_vision_encoder(self, config: SACConfig):
"""Set up CNN encoder"""
from transformers import AutoModel
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
elif hasattr(self.image_enc_layers, "fc"):
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
else:
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
return self.image_enc_layers, self.image_enc_out_shape
def forward(self, x):
enc_feat = self.image_enc_layers(x).last_hidden_state
return enc_feat
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
class SpatialLearnedEmbeddings(nn.Module):
def __init__(self, height, width, channel, num_features=8):
"""
PyTorch implementation of learned spatial embeddings
Args:
height: Spatial height of input features
width: Spatial width of input features
channel: Number of input channels
num_features: Number of output embedding dimensions
"""
super().__init__()
self.height = height
self.width = width
self.channel = channel
self.num_features = num_features
self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features))
nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear")
def forward(self, features):
"""
Forward pass for spatial embedding
Args:
features: Input tensor of shape [B, C, H, W] where B is batch size,
C is number of channels, H is height, and W is width
Returns:
Output tensor of shape [B, C*F] where F is the number of features
"""
features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1]
kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F]
# Element-wise multiplication and spatial reduction
output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions
# Reshape to combine channel and feature dimensions
output = output.view(output.size(0), -1) # [B, C*F]
return output
class RescaleFromTanh(Transform):
def __init__(self, low: float = -1, high: float = 1):
super().__init__()
self.low = low
self.high = high
def _call(self, x):
# Rescale from (-1, 1) to (low, high)
return 0.5 * (x + 1.0) * (self.high - self.low) + self.low
def _inverse(self, y):
# Rescale from (low, high) back to (-1, 1)
return 2.0 * (y - self.low) / (self.high - self.low) - 1.0
def log_abs_det_jacobian(self, x, y):
# log|d(rescale)/dx| = sum(log(0.5 * (high - low)))
scale = 0.5 * (self.high - self.low)
return torch.sum(torch.log(scale), dim=-1)
class TanhMultivariateNormalDiag(TransformedDistribution):
def __init__(self, loc, scale_diag, low=None, high=None):
base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag))
transforms = [TanhTransform(cache_size=1)]
if low is not None and high is not None:
low = torch.as_tensor(low)
high = torch.as_tensor(high)
transforms.insert(0, RescaleFromTanh(low, high))
super().__init__(base_dist, transforms)
def mode(self):
# Mode is mean of base distribution, passed through transforms
x = self.base_dist.mean
for transform in self.transforms:
x = transform(x)
return x
def stddev(self):
std = self.base_dist.stddev
x = std
for transform in self.transforms:
x = transform(x)
return x
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params