Compare commits
12 Commits
user/miche
...
user/adil-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
278b56bce9 | ||
|
|
0ffc0a7170 | ||
|
|
43d9890489 | ||
|
|
963be41003 | ||
|
|
9edae4a8de | ||
|
|
89d8189d8b | ||
|
|
8b70b129dc | ||
|
|
db3925df28 | ||
|
|
f99e670976 | ||
|
|
eec28baa63 | ||
|
|
f1f04eb4f9 | ||
|
|
77a7f92139 |
@@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
# Step 1: Combine all unique indices
|
||||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
all_indices = sorted({idx for indices in query_indices.values() for idx in indices})
|
||||||
for key, q_idx in query_indices.items()
|
|
||||||
if key not in self.meta.video_keys
|
# Step 2: Select all required data at once
|
||||||
}
|
selected_dataset = self.hf_dataset.select(all_indices).to_dict()
|
||||||
|
selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()}
|
||||||
|
|
||||||
|
# Step 3: Map original indices to their positions in the selected dataset
|
||||||
|
index_map = {original_idx: i for i, original_idx in enumerate(all_indices)}
|
||||||
|
|
||||||
|
# Step 4: Build the result for each key
|
||||||
|
results = {}
|
||||||
|
for key, q_indices in query_indices.items():
|
||||||
|
if key not in self.meta.video_keys:
|
||||||
|
mapped_indices = [index_map[idx] for idx in q_indices]
|
||||||
|
results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices])
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
|
|||||||
@@ -28,12 +28,18 @@ class SACConfig:
|
|||||||
)
|
)
|
||||||
output_shapes: dict[str, list[int]] = field(
|
output_shapes: dict[str, list[int]] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"action": [4],
|
"action": [2],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalization / Unnormalization
|
# Normalization / Unnormalization
|
||||||
input_normalization_modes: dict[str, str] | None = None
|
input_normalization_modes: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"observation.image": "mean_std",
|
||||||
|
"observation.state": "min_max",
|
||||||
|
"observation.environment_state": "min_max",
|
||||||
|
}
|
||||||
|
)
|
||||||
output_normalization_modes: dict[str, str] = field(
|
output_normalization_modes: dict[str, str] = field(
|
||||||
default_factory=lambda: {"action": "min_max"},
|
default_factory=lambda: {"action": "min_max"},
|
||||||
)
|
)
|
||||||
@@ -48,9 +54,10 @@ class SACConfig:
|
|||||||
critic_target_update_weight = 0.005
|
critic_target_update_weight = 0.005
|
||||||
utd_ratio = 2
|
utd_ratio = 2
|
||||||
state_encoder_hidden_dim = 256
|
state_encoder_hidden_dim = 256
|
||||||
latent_dim = 128
|
latent_dim = 256
|
||||||
target_entropy = None
|
target_entropy = None
|
||||||
backup_entropy = True
|
# backup_entropy = False
|
||||||
|
use_backup_entropy = True
|
||||||
critic_network_kwargs = {
|
critic_network_kwargs = {
|
||||||
"hidden_dims": [256, 256],
|
"hidden_dims": [256, 256],
|
||||||
"activate_final": True,
|
"activate_final": True,
|
||||||
|
|||||||
@@ -18,8 +18,7 @@
|
|||||||
# TODO: (1) better device management
|
# TODO: (1) better device management
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from copy import deepcopy
|
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||||
from typing import Callable, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -57,12 +56,20 @@ class SACPolicy(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.normalize_inputs = nn.Identity()
|
self.normalize_inputs = nn.Identity()
|
||||||
|
# HACK: we need to pass the dataset_stats to the normalization functions
|
||||||
|
dataset_stats = dataset_stats or {
|
||||||
|
"action": {
|
||||||
|
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
|
||||||
|
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||||
|
}
|
||||||
|
}
|
||||||
self.normalize_targets = Normalize(
|
self.normalize_targets = Normalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_critic = SACObservationEncoder(config)
|
encoder_critic = SACObservationEncoder(config)
|
||||||
encoder_actor = SACObservationEncoder(config)
|
encoder_actor = SACObservationEncoder(config)
|
||||||
# Define networks
|
# Define networks
|
||||||
@@ -72,26 +79,38 @@ class SACPolicy(
|
|||||||
encoder=encoder_critic,
|
encoder=encoder_critic,
|
||||||
network=MLP(
|
network=MLP(
|
||||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||||
**config.critic_network_kwargs
|
**config.critic_network_kwargs,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
critic_nets.append(critic_net)
|
critic_nets.append(critic_net)
|
||||||
|
|
||||||
|
target_critic_nets = []
|
||||||
|
for _ in range(config.num_critics):
|
||||||
|
target_critic_net = Critic(
|
||||||
|
encoder=encoder_critic,
|
||||||
|
network=MLP(
|
||||||
|
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||||
|
**config.critic_network_kwargs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
target_critic_nets.append(target_critic_net)
|
||||||
|
|
||||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||||
self.critic_target = deepcopy(self.critic_ensemble)
|
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
||||||
|
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||||
|
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=encoder_actor,
|
encoder=encoder_actor,
|
||||||
network=MLP(
|
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||||
input_dim=encoder_actor.output_dim,
|
|
||||||
**config.actor_network_kwargs
|
|
||||||
),
|
|
||||||
action_dim=config.output_shapes["action"][0],
|
action_dim=config.output_shapes["action"][0],
|
||||||
**config.policy_kwargs
|
**config.policy_kwargs,
|
||||||
)
|
)
|
||||||
if config.target_entropy is None:
|
if config.target_entropy is None:
|
||||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A))
|
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
# TODO: fix later device
|
||||||
|
# TODO: Handle the case where the temparameter is a fixed
|
||||||
|
self.log_alpha = torch.zeros(1, requires_grad=True, device="cpu")
|
||||||
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
@@ -111,11 +130,13 @@ class SACPolicy(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select action for inference/evaluation"""
|
"""Select action for inference/evaluation"""
|
||||||
actions, _ = self.actor(batch)
|
actions, _, _ = self.actor(batch)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
def critic_forward(
|
||||||
|
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||||
|
) -> Tensor:
|
||||||
"""Forward pass through a critic network ensemble
|
"""Forward pass through a critic network ensemble
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -130,16 +151,20 @@ class SACPolicy(
|
|||||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||||
return q_values
|
return q_values
|
||||||
|
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||||
"""Run the batch through the model and compute the loss.
|
"""Run the batch through the model and compute the loss.
|
||||||
|
|
||||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||||
"""
|
"""
|
||||||
|
# We have to actualize the value of the temperature because in the previous
|
||||||
|
self.temperature = self.log_alpha.exp().item()
|
||||||
|
temperature = self.temperature
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||||
# the next observation for calculating the right td index.
|
# the next observation for calculating the right td index.
|
||||||
actions = batch["action"][:, 0]
|
# actions = batch["action"][:, 0]
|
||||||
|
actions = batch["action"]
|
||||||
rewards = batch["next.reward"][:, 0]
|
rewards = batch["next.reward"][:, 0]
|
||||||
observations = {}
|
observations = {}
|
||||||
next_observations = {}
|
next_observations = {}
|
||||||
@@ -147,18 +172,13 @@ class SACPolicy(
|
|||||||
if k.startswith("observation."):
|
if k.startswith("observation."):
|
||||||
observations[k] = batch[k][:, 0]
|
observations[k] = batch[k][:, 0]
|
||||||
next_observations[k] = batch[k][:, 1]
|
next_observations[k] = batch[k][:, 1]
|
||||||
|
done = batch["next.done"]
|
||||||
|
|
||||||
# perform image augmentation
|
with torch.no_grad():
|
||||||
|
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||||
# reward bias from HIL-SERL code base
|
|
||||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
|
||||||
|
|
||||||
# calculate critics loss
|
|
||||||
# 1- compute actions from policy
|
|
||||||
action_preds, log_probs = self.actor(next_observations)
|
|
||||||
|
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
|
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
|
||||||
|
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
if self.config.num_subsample_critics is not None:
|
if self.config.num_subsample_critics is not None:
|
||||||
@@ -168,84 +188,125 @@ class SACPolicy(
|
|||||||
|
|
||||||
# critics subsample size
|
# critics subsample size
|
||||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||||
|
if self.config.use_backup_entropy:
|
||||||
# compute td target
|
min_q -= self.temperature * next_log_probs
|
||||||
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term
|
td_target = rewards + self.config.discount * min_q * ~done
|
||||||
|
|
||||||
# 3- compute predicted qs
|
# 3- compute predicted qs
|
||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
|
|
||||||
# 4- Calculate loss
|
# 4- Calculate loss
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
critics_loss = F.mse_loss(
|
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||||
q_preds, # shape: [num_critics, batch_size]
|
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
|
critics_loss = (
|
||||||
reduction="none"
|
F.mse_loss(
|
||||||
).sum(0).mean()
|
input=q_preds,
|
||||||
|
target=td_target_duplicate,
|
||||||
|
reduction="none",
|
||||||
|
).mean(1)
|
||||||
|
).sum()
|
||||||
|
|
||||||
# critics_loss = (
|
actions_pi, log_probs, _ = self.actor(observations)
|
||||||
# F.mse_loss(
|
|
||||||
# q_preds,
|
|
||||||
# einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]),
|
|
||||||
# reduction="none",
|
|
||||||
# ).sum(0) # sum over ensemble
|
|
||||||
# # `q_preds_ensemble` depends on the first observation and the actions.
|
|
||||||
# * ~batch["observation.state_is_pad"][0]
|
|
||||||
# * ~batch["action_is_pad"]
|
|
||||||
# # q_targets depends on the reward and the next observations.
|
|
||||||
# * ~batch["next.reward_is_pad"]
|
|
||||||
# * ~batch["observation.state_is_pad"][1:]
|
|
||||||
# ).sum(0).mean()
|
|
||||||
|
|
||||||
# calculate actors loss
|
|
||||||
# 1- temperature
|
|
||||||
temperature = self.temperature()
|
|
||||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
|
||||||
actions, log_probs = self.actor(observations)
|
|
||||||
# 3- get q-value predictions
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||||
actor_loss = (
|
min_q_preds = q_preds.min(dim=0)[0]
|
||||||
-(q_preds - temperature * log_probs).mean()
|
|
||||||
# * ~batch["observation.state_is_pad"][0]
|
|
||||||
# * ~batch["action_is_pad"]
|
|
||||||
).mean()
|
|
||||||
|
|
||||||
|
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||||
|
|
||||||
# calculate temperature loss
|
# calculate temperature loss
|
||||||
# 1- calculate entropy
|
with torch.no_grad():
|
||||||
entropy = -log_probs.mean()
|
_, log_probs, _ = self.actor(observations)
|
||||||
temperature_loss = self.temperature(
|
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||||
lhs=entropy,
|
|
||||||
rhs=self.config.target_entropy
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = critics_loss + actor_loss + temperature_loss
|
loss = critics_loss + actor_loss + temperature_loss
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"critics_loss": critics_loss.item(),
|
"critics_loss": critics_loss.item(),
|
||||||
"actor_loss": actor_loss.item(),
|
"actor_loss": actor_loss.item(),
|
||||||
|
"mean_q_predicts": min_q_preds.mean().item(),
|
||||||
|
"min_q_predicts": min_q_preds.min().item(),
|
||||||
|
"max_q_predicts": min_q_preds.max().item(),
|
||||||
"temperature_loss": temperature_loss.item(),
|
"temperature_loss": temperature_loss.item(),
|
||||||
"temperature": temperature.item(),
|
"temperature": temperature,
|
||||||
"entropy": entropy.item(),
|
"mean_log_probs": log_probs.mean().item(),
|
||||||
|
"min_log_probs": log_probs.min().item(),
|
||||||
|
"max_log_probs": log_probs.max().item(),
|
||||||
|
"td_target_mean": td_target.mean().item(),
|
||||||
|
"td_target_max": td_target.max().item(),
|
||||||
|
"action_mean": actions.mean().item(),
|
||||||
|
"entropy": log_probs.mean().item(),
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
}
|
}
|
||||||
|
|
||||||
def update(self):
|
def update_target_networks(self):
|
||||||
# TODO: implement UTD update
|
|
||||||
# First update only critics for utd_ratio-1 times
|
|
||||||
#for critic_step in range(self.config.utd_ratio - 1):
|
|
||||||
# only update critic and critic target
|
|
||||||
# Then update critic, critic target, actor and temperature
|
|
||||||
"""Update target networks with exponential moving average"""
|
"""Update target networks with exponential moving average"""
|
||||||
with torch.no_grad():
|
|
||||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||||
target_param.data.copy_(
|
target_param.data.copy_(
|
||||||
target_param.data * self.config.critic_target_update_weight +
|
param.data * self.config.critic_target_update_weight
|
||||||
param.data * (1.0 - self.config.critic_target_update_weight)
|
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||||
|
temperature = self.log_alpha.exp().item()
|
||||||
|
with torch.no_grad():
|
||||||
|
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||||
|
|
||||||
|
# 2- compute q targets
|
||||||
|
q_targets = self.critic_forward(
|
||||||
|
observations=next_observations, actions=next_action_preds, use_target=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||||
|
if self.config.num_subsample_critics is not None:
|
||||||
|
indices = torch.randperm(self.config.num_critics)
|
||||||
|
indices = indices[: self.config.num_subsample_critics]
|
||||||
|
q_targets = q_targets[indices]
|
||||||
|
|
||||||
|
# critics subsample size
|
||||||
|
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||||
|
if self.config.use_backup_entropy:
|
||||||
|
min_q = min_q - (temperature * next_log_probs)
|
||||||
|
|
||||||
|
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||||
|
|
||||||
|
# 3- compute predicted qs
|
||||||
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
|
|
||||||
|
# 4- Calculate loss
|
||||||
|
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||||
|
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||||
|
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||||
|
critics_loss = (
|
||||||
|
F.mse_loss(
|
||||||
|
input=q_preds,
|
||||||
|
target=td_target_duplicate,
|
||||||
|
reduction="none",
|
||||||
|
).mean(1)
|
||||||
|
).sum()
|
||||||
|
return critics_loss
|
||||||
|
|
||||||
|
def compute_loss_temperature(self, observations) -> Tensor:
|
||||||
|
"""Compute the temperature loss"""
|
||||||
|
# calculate temperature loss
|
||||||
|
with torch.no_grad():
|
||||||
|
_, log_probs, _ = self.actor(observations)
|
||||||
|
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||||
|
return temperature_loss
|
||||||
|
|
||||||
|
def compute_loss_actor(self, observations) -> Tensor:
|
||||||
|
temperature = self.log_alpha.exp().item()
|
||||||
|
|
||||||
|
actions_pi, log_probs, _ = self.actor(observations)
|
||||||
|
|
||||||
|
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||||
|
min_q_preds = q_preds.min(dim=0)[0]
|
||||||
|
|
||||||
|
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||||
|
return actor_loss
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -276,7 +337,9 @@ class MLP(nn.Module):
|
|||||||
if dropout_rate is not None and dropout_rate > 0:
|
if dropout_rate is not None and dropout_rate > 0:
|
||||||
layers.append(nn.Dropout(p=dropout_rate))
|
layers.append(nn.Dropout(p=dropout_rate))
|
||||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
layers.append(
|
||||||
|
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||||
|
)
|
||||||
|
|
||||||
self.net = nn.Sequential(*layers)
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
@@ -290,7 +353,7 @@ class Critic(nn.Module):
|
|||||||
encoder: Optional[nn.Module],
|
encoder: Optional[nn.Module],
|
||||||
network: nn.Module,
|
network: nn.Module,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
device: str = "cuda"
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
@@ -321,9 +384,7 @@ class Critic(nn.Module):
|
|||||||
actions: torch.Tensor,
|
actions: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Move each tensor in observations to device
|
# Move each tensor in observations to device
|
||||||
observations = {
|
observations = {k: v.to(self.device) for k, v in observations.items()}
|
||||||
k: v.to(self.device) for k, v in observations.items()
|
|
||||||
}
|
|
||||||
actions = actions.to(self.device)
|
actions = actions.to(self.device)
|
||||||
|
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
@@ -345,7 +406,7 @@ class Policy(nn.Module):
|
|||||||
fixed_std: Optional[torch.Tensor] = None,
|
fixed_std: Optional[torch.Tensor] = None,
|
||||||
init_final: Optional[float] = None,
|
init_final: Optional[float] = None,
|
||||||
use_tanh_squash: bool = False,
|
use_tanh_squash: bool = False,
|
||||||
device: str = "cuda"
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
@@ -386,7 +447,6 @@ class Policy(nn.Module):
|
|||||||
self,
|
self,
|
||||||
observations: torch.Tensor,
|
observations: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
# Encode observations if encoder exists
|
# Encode observations if encoder exists
|
||||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||||
|
|
||||||
@@ -397,22 +457,30 @@ class Policy(nn.Module):
|
|||||||
# Compute standard deviations
|
# Compute standard deviations
|
||||||
if self.fixed_std is None:
|
if self.fixed_std is None:
|
||||||
log_std = self.std_layer(outputs)
|
log_std = self.std_layer(outputs)
|
||||||
|
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||||
|
|
||||||
if self.use_tanh_squash:
|
if self.use_tanh_squash:
|
||||||
log_std = torch.tanh(log_std)
|
log_std = torch.tanh(log_std)
|
||||||
|
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||||
|
else:
|
||||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||||
else:
|
else:
|
||||||
log_std = self.fixed_std.expand_as(means)
|
log_std = self.fixed_std.expand_as(means)
|
||||||
|
|
||||||
# uses tahn activation function to squash the action to be in the range of [-1, 1]
|
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||||
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
|
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||||
log_probs = normal.log_prob(x_t)
|
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||||
|
|
||||||
if self.use_tanh_squash:
|
if self.use_tanh_squash:
|
||||||
actions = torch.tanh(x_t)
|
actions = torch.tanh(x_t)
|
||||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||||
log_probs = log_probs.sum(-1) # sum over action dim
|
else:
|
||||||
|
actions = x_t # No Tanh; raw Gaussian sample
|
||||||
|
|
||||||
return actions, log_probs
|
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||||
|
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||||
|
return actions, log_probs, means
|
||||||
|
|
||||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||||
"""Get encoded features from observations"""
|
"""Get encoded features from observations"""
|
||||||
@@ -461,19 +529,13 @@ class SACObservationEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
if "observation.state" in config.input_shapes:
|
if "observation.state" in config.input_shapes:
|
||||||
self.state_enc_layers = nn.Sequential(
|
self.state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
nn.Linear(config.input_shapes["observation.state"][0], config.latent_dim),
|
||||||
nn.ELU(),
|
|
||||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
|
||||||
nn.LayerNorm(config.latent_dim),
|
nn.LayerNorm(config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
if "observation.environment_state" in config.input_shapes:
|
if "observation.environment_state" in config.input_shapes:
|
||||||
self.env_state_enc_layers = nn.Sequential(
|
self.env_state_enc_layers = nn.Sequential(
|
||||||
nn.Linear(
|
nn.Linear(config.input_shapes["observation.environment_state"][0], config.latent_dim),
|
||||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
|
||||||
),
|
|
||||||
nn.ELU(),
|
|
||||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
|
||||||
nn.LayerNorm(config.latent_dim),
|
nn.LayerNorm(config.latent_dim),
|
||||||
nn.Tanh(),
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
@@ -502,59 +564,16 @@ class SACObservationEncoder(nn.Module):
|
|||||||
return self.config.latent_dim
|
return self.config.latent_dim
|
||||||
|
|
||||||
|
|
||||||
class LagrangeMultiplier(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
init_value: float = 1.0,
|
|
||||||
constraint_shape: Sequence[int] = (),
|
|
||||||
device: str = "cuda"
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.device = torch.device(device)
|
|
||||||
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
|
||||||
|
|
||||||
# Initialize the Lagrange multiplier as a parameter
|
|
||||||
self.lagrange = nn.Parameter(
|
|
||||||
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.to(self.device)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
lhs: Optional[torch.Tensor | float | int] = None,
|
|
||||||
rhs: Optional[torch.Tensor | float | int] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Get the multiplier value based on parameterization
|
|
||||||
multiplier = torch.nn.functional.softplus(self.lagrange)
|
|
||||||
|
|
||||||
# Return the raw multiplier if no constraint values provided
|
|
||||||
if lhs is None:
|
|
||||||
return multiplier
|
|
||||||
|
|
||||||
# Convert inputs to tensors and move to device
|
|
||||||
lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device)
|
|
||||||
if rhs is not None:
|
|
||||||
rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device)
|
|
||||||
else:
|
|
||||||
rhs = torch.zeros_like(lhs, device=self.device)
|
|
||||||
|
|
||||||
diff = lhs - rhs
|
|
||||||
|
|
||||||
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
|
||||||
|
|
||||||
return multiplier * diff
|
|
||||||
|
|
||||||
|
|
||||||
def orthogonal_init():
|
def orthogonal_init():
|
||||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||||
|
|
||||||
|
|
||||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
|
||||||
"""Creates an ensemble of critic networks"""
|
"""Creates an ensemble of critic networks"""
|
||||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||||
return nn.ModuleList(critics).to(device)
|
return nn.ModuleList(critics).to(device)
|
||||||
|
|
||||||
|
|
||||||
# borrowed from tdmpc
|
# borrowed from tdmpc
|
||||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ training:
|
|||||||
# Offline training dataloader
|
# Offline training dataloader
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
|
|
||||||
batch_size: 128
|
batch_size: 256
|
||||||
grad_clip_norm: 10.0
|
grad_clip_norm: 10.0
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
|
||||||
eval_freq: 50000
|
eval_freq: 2500
|
||||||
log_freq: 500
|
log_freq: 500
|
||||||
save_freq: 50000
|
save_freq: 50000
|
||||||
|
|
||||||
@@ -46,8 +46,8 @@ policy:
|
|||||||
|
|
||||||
# Input / output structure.
|
# Input / output structure.
|
||||||
n_action_repeats: 1
|
n_action_repeats: 1
|
||||||
horizon: 5
|
horizon: 2
|
||||||
n_action_steps: 5
|
n_action_steps: 2
|
||||||
|
|
||||||
input_shapes:
|
input_shapes:
|
||||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
|||||||
@@ -99,7 +99,8 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||||||
[
|
[
|
||||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||||
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||||
|
{"params": [policy.log_alpha], "lr": policy.config.temperature_lr},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|||||||
1112
lerobot/scripts/train_sac.py
Normal file
1112
lerobot/scripts/train_sac.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user