forked from tangger/lerobot
- Updated action selection to use distribution sampling and log probabilities for better stochastic behavior. - Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs. - Cleaned up code by removing unnecessary comments and improving readability. These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
693 lines
27 KiB
Python
693 lines
27 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team.
|
|
# All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# TODO: (1) better device management
|
|
|
|
from collections import deque
|
|
from copy import deepcopy
|
|
import math
|
|
from typing import Callable, Optional, Sequence, Tuple
|
|
|
|
import einops
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F # noqa: N812
|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
from torch import Tensor
|
|
|
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
|
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
|
|
|
|
|
class SACPolicy(
|
|
nn.Module,
|
|
PyTorchModelHubMixin,
|
|
library_name="lerobot",
|
|
repo_url="https://github.com/huggingface/lerobot",
|
|
tags=["robotics", "RL", "SAC"],
|
|
):
|
|
name = "sac"
|
|
|
|
def __init__(
|
|
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
|
):
|
|
super().__init__()
|
|
|
|
if config is None:
|
|
config = SACConfig()
|
|
self.config = config
|
|
|
|
if config.input_normalization_modes is not None:
|
|
self.normalize_inputs = Normalize(
|
|
config.input_shapes, config.input_normalization_modes, dataset_stats
|
|
)
|
|
else:
|
|
self.normalize_inputs = nn.Identity()
|
|
self.normalize_targets = Normalize(
|
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
|
)
|
|
self.unnormalize_outputs = Unnormalize(
|
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
|
)
|
|
encoder = SACObservationEncoder(config)
|
|
# Define networks
|
|
critic_nets = []
|
|
for _ in range(config.num_critics):
|
|
critic_net = Critic(encoder=encoder, network=MLP(**config.critic_network_kwargs))
|
|
critic_nets.append(critic_net)
|
|
|
|
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
|
self.critic_target = deepcopy(self.critic_ensemble)
|
|
|
|
self.actor = Policy(
|
|
encoder=encoder,
|
|
network=MLP(**config.actor_network_kwargs),
|
|
action_dim=config.output_shapes["action"][0],
|
|
**config.policy_kwargs,
|
|
)
|
|
if config.target_entropy is None:
|
|
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
|
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
|
|
|
def reset(self):
|
|
"""
|
|
Clear observation and action queues. Should be called on `env.reset()`
|
|
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
|
"""
|
|
|
|
self._queues = {
|
|
"observation.state": deque(maxlen=1),
|
|
"action": deque(maxlen=1),
|
|
}
|
|
if "observation.image" in self.config.input_shapes:
|
|
self._queues["observation.image"] = deque(maxlen=1)
|
|
if "observation.environment_state" in self.config.input_shapes:
|
|
self._queues["observation.environment_state"] = deque(maxlen=1)
|
|
|
|
@torch.no_grad()
|
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""Select action for inference/evaluation"""
|
|
distribution = self.actor(batch)
|
|
# Sample from the distribution and return just the actions
|
|
actions = distribution.mode() # or distribution.sample() for stochastic actions
|
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
|
return actions
|
|
|
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
|
"""Run the batch through the model and compute the loss.
|
|
|
|
Returns a dictionary with loss as a tensor, and other information as native floats.
|
|
"""
|
|
batch = self.normalize_inputs(batch)
|
|
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
|
# the next observation for caluculating the right td index.
|
|
actions = batch["action"][:, 0]
|
|
rewards = batch["next.reward"][:, 0]
|
|
observations = {}
|
|
next_observations = {}
|
|
for k in batch:
|
|
if k.startswith("observation."):
|
|
observations[k] = batch[k][:, 0]
|
|
next_observations[k] = batch[k][:, 1]
|
|
|
|
# perform image augmentation
|
|
|
|
# reward bias from HIL-SERL code base
|
|
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
|
|
|
# calculate critics loss
|
|
# 1- compute actions from policy
|
|
distribution = self.actor(observations)
|
|
action_preds = distribution.sample()
|
|
log_probs = distribution.log_prob(action_preds)
|
|
action_preds = torch.clamp(action_preds, -1, +1)
|
|
# 2- compute q targets
|
|
q_targets = self.target_qs(next_observations, action_preds)
|
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
|
if self.config.num_subsample_critics is not None:
|
|
indices = torch.randperm(self.config.num_critics)
|
|
indices = indices[:self.config.num_subsample_critics]
|
|
q_targets = q_targets[indices]
|
|
|
|
# critics subsample size
|
|
min_q = q_targets.min(dim=0)
|
|
|
|
# compute td target
|
|
td_target = rewards + self.discount * min_q
|
|
|
|
# 3- compute predicted qs
|
|
q_preds = self.critic_ensemble(observations, actions)
|
|
|
|
# 4- Calculate loss
|
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
|
#critics_loss = (
|
|
# (
|
|
# F.mse_loss(
|
|
# q_preds,
|
|
# einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
|
# reduction="none",
|
|
# ).sum(0) # sum over ensemble
|
|
# # `q_preds_ensemble` depends on the first observation and the actions.
|
|
# * ~batch["observation.state_is_pad"][0]
|
|
# * ~batch["action_is_pad"]
|
|
# # q_targets depends on the reward and the next observations.
|
|
# * ~batch["next.reward_is_pad"]
|
|
# * ~batch["observation.state_is_pad"][1:]
|
|
# )
|
|
# .sum(0)
|
|
# .mean()
|
|
#)
|
|
# 4- Calculate loss
|
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
|
critics_loss = F.mse_loss(
|
|
q_preds, # shape: [num_critics, batch_size]
|
|
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
|
reduction="none"
|
|
).sum(0).mean()
|
|
# breakpoint()
|
|
|
|
# calculate actors loss
|
|
# 1- temperature
|
|
temperature = self.temperature()
|
|
|
|
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
|
<<<<<<< HEAD
|
|
actions, log_probs = self.actor(observations) \
|
|
|
|
=======
|
|
distribution = self.actor(observations)
|
|
actions = distribution.sample()
|
|
log_probs = distribution.log_prob(actions)
|
|
actions = torch.clamp(actions, -1, +1)
|
|
>>>>>>> d3c62b92 (Refactor SACPolicy for improved action sampling and standard deviation handling)
|
|
# 3- get q-value predictions
|
|
with torch.no_grad():
|
|
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
|
actor_loss = (
|
|
-(q_preds - temperature * log_probs).mean()
|
|
* ~batch["observation.state_is_pad"][0]
|
|
* ~batch["action_is_pad"]
|
|
).mean()
|
|
|
|
# calculate temperature loss
|
|
# 1- calculate entropy
|
|
entropy = -log_probs.mean()
|
|
temperature_loss = self.temp(lhs=entropy, rhs=self.config.target_entropy)
|
|
|
|
loss = critics_loss + actor_loss + temperature_loss
|
|
|
|
return {
|
|
"critics_loss": critics_loss.item(),
|
|
"actor_loss": actor_loss.item(),
|
|
"temperature_loss": temperature_loss.item(),
|
|
"temperature": temperature.item(),
|
|
"entropy": entropy.item(),
|
|
"loss": loss,
|
|
}
|
|
|
|
def update(self):
|
|
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
|
# TODO: implement UTD update
|
|
# First update only critics for utd_ratio-1 times
|
|
# for critic_step in range(self.config.utd_ratio - 1):
|
|
# only update critic and critic target
|
|
# Then update critic, critic target, actor and temperature
|
|
|
|
# for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
|
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_dims: list[int],
|
|
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
|
activate_final: bool = False,
|
|
dropout_rate: Optional[float] = None,
|
|
):
|
|
super().__init__()
|
|
self.activate_final = activate_final
|
|
layers = []
|
|
|
|
for i, size in enumerate(hidden_dims):
|
|
layers.append(nn.Linear(hidden_dims[i-1] if i > 0 else hidden_dims[0], size))
|
|
|
|
if i + 1 < len(hidden_dims) or activate_final:
|
|
if dropout_rate is not None and dropout_rate > 0:
|
|
layers.append(nn.Dropout(p=dropout_rate))
|
|
layers.append(nn.LayerNorm(size))
|
|
layers.append(
|
|
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
|
)
|
|
|
|
self.net = nn.Sequential(*layers)
|
|
|
|
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
|
# in training mode or not. TODO: find better way to do this
|
|
self.train(train)
|
|
return self.net(x)
|
|
|
|
|
|
class Critic(nn.Module):
|
|
def __init__(
|
|
self,
|
|
encoder: Optional[nn.Module],
|
|
network: nn.Module,
|
|
init_final: Optional[float] = None,
|
|
activate_final: bool = False,
|
|
device: str = "cuda",
|
|
):
|
|
super().__init__()
|
|
self.device = torch.device(device)
|
|
self.encoder = encoder
|
|
self.network = network
|
|
self.init_final = init_final
|
|
self.activate_final = activate_final
|
|
|
|
# Find the last Linear layer's output dimension
|
|
for layer in reversed(network.net):
|
|
if isinstance(layer, nn.Linear):
|
|
out_features = layer.out_features
|
|
break
|
|
|
|
# Output layer
|
|
if init_final is not None:
|
|
self.output_layer = nn.Linear(out_features, 1)
|
|
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
|
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
|
else:
|
|
self.output_layer = nn.Linear(out_features, 1)
|
|
orthogonal_init()(self.output_layer.weight)
|
|
|
|
self.to(self.device)
|
|
|
|
def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False) -> torch.Tensor:
|
|
self.train(train)
|
|
|
|
observations = observations.to(self.device)
|
|
actions = actions.to(self.device)
|
|
|
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
|
|
|
inputs = torch.cat([obs_enc, actions], dim=-1)
|
|
x = self.network(inputs)
|
|
value = self.output_layer(x)
|
|
return value.squeeze(-1)
|
|
|
|
def q_value_ensemble(
|
|
self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
|
|
) -> torch.Tensor:
|
|
observations = observations.to(self.device)
|
|
actions = actions.to(self.device)
|
|
|
|
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
|
|
batch_size, num_actions = actions.shape[:2]
|
|
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
|
|
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
|
|
actions_flat = actions.reshape(-1, actions.shape[-1])
|
|
q_values = self(obs_flat, actions_flat, train)
|
|
return q_values.reshape(batch_size, num_actions)
|
|
else:
|
|
return self(observations, actions, train)
|
|
|
|
|
|
class Policy(nn.Module):
|
|
def __init__(
|
|
self,
|
|
encoder: Optional[nn.Module],
|
|
network: nn.Module,
|
|
action_dim: int,
|
|
std_parameterization: str = "exp",
|
|
std_min: float = 0.05,
|
|
std_max: float = 2.0,
|
|
tanh_squash_distribution: bool = False,
|
|
fixed_std: Optional[torch.Tensor] = None,
|
|
init_final: Optional[float] = None,
|
|
activate_final: bool = False,
|
|
device: str = "cuda",
|
|
):
|
|
super().__init__()
|
|
self.device = torch.device(device)
|
|
self.encoder = encoder
|
|
self.network = network
|
|
self.action_dim = action_dim
|
|
self.std_parameterization = std_parameterization
|
|
self.std_min = std_min
|
|
self.std_max = std_max
|
|
self.tanh_squash_distribution = tanh_squash_distribution
|
|
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
|
self.activate_final = activate_final
|
|
|
|
# Find the last Linear layer's output dimension
|
|
for layer in reversed(network.net):
|
|
if isinstance(layer, nn.Linear):
|
|
out_features = layer.out_features
|
|
break
|
|
|
|
# Mean layer
|
|
self.mean_layer = nn.Linear(out_features, action_dim)
|
|
if init_final is not None:
|
|
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
|
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
|
else:
|
|
orthogonal_init()(self.mean_layer.weight)
|
|
|
|
# Standard deviation layer or parameter
|
|
if fixed_std is None:
|
|
if std_parameterization == "uniform":
|
|
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
|
else:
|
|
self.std_layer = nn.Linear(out_features, action_dim)
|
|
if init_final is not None:
|
|
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
|
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
|
else:
|
|
orthogonal_init()(self.std_layer.weight)
|
|
|
|
self.to(self.device)
|
|
|
|
def forward(
|
|
self,
|
|
observations: torch.Tensor,
|
|
temperature: float = 1.0,
|
|
train: bool = False,
|
|
non_squash_distribution: bool = False,
|
|
) -> torch.distributions.Distribution:
|
|
self.train(train)
|
|
|
|
# Encode observations if encoder exists
|
|
if self.encoder is not None:
|
|
with torch.set_grad_enabled(train):
|
|
obs_enc = self.encoder(observations, train=train)
|
|
else:
|
|
obs_enc = observations
|
|
|
|
# Get network outputs
|
|
outputs = self.network(obs_enc)
|
|
means = self.mean_layer(outputs)
|
|
|
|
# Compute standard deviations
|
|
if self.fixed_std is None:
|
|
if self.std_parameterization == "exp":
|
|
log_stds = self.std_layer(outputs)
|
|
# Clamp log_stds to prevent too large or small values
|
|
log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max))
|
|
stds = torch.exp(log_stds)
|
|
elif self.std_parameterization == "softplus":
|
|
stds = torch.nn.functional.softplus(self.std_layer(outputs))
|
|
stds = torch.clamp(stds, self.std_min, self.std_max)
|
|
elif self.std_parameterization == "uniform":
|
|
log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max))
|
|
stds = torch.exp(log_stds).expand_as(means)
|
|
else:
|
|
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
|
|
else:
|
|
assert self.std_parameterization == "fixed"
|
|
stds = self.fixed_std.expand_as(means)
|
|
|
|
# Scale with temperature
|
|
temperature = torch.tensor(temperature, device=self.device)
|
|
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
|
|
|
|
# Create distribution
|
|
if self.tanh_squash_distribution and not non_squash_distribution:
|
|
distribution = TanhMultivariateNormalDiag(
|
|
loc=means,
|
|
scale_diag=stds,
|
|
)
|
|
else:
|
|
distribution = torch.distributions.Normal(
|
|
loc=means,
|
|
scale=stds,
|
|
)
|
|
|
|
return distribution
|
|
|
|
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
|
"""Get encoded features from observations"""
|
|
observations = observations.to(self.device)
|
|
if self.encoder is not None:
|
|
with torch.no_grad():
|
|
return self.encoder(observations, train=False)
|
|
return observations
|
|
|
|
|
|
class SACObservationEncoder(nn.Module):
|
|
"""Encode image and/or state vector observations.
|
|
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
|
"""
|
|
|
|
def __init__(self, config: SACConfig):
|
|
"""
|
|
Creates encoders for pixel and/or state modalities.
|
|
"""
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
if "observation.image" in config.input_shapes:
|
|
self.image_enc_layers = nn.Sequential(
|
|
nn.Conv2d(
|
|
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
|
),
|
|
nn.ReLU(),
|
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
|
nn.ReLU(),
|
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
|
nn.ReLU(),
|
|
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
|
nn.ReLU(),
|
|
)
|
|
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
|
with torch.inference_mode():
|
|
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
|
self.image_enc_layers.extend(
|
|
nn.Sequential(
|
|
nn.Flatten(),
|
|
nn.Linear(np.prod(out_shape), config.latent_dim),
|
|
nn.LayerNorm(config.latent_dim),
|
|
nn.Tanh(),
|
|
)
|
|
)
|
|
if "observation.state" in config.input_shapes:
|
|
self.state_enc_layers = nn.Sequential(
|
|
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
|
nn.ELU(),
|
|
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
|
nn.LayerNorm(config.latent_dim),
|
|
nn.Tanh(),
|
|
)
|
|
if "observation.environment_state" in config.input_shapes:
|
|
self.env_state_enc_layers = nn.Sequential(
|
|
nn.Linear(
|
|
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
|
),
|
|
nn.ELU(),
|
|
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
|
nn.LayerNorm(config.latent_dim),
|
|
nn.Tanh(),
|
|
)
|
|
|
|
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
|
"""Encode the image and/or state vector.
|
|
|
|
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
|
over all features.
|
|
"""
|
|
feat = []
|
|
# Concatenate all images along the channel dimension.
|
|
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
|
for image_key in image_keys:
|
|
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
|
if "observation.environment_state" in self.config.input_shapes:
|
|
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
|
if "observation.state" in self.config.input_shapes:
|
|
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
|
return torch.stack(feat, dim=0).mean(0)
|
|
|
|
|
|
class LagrangeMultiplier(nn.Module):
|
|
def __init__(self, init_value: float = 1.0, constraint_shape: Sequence[int] = (), device: str = "cuda"):
|
|
super().__init__()
|
|
self.device = torch.device(device)
|
|
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
|
|
|
# Initialize the Lagrange multiplier as a parameter
|
|
self.lagrange = nn.Parameter(
|
|
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
|
)
|
|
|
|
self.to(self.device)
|
|
|
|
def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
# Get the multiplier value based on parameterization
|
|
multiplier = torch.nn.functional.softplus(self.lagrange)
|
|
|
|
# Return the raw multiplier if no constraint values provided
|
|
if lhs is None:
|
|
return multiplier
|
|
|
|
# Move inputs to device
|
|
lhs = lhs.to(self.device)
|
|
if rhs is not None:
|
|
rhs = rhs.to(self.device)
|
|
|
|
# Use the multiplier to compute the Lagrange penalty
|
|
if rhs is None:
|
|
rhs = torch.zeros_like(lhs, device=self.device)
|
|
|
|
diff = lhs - rhs
|
|
|
|
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
|
|
|
return multiplier * diff
|
|
|
|
|
|
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
|
|
# 1. The base distribution is a diagonal multivariate normal distribution
|
|
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
|
|
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
|
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
|
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
|
DEFAULT_SAMPLE_SHAPE = torch.Size()
|
|
|
|
def __init__(
|
|
self,
|
|
loc: torch.Tensor,
|
|
scale_diag: torch.Tensor,
|
|
low: Optional[torch.Tensor] = None,
|
|
high: Optional[torch.Tensor] = None,
|
|
):
|
|
# Create base normal distribution
|
|
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
|
|
|
|
# Create list of transforms
|
|
transforms = []
|
|
|
|
# Add tanh transform
|
|
transforms.append(torch.distributions.transforms.TanhTransform())
|
|
|
|
# Add rescaling transform if bounds are provided
|
|
if low is not None and high is not None:
|
|
transforms.append(
|
|
torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
|
|
)
|
|
|
|
# Initialize parent class
|
|
super().__init__(base_distribution=base_distribution, transforms=transforms)
|
|
|
|
# Store parameters
|
|
self.loc = loc
|
|
self.scale_diag = scale_diag
|
|
self.low = low
|
|
self.high = high
|
|
|
|
def mode(self) -> torch.Tensor:
|
|
"""Get the mode of the transformed distribution"""
|
|
# The mode of a normal distribution is its mean
|
|
mode = self.loc
|
|
# Apply transforms
|
|
for transform in self.transforms:
|
|
mode = transform(mode)
|
|
|
|
return mode
|
|
|
|
def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
|
|
"""
|
|
Reparameterized sample from the distribution
|
|
"""
|
|
# Sample from base distribution
|
|
x = self.base_dist.rsample(sample_shape)
|
|
|
|
# Apply transforms
|
|
for transform in self.transforms:
|
|
x = transform(x)
|
|
|
|
return x
|
|
|
|
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Compute log probability of a value
|
|
Includes the log det jacobian for the transforms
|
|
"""
|
|
# Initialize log prob
|
|
log_prob = torch.zeros_like(value[..., 0])
|
|
|
|
# Inverse transforms to get back to normal distribution
|
|
q = value
|
|
for transform in reversed(self.transforms):
|
|
q = transform.inv(q)
|
|
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
|
|
|
|
# Add base distribution log prob
|
|
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
|
|
|
|
return log_prob
|
|
|
|
def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Sample from the distribution and compute log probability
|
|
"""
|
|
x = self.rsample(sample_shape)
|
|
log_prob = self.log_prob(x)
|
|
return x, log_prob
|
|
|
|
def entropy(self) -> torch.Tensor:
|
|
"""
|
|
Compute entropy of the distribution
|
|
"""
|
|
# Start with base distribution entropy
|
|
entropy = self.base_dist.entropy().sum(-1)
|
|
|
|
# Add log det jacobian for each transform
|
|
x = self.rsample()
|
|
for transform in self.transforms:
|
|
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
|
x = transform(x)
|
|
|
|
return entropy
|
|
|
|
|
|
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
|
"""Creates an ensemble of critic networks"""
|
|
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
|
return nn.ModuleList(critics).to(device)
|
|
|
|
|
|
def orthogonal_init():
|
|
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
|
|
|
|
|
# borrowed from tdmpc
|
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
|
|
|
Args:
|
|
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
|
(B, *), where * is any number of dimensions.
|
|
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
|
can be more than 1 dimensions, generally different from *.
|
|
Returns:
|
|
A return value from the callable reshaped to (**, *).
|
|
"""
|
|
if image_tensor.ndim == 4:
|
|
return fn(image_tensor)
|
|
start_dims = image_tensor.shape[:-3]
|
|
inp = torch.flatten(image_tensor, end_dim=-4)
|
|
flat_out = fn(inp)
|
|
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|