Change SAC policy implementation with configuration and modeling classes
This commit is contained in:
committed by
AdilZouitine
parent
bd8c768f62
commit
4b78ab2789
@@ -59,7 +59,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy, SACConfig
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
@@ -28,29 +28,41 @@ class SACConfig:
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
shared_encoder = False
|
||||
discount = 0.99
|
||||
temperature_init = 1.0
|
||||
num_critics = 2
|
||||
# num_critics = 8
|
||||
num_subsample_critics = None
|
||||
# num_subsample_critics = 2
|
||||
# critic_lr = 1e-3
|
||||
critic_lr = 3e-4
|
||||
actor_lr = 3e-4
|
||||
temperature_lr = 3e-4
|
||||
critic_target_update_weight = 0.005
|
||||
utd_ratio = 2
|
||||
# utd_ratio = 8
|
||||
utd_ratio = 1 # If you want enable utd_ratio, you need to set it to >1
|
||||
state_encoder_hidden_dim = 256
|
||||
latent_dim = 128
|
||||
latent_dim = 256
|
||||
target_entropy = None
|
||||
backup_entropy = True
|
||||
# backup_entropy = False
|
||||
use_backup_entropy = True
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
|
||||
@@ -57,19 +57,22 @@ class SACPolicy(
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
# HACK: we need to pass the dataset_stats to the normalization functions
|
||||
|
||||
# NOTE: This is for biwalker environment
|
||||
dataset_stats = dataset_stats or {
|
||||
"action": {
|
||||
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
|
||||
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
# HACK: we need to pass the dataset_stats to the normalization functions
|
||||
dataset_stats = dataset_stats or {
|
||||
"action": {
|
||||
"min": torch.tensor([-1.0, -1.0, -1.0, -1.0]),
|
||||
"max": torch.tensor([1.0, 1.0, 1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
|
||||
# NOTE: This is for pusht environment
|
||||
# dataset_stats = dataset_stats or {
|
||||
# "action": {
|
||||
# "min": torch.tensor([0, 0]),
|
||||
# "max": torch.tensor([512, 512]),
|
||||
# }
|
||||
# }
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
@@ -77,8 +80,12 @@ class SACPolicy(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
@@ -105,7 +112,6 @@ class SACPolicy(
|
||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = create_critic_ensemble(target_critic_nets, config.num_critics)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
@@ -159,100 +165,7 @@ class SACPolicy(
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
# We have to actualize the value of the temperature because in the previous
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
temperature = self.temperature
|
||||
temperature = self.temperature
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for calculating the right td index.
|
||||
# actions = batch["action"][:, 0]
|
||||
actions = batch["action"]
|
||||
# actions = batch["action"][:, 0]
|
||||
actions = batch["action"]
|
||||
rewards = batch["next.reward"][:, 0]
|
||||
observations = {}
|
||||
next_observations = {}
|
||||
for k in batch:
|
||||
if k.startswith("observation."):
|
||||
observations[k] = batch[k][:, 0]
|
||||
next_observations[k] = batch[k][:, 1]
|
||||
done = batch["next.done"]
|
||||
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(next_observations, next_action_preds, use_target=True)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q -= self.temperature * next_log_probs
|
||||
td_target = rewards + self.config.discount * min_q * ~done
|
||||
td_target = rewards + self.config.discount * min_q * ~done
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(1)
|
||||
).sum()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
with torch.inference_mode():
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"mean_q_predicts": min_q_preds.mean().item(),
|
||||
"min_q_predicts": min_q_preds.min().item(),
|
||||
"max_q_predicts": min_q_preds.max().item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature,
|
||||
"mean_log_probs": log_probs.mean().item(),
|
||||
"min_log_probs": log_probs.min().item(),
|
||||
"max_log_probs": log_probs.max().item(),
|
||||
"td_target_mean": td_target.mean().item(),
|
||||
"td_target_max": td_target.max().item(),
|
||||
"action_mean": actions.mean().item(),
|
||||
"entropy": log_probs.mean().item(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
@@ -271,9 +184,6 @@ class SACPolicy(
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
)
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
@@ -440,7 +350,6 @@ class Policy(nn.Module):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Mean layer
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
|
||||
Reference in New Issue
Block a user