Compare commits

...

12 Commits

Author SHA1 Message Date
Adil Zouitine
278b56bce9 Add rlpd tricks 2025-01-15 15:49:24 +01:00
Adil Zouitine
0ffc0a7170 SAC works 2025-01-14 11:34:52 +01:00
Adil Zouitine
43d9890489 remove breakpoint 2025-01-13 17:58:00 +01:00
Adil Zouitine
963be41003 [WIP] correct sac implementation 2025-01-13 17:54:11 +01:00
Adil Zouitine
9edae4a8de Correct losses and factorisation 2025-01-07 17:07:55 +01:00
Ke-Wang1017
89d8189d8b remove unused debug lines 2025-01-06 10:18:40 +00:00
Ke-Wang1017
8b70b129dc improvements from JClinton to speed up loading offline data 2025-01-06 10:15:45 +00:00
joeclinton1
db3925df28 Potential fixes for SAC instability and NAN bug 2025-01-06 10:15:01 +00:00
Ke-Wang1017
f99e670976 Refactor SACPolicy and configuration for improved training dynamics
- Introduced target critic networks in SACPolicy to enhance stability during training.
- Updated TD target calculation to incorporate entropy adjustments, improving robustness.
- Increased online buffer capacity in configuration from 10,000 to 40,000 for better data handling.
- Adjusted learning rates for critic, actor, and temperature to 3e-4 for optimized training performance.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
2025-01-06 10:14:34 +00:00
KeWang1017
eec28baa63 fix the bug of target critic updates, roll back to origial temperature implementation, added debug logging info 2025-01-06 10:14:09 +00:00
KeWang1017
f1f04eb4f9 use mean instead of sampled action for the inference 2025-01-06 10:12:24 +00:00
KeWang1017
77a7f92139 1, add input normalization in configuration_sac.py 2, add masking on loss computation 2025-01-06 10:11:51 +00:00
6 changed files with 1359 additions and 207 deletions

View File

@@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
# Step 1: Combine all unique indices
all_indices = sorted({idx for indices in query_indices.values() for idx in indices})
# Step 2: Select all required data at once
selected_dataset = self.hf_dataset.select(all_indices).to_dict()
selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()}
# Step 3: Map original indices to their positions in the selected dataset
index_map = {original_idx: i for i, original_idx in enumerate(all_indices)}
# Step 4: Build the result for each key
results = {}
for key, q_indices in query_indices.items():
if key not in self.meta.video_keys:
mapped_indices = [index_map[idx] for idx in q_indices]
results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices])
return results
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function

View File

@@ -28,12 +28,18 @@ class SACConfig:
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
@@ -48,9 +54,10 @@ class SACConfig:
critic_target_update_weight = 0.005
utd_ratio = 2
state_encoder_hidden_dim = 256
latent_dim = 128
latent_dim = 256
target_entropy = None
backup_entropy = True
# backup_entropy = False
use_backup_entropy = True
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,

View File

@@ -18,8 +18,7 @@
# TODO: (1) better device management
from collections import deque
from copy import deepcopy
from typing import Callable, Optional, Sequence, Tuple
from typing import Callable, Optional, Sequence, Tuple, Union
import einops
import numpy as np
@@ -57,12 +56,20 @@ class SACPolicy(
)
else:
self.normalize_inputs = nn.Identity()
# HACK: we need to pass the dataset_stats to the normalization functions
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config)
# Define networks
@@ -72,26 +79,38 @@ class SACPolicy(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs
)
**config.critic_network_kwargs,
),
)
critic_nets.append(critic_net)
target_critic_nets = []
for _ in range(config.num_critics):
target_critic_net = Critic(
encoder=encoder_critic,
network=MLP(
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
**config.critic_network_kwargs,
),
)
target_critic_nets.append(target_critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble)
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy(
encoder=encoder_actor,
network=MLP(
input_dim=encoder_actor.output_dim,
**config.actor_network_kwargs
),
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
**config.policy_kwargs
**config.policy_kwargs,
)
if config.target_entropy is None:
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
# TODO: fix later device
# TODO: Handle the case where the temparameter is a fixed
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
self.temperature = self.log_alpha.exp().item()
def reset(self):
"""
@@ -111,18 +130,20 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
actions, _ = self.actor(batch)
actions, _, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
def critic_forward(
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
@@ -130,16 +151,20 @@ class SACPolicy(
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
# We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item()
temperature = self.temperature
batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
actions = batch["action"][:, 0]
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
# actions = batch["action"][:, 0]
actions = batch["action"]
rewards = batch["next.reward"][:, 0]
observations = {}
next_observations = {}
@@ -147,105 +172,141 @@ class SACPolicy(
if k.startswith("observation."):
observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1]
# perform image augmentation
done = batch["next.done"]
# reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor(next_observations)
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[:self.config.num_subsample_critics]
q_targets = q_targets[indices]
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
# compute td target
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = F.mse_loss(
q_preds, # shape: [num_critics, batch_size]
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
reduction="none"
).sum(0).mean()
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
# critics_loss = (
# F.mse_loss(
# q_preds,
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
# reduction="none",
# ).sum(0) # sum over ensemble
# # `q_preds_ensemble` depends on the first observation and the actions.
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
# # q_targets depends on the reward and the next observations.
# * ~batch["next.reward_is_pad"]
# * ~batch["observation.state_is_pad"][1:]
# ).sum(0).mean()
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor(observations)
# 3- get q-value predictions
actions_pi, log_probs, _ = self.actor(observations)
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False)
actor_loss = (
-(q_preds - temperature * log_probs).mean()
# * ~batch["observation.state_is_pad"][0]
# * ~batch["action_is_pad"]
).mean()
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
# calculate temperature loss
# 1- calculate entropy
entropy = -log_probs.mean()
temperature_loss = self.temperature(
lhs=entropy,
rhs=self.config.target_entropy
)
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
loss = critics_loss + actor_loss + temperature_loss
return {
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"entropy": entropy.item(),
"loss": loss,
}
def update(self):
# TODO: implement UTD update
# First update only critics for utd_ratio-1 times
#for critic_step in range(self.config.utd_ratio - 1):
# only update critic and critic target
# Then update critic, critic target, actor and temperature
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts": min_q_preds.min().item(),
"max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature,
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": log_probs.mean().item(),
"loss": loss,
}
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
temperature = self.log_alpha.exp().item()
with torch.no_grad():
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
target_param.data.copy_(
target_param.data * self.config.critic_target_update_weight +
param.data * (1.0 - self.config.critic_target_update_weight)
)
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
return critics_loss
def compute_loss_temperature(self, observations) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(self, observations) -> Tensor:
temperature = self.log_alpha.exp().item()
actions_pi, log_probs, _ = self.actor(observations)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
return actor_loss
class MLP(nn.Module):
def __init__(
self,
@@ -258,52 +319,54 @@ class MLP(nn.Module):
super().__init__()
self.activate_final = activate_final
layers = []
# First layer uses input_dim
layers.append(nn.Linear(input_dim, hidden_dims[0]))
# Add activation after first layer
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[0]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
# Rest of the layers
for i in range(1, len(hidden_dims)):
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i]))
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
if i + 1 < len(hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(hidden_dims[i]))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
layers.append(
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
)
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class Critic(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
device: str = "cuda"
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.init_final = init_final
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Output layer
if init_final is not None:
self.output_layer = nn.Linear(out_features, 1)
@@ -312,22 +375,20 @@ class Critic(nn.Module):
else:
self.output_layer = nn.Linear(out_features, 1)
orthogonal_init()(self.output_layer.weight)
self.to(self.device)
def forward(
self,
observations: dict[str, torch.Tensor],
self,
observations: dict[str, torch.Tensor],
actions: torch.Tensor,
) -> torch.Tensor:
# Move each tensor in observations to device
observations = {
k: v.to(self.device) for k, v in observations.items()
}
observations = {k: v.to(self.device) for k, v in observations.items()}
actions = actions.to(self.device)
obs_enc = observations if self.encoder is None else self.encoder(observations)
inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs)
value = self.output_layer(x)
@@ -345,7 +406,7 @@ class Policy(nn.Module):
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
use_tanh_squash: bool = False,
device: str = "cuda"
device: str = "cpu",
):
super().__init__()
self.device = torch.device(device)
@@ -356,13 +417,13 @@ class Policy(nn.Module):
self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.use_tanh_squash = use_tanh_squash
# Find the last Linear layer's output dimension
for layer in reversed(network.net):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
@@ -370,7 +431,7 @@ class Policy(nn.Module):
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter
if fixed_std is None:
self.std_layer = nn.Linear(out_features, action_dim)
@@ -379,41 +440,48 @@ class Policy(nn.Module):
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.to(self.device)
def forward(
self,
self,
observations: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Encode observations if encoder exists
obs_enc = observations if self.encoder is None else self.encoder(observations)
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
log_std = self.std_layer(outputs)
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
else:
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
log_std = self.fixed_std.expand_as(means)
# uses tahn activation function to squash the action to be in the range of [-1, 1]
# uses tanh activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, torch.exp(log_std))
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t)
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
else:
actions = x_t # No Tanh; raw Gaussian sample
log_probs = log_probs.sum(-1) # Sum over action dimensions
means = torch.tanh(means) if self.use_tanh_squash else means
return actions, log_probs, means
return actions, log_probs
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
@@ -461,19 +529,13 @@ class SACObservationEncoder(nn.Module):
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
@@ -495,66 +557,23 @@ class SACObservationEncoder(nn.Module):
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
return torch.stack(feat, dim=0).mean(0)
@property
def output_dim(self) -> int:
"""Returns the dimension of the encoder output"""
return self.config.latent_dim
class LagrangeMultiplier(nn.Module):
def __init__(
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
# Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
)
self.to(self.device)
def forward(
self,
lhs: Optional[torch.Tensor | float | int] = None,
rhs: Optional[torch.Tensor | float | int] = None
) -> torch.Tensor:
# Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange)
# Return the raw multiplier if no constraint values provided
if lhs is None:
return multiplier
# Convert inputs to tensors and move to device
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
if rhs is not None:
rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device)
else:
rhs = torch.zeros_like(lhs, device=self.device)
diff = lhs - rhs
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
return multiplier * diff
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
return nn.ModuleList(critics).to(device)
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
@@ -562,7 +581,7 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *.
Returns:
A return value from the callable reshaped to (**, *).
@@ -572,4 +591,4 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@@ -15,11 +15,11 @@ training:
# Offline training dataloader
num_workers: 4
batch_size: 128
batch_size: 256
grad_clip_norm: 10.0
lr: 3e-4
eval_freq: 50000
eval_freq: 2500
log_freq: 500
save_freq: 50000
@@ -46,8 +46,8 @@ policy:
# Input / output structure.
n_action_repeats: 1
horizon: 5
n_action_steps: 5
horizon: 2
n_action_steps: 2
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?

View File

@@ -99,7 +99,8 @@ def make_optimizer_and_scheduler(cfg, policy):
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
{"params": [policy.log_alpha], "lr": policy.config.temperature_lr},
]
)
lr_scheduler = None

1112
lerobot/scripts/train_sac.py Normal file

File diff suppressed because it is too large Load Diff