Change SAC policy implementation with configuration and modeling classes

This commit is contained in:
Adil Zouitine
2025-01-17 09:39:04 +01:00
committed by AdilZouitine
parent bd8c768f62
commit 4b78ab2789
4 changed files with 51 additions and 714 deletions

View File

@@ -59,7 +59,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy, SACConfig
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")

View File

@@ -28,29 +28,41 @@ class SACConfig:
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
"observation.environment_state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
shared_encoder = False
discount = 0.99
temperature_init = 1.0
num_critics = 2
# num_critics = 8
num_subsample_critics = None
# num_subsample_critics = 2
# critic_lr = 1e-3
critic_lr = 3e-4
actor_lr = 3e-4
temperature_lr = 3e-4
critic_target_update_weight = 0.005
utd_ratio = 2
# utd_ratio = 8
utd_ratio = 1 # If you want enable utd_ratio, you need to set it to >1
state_encoder_hidden_dim = 256
latent_dim = 128
latent_dim = 256
target_entropy = None
backup_entropy = True
# backup_entropy = False
use_backup_entropy = True
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,

View File

@@ -57,19 +57,22 @@ class SACPolicy(
else:
self.normalize_inputs = nn.Identity()
# HACK: we need to pass the dataset_stats to the normalization functions
# NOTE: This is for biwalker environment
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
# HACK: we need to pass the dataset_stats to the normalization functions
dataset_stats = dataset_stats or {
"action": {
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
}
}
# NOTE: This is for pusht environment
# dataset_stats = dataset_stats or {
# "action": {
# "min": torch.tensor([0, 0]),
# "max": torch.tensor([512, 512]),
# }
# }
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
@@ -77,8 +80,12 @@ class SACPolicy(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config)
if config.shared_encoder:
encoder_critic = SACObservationEncoder(config)
encoder_actor = encoder_critic
else:
encoder_critic = SACObservationEncoder(config)
encoder_actor = SACObservationEncoder(config)
# Define networks
critic_nets = []
for _ in range(config.num_critics):
@@ -105,7 +112,6 @@ class SACPolicy(
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
self.actor = Policy(
encoder=encoder_actor,
@@ -159,100 +165,7 @@ class SACPolicy(
q_values = torch.stack([critic(observations, actions) for critic in critics])
return q_values
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
# We have to actualize the value of the temperature because in the previous
self.temperature = self.log_alpha.exp().item()
temperature = self.temperature
temperature = self.temperature
batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for calculating the right td index.
# actions = batch["action"][:, 0]
actions = batch["action"]
# actions = batch["action"][:, 0]
actions = batch["action"]
rewards = batch["next.reward"][:, 0]
observations = {}
next_observations = {}
for k in batch:
if k.startswith("observation."):
observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1]
done = batch["next.done"]
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q -= self.temperature * next_log_probs
td_target = rewards + self.config.discount * min_q * ~done
td_target = rewards + self.config.discount * min_q * ~done
# 3- compute predicted qs
q_preds = self.critic_forward(observations, actions, use_target=False)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(1)
).sum()
actions_pi, log_probs, _ = self.actor(observations)
actions_pi, log_probs, _ = self.actor(observations)
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
loss = critics_loss + actor_loss + temperature_loss
return {
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"mean_q_predicts": min_q_preds.mean().item(),
"min_q_predicts": min_q_preds.min().item(),
"max_q_predicts": min_q_preds.max().item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature,
"mean_log_probs": log_probs.mean().item(),
"min_log_probs": log_probs.min().item(),
"max_log_probs": log_probs.max().item(),
"td_target_mean": td_target.mean().item(),
"td_target_max": td_target.max().item(),
"action_mean": actions.mean().item(),
"entropy": log_probs.mean().item(),
"loss": loss,
}
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
@@ -271,9 +184,6 @@ class SACPolicy(
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
q_targets = self.critic_forward(
observations=next_observations, actions=next_action_preds, use_target=True
)
# subsample critics to prevent overfitting if use high UTD (update to date)
if self.config.num_subsample_critics is not None:
@@ -440,7 +350,6 @@ class Policy(nn.Module):
if isinstance(layer, nn.Linear):
out_features = layer.out_features
break
# Mean layer
self.mean_layer = nn.Linear(out_features, action_dim)
if init_final is not None: