Compare commits

...

12 Commits

Author SHA1 Message Date
Adil Zouitine
278b56bce9 Add rlpd tricks 2025-01-15 15:49:24 +01:00
Adil Zouitine
0ffc0a7170 SAC works 2025-01-14 11:34:52 +01:00
Adil Zouitine
43d9890489 remove breakpoint 2025-01-13 17:58:00 +01:00
Adil Zouitine
963be41003 [WIP] correct sac implementation 2025-01-13 17:54:11 +01:00
Adil Zouitine
9edae4a8de Correct losses and factorisation 2025-01-07 17:07:55 +01:00
Ke-Wang1017
89d8189d8b remove unused debug lines 2025-01-06 10:18:40 +00:00
Ke-Wang1017
8b70b129dc improvements from JClinton to speed up loading offline data 2025-01-06 10:15:45 +00:00
joeclinton1
db3925df28 Potential fixes for SAC instability and NAN bug 2025-01-06 10:15:01 +00:00
Ke-Wang1017
f99e670976 Refactor SACPolicy and configuration for improved training dynamics
- Introduced target critic networks in SACPolicy to enhance stability during training.
- Updated TD target calculation to incorporate entropy adjustments, improving robustness.
- Increased online buffer capacity in configuration from 10,000 to 40,000 for better data handling.
- Adjusted learning rates for critic, actor, and temperature to 3e-4 for optimized training performance.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
2025-01-06 10:14:34 +00:00
KeWang1017
eec28baa63 fix the bug of target critic updates, roll back to origial temperature implementation, added debug logging info 2025-01-06 10:14:09 +00:00
KeWang1017
f1f04eb4f9 use mean instead of sampled action for the inference 2025-01-06 10:12:24 +00:00
KeWang1017
77a7f92139 1, add input normalization in configuration_sac.py 2, add masking on loss computation 2025-01-06 10:11:51 +00:00
6 changed files with 1359 additions and 207 deletions

View File

@@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return { # Step 1: Combine all unique indices
key: torch.stack(self.hf_dataset.select(q_idx)[key]) all_indices = sorted({idx for indices in query_indices.values() for idx in indices})
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys # Step 2: Select all required data at once
} selected_dataset = self.hf_dataset.select(all_indices).to_dict()
selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()}
# Step 3: Map original indices to their positions in the selected dataset
index_map = {original_idx: i for i, original_idx in enumerate(all_indices)}
# Step 4: Build the result for each key
results = {}
for key, q_indices in query_indices.items():
if key not in self.meta.video_keys:
mapped_indices = [index_map[idx] for idx in q_indices]
results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices])
return results
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function

View File

@@ -28,12 +28,18 @@ class SACConfig:
) )
output_shapes: dict[str, list[int]] = field( output_shapes: dict[str, list[int]] = field(
default_factory=lambda: { default_factory=lambda: {
"action": [4], "action": [2],
} }
) )
# Normalization / Unnormalization # Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field( output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"}, default_factory=lambda: {"action": "min_max"},
) )
@@ -48,9 +54,10 @@ class SACConfig:
critic_target_update_weight = 0.005 critic_target_update_weight = 0.005
utd_ratio = 2 utd_ratio = 2
state_encoder_hidden_dim = 256 state_encoder_hidden_dim = 256
latent_dim = 128 latent_dim = 256
target_entropy = None target_entropy = None
backup_entropy = True # backup_entropy = False
use_backup_entropy = True
critic_network_kwargs = { critic_network_kwargs = {
"hidden_dims": [256, 256], "hidden_dims": [256, 256],
"activate_final": True, "activate_final": True,

View File

@@ -18,8 +18,7 @@
# TODO: (1) better device management # TODO: (1) better device management
from collections import deque from collections import deque
from copy import deepcopy from typing import Callable, Optional, Sequence, Tuple, Union
from typing import Callable, Optional, Sequence, Tuple
import einops import einops
import numpy as np import numpy as np
@@ -57,12 +56,20 @@ class SACPolicy(
) )
else: else:
self.normalize_inputs = nn.Identity() self.normalize_inputs = nn.Identity()
# HACK: we need to pass the dataset_stats to the normalization functions
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
self.normalize_targets = Normalize( self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats config.output_shapes, config.output_normalization_modes, dataset_stats
) )
encoder_critic = SACObservationEncoder(config) encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config) encoder_actor = SACObservationEncoder(config)
# Define networks # Define networks
@@ -72,26 +79,38 @@ class SACPolicy(
encoder=encoder_critic, encoder=encoder_critic,
network=MLP( network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs **config.critic_network_kwargs,
) ),
) )
critic_nets.append(critic_net) critic_nets.append(critic_net)
target_critic_nets = []
for _ in range(config.num_critics):
target_critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
),
)
target_critic_nets.append(target_critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble) self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy( self.actor = Policy(
encoder=encoder_actor, encoder=encoder_actor,
network=MLP( network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
input_dim=encoder_actor.output_dim,
**config.actor_network_kwargs
),
action_dim=config.output_shapes["action"][0], action_dim=config.output_shapes["action"][0],
**config.policy_kwargs **config.policy_kwargs,
) )
if config.target_entropy is None: if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
self.temperature = LagrangeMultiplier(init_value=config.temperature_init) # TODO: fix later device
# TODO: Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
self.temperature = self.log_alpha.exp().item()
def reset(self): def reset(self):
""" """
@@ -111,11 +130,13 @@ class SACPolicy(
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation""" """Select action for inference/evaluation"""
actions, _ = self.actor(batch) actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions return actions
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor: def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
) -> Tensor:
"""Forward pass through a critic network ensemble """Forward pass through a critic network ensemble
Args: Args:
@@ -130,16 +151,20 @@ class SACPolicy(
q_values = torch.stack([critic(observations, actions) for critic in critics]) q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss. """Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats. Returns a dictionary with loss as a tensor, and other information as native floats.
""" """
# We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item()
temperature = self.temperature
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and # batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index. # the next observation for calculating the right td index.
actions = batch["action"][:, 0] # actions = batch["action"][:, 0]
actions = batch["action"]
rewards = batch["next.reward"][:, 0] rewards = batch["next.reward"][:, 0]
observations = {} observations = {}
next_observations = {} next_observations = {}
@@ -147,18 +172,13 @@ class SACPolicy(
if k.startswith("observation."): if k.startswith("observation."):
observations[k] = batch[k][:, 0] observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1] next_observations[k] = batch[k][:, 1]
done = batch["next.done"]
# perform image augmentation with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor(next_observations)
# 2- compute q targets # 2- compute q targets
q_targets = self.critic_forward(next_observations, action_preds, use_target=True) q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
# subsample critics to prevent overfitting if use high UTD (update to date) # subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None: if self.config.num_subsample_critics is not None:
@@ -168,84 +188,125 @@ class SACPolicy(
# critics subsample size # critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
# compute td target min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term td_target = rewards + self.config.discount * min_q * ~done
# 3- compute predicted qs # 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False) q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss # 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss( td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
q_preds, # shape: [num_critics, batch_size] # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape critics_loss = (
reduction="none" F.mse_loss(
).sum(0).mean() input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
# critics_loss = ( actions_pi, log_probs, _ = self.actor(observations)
# F.mse_loss(
# q_preds,
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
# reduction="none",
# ).sum(0) # sum over ensemble
# # `q_preds_ensemble` depends on the first observation and the actions.
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
# # q_targets depends on the reward and the next observations.
# * ~batch["next.reward_is_pad"]
# * ~batch["observation.state_is_pad"][1:]
# ).sum(0).mean()
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor(observations)
# 3- get q-value predictions
with torch.inference_mode(): with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False) q_preds = self.critic_forward(observations, actions_pi, use_target=False)
actor_loss = ( min_q_preds = q_preds.min(dim=0)[0]
-(q_preds - temperature * log_probs).mean()
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
).mean()
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
# calculate temperature loss # calculate temperature loss
# 1- calculate entropy with torch.no_grad():
entropy = -log_probs.mean() _, log_probs, _ = self.actor(observations)
temperature_loss = self.temperature( temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
lhs=entropy,
rhs=self.config.target_entropy
)
loss = critics_loss + actor_loss + temperature_loss loss = critics_loss + actor_loss + temperature_loss
return { return {
"critics_loss": critics_loss.item(), "critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(), "actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts": min_q_preds.min().item(),
"max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(), "temperature_loss": temperature_loss.item(),
"temperature": temperature.item(), "temperature": temperature,
"entropy": entropy.item(), "mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": log_probs.mean().item(),
"loss": loss, "loss": loss,
} }
def update(self): def update_target_networks(self):
# TODO: implement UTD update
# First update only critics for utd_ratio-1 times
#for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target
# Then update critic, critic target, actor and temperature
"""Update target networks with exponential moving average""" """Update target networks with exponential moving average"""
with torch.no_grad():
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_( target_param.data.copy_(
target_param.data * self.config.critic_target_update_weight + param.data * self.config.critic_target_update_weight
param.data * (1.0 - self.config.critic_target_update_weight) + target_param.data * (1.0 - self.config.critic_target_update_weight)
) )
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
return critics_loss
def compute_loss_temperature(self, observations) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
return actor_loss
class MLP(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, self,
@@ -276,7 +337,9 @@ class MLP(nn.Module):
if dropout_rate is not None and dropout_rate > 0: if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i])) layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers) self.net = nn.Sequential(*layers)
@@ -290,7 +353,7 @@ class Critic(nn.Module):
encoder: Optional[nn.Module], encoder: Optional[nn.Module],
network: nn.Module, network: nn.Module,
init_final: Optional[float] = None, init_final: Optional[float] = None,
device: str = "cuda" device: str = "cpu",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@@ -321,9 +384,7 @@ class Critic(nn.Module):
actions: torch.Tensor, actions: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Move each tensor in observations to device # Move each tensor in observations to device
observations = { observations = {k: v.to(self.device) for k, v in observations.items()}
k: v.to(self.device) for k, v in observations.items()
}
actions = actions.to(self.device) actions = actions.to(self.device)
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
@@ -345,7 +406,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None, fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None, init_final: Optional[float] = None,
use_tanh_squash: bool = False, use_tanh_squash: bool = False,
device: str = "cuda" device: str = "cpu",
): ):
super().__init__() super().__init__()
self.device = torch.device(device) self.device = torch.device(device)
@@ -386,7 +447,6 @@ class Policy(nn.Module):
self, self,
observations: torch.Tensor, observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists # Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations) obs_enc = observations if self.encoder is None else self.encoder(observations)
@@ -397,22 +457,30 @@ class Policy(nn.Module):
# Compute standard deviations # Compute standard deviations
if self.fixed_std is None: if self.fixed_std is None:
log_std = self.std_layer(outputs) log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash: if self.use_tanh_squash:
log_std = torch.tanh(log_std) log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else: else:
log_std = self.fixed_std.expand_as(means) log_std = self.fixed_std.expand_as(means)
# uses tahn activation function to squash the action to be in the range of [-1, 1] # uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std)) normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) log_probs = normal.log_prob(x_t) # Base log probability before Tanh
if self.use_tanh_squash: if self.use_tanh_squash:
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
log_probs = log_probs.sum(-1) # sum over action dim else:
actions = x_t # No Tanh; raw Gaussian sample
return actions, log_probs log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations""" """Get encoded features from observations"""
@@ -461,19 +529,13 @@ class SACObservationEncoder(nn.Module):
) )
if "observation.state" in config.input_shapes: if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential( self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
if "observation.environment_state" in config.input_shapes: if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential( self.env_state_enc_layers = nn.Sequential(
nn.Linear( nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim), nn.LayerNorm(config.latent_dim),
nn.Tanh(), nn.Tanh(),
) )
@@ -502,59 +564,16 @@ class SACObservationEncoder(nn.Module):
return self.config.latent_dim return self.config.latent_dim
class LagrangeMultiplier(nn.Module):
def __init__(
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
# Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
)
self.to(self.device)
def forward(
self,
lhs: Optional[torch.Tensor | float | int] = None,
rhs: Optional[torch.Tensor | float | int] = None
) -> torch.Tensor:
# Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange)
# Return the raw multiplier if no constraint values provided
if lhs is None:
return multiplier
# Convert inputs to tensors and move to device
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
if rhs is not None:
rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device)
else:
rhs = torch.zeros_like(lhs, device=self.device)
diff = lhs - rhs
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
return multiplier * diff
def orthogonal_init(): def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
"""Creates an ensemble of critic networks""" """Creates an ensemble of critic networks"""
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device) return nn.ModuleList(critics).to(device)
# borrowed from tdmpc # borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor. """Helper to temporarily flatten extra dims at the start of the image tensor.

View File

@@ -15,11 +15,11 @@ training:
# Offline training dataloader # Offline training dataloader
num_workers: 4 num_workers: 4
batch_size: 128 batch_size: 256
grad_clip_norm: 10.0 grad_clip_norm: 10.0
lr: 3e-4 lr: 3e-4
eval_freq: 50000 eval_freq: 2500
log_freq: 500 log_freq: 500
save_freq: 50000 save_freq: 50000
@@ -46,8 +46,8 @@ policy:
# Input / output structure. # Input / output structure.
n_action_repeats: 1 n_action_repeats: 1
horizon: 5 horizon: 2
n_action_steps: 5 n_action_steps: 2
input_shapes: input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?

View File

@@ -99,7 +99,8 @@ def make_optimizer_and_scheduler(cfg, policy):
[ [
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr}, {"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr}, {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr}, # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
{"params": [policy.log_alpha], "lr": policy.config.temperature_lr},
] ]
) )
lr_scheduler = None lr_scheduler = None

1112
lerobot/scripts/train_sac.py Normal file

File diff suppressed because it is too large Load Diff