use mean instead of sampled action for the inference

This commit is contained in:
KeWang1017
2024-12-31 10:48:06 +00:00
committed by Ke-Wang1017
parent 77a7f92139
commit f1f04eb4f9
3 changed files with 7 additions and 6 deletions

View File

@@ -54,7 +54,7 @@ class SACConfig:
critic_target_update_weight = 0.005 critic_target_update_weight = 0.005
utd_ratio = 2 utd_ratio = 2
state_encoder_hidden_dim = 256 state_encoder_hidden_dim = 256
latent_dim = 128 latent_dim = 256
target_entropy = None target_entropy = None
backup_entropy = True backup_entropy = True
critic_network_kwargs = { critic_network_kwargs = {

View File

@@ -111,7 +111,7 @@ class SACPolicy(
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation""" """Select action for inference/evaluation"""
actions, _ = self.actor(batch) _, _, actions = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
return actions return actions
@@ -155,7 +155,7 @@ class SACPolicy(
# calculate critics loss # calculate critics loss
# 1- compute actions from policy # 1- compute actions from policy
action_preds, log_probs = self.actor(next_observations) action_preds, log_probs, _ = self.actor(next_observations)
# 2- compute q targets # 2- compute q targets
q_targets = self.critic_forward(next_observations, action_preds, use_target=True) q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
@@ -195,7 +195,7 @@ class SACPolicy(
# 1- temperature # 1- temperature
temperature = self.temperature() temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,) # 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor(observations) actions, log_probs, _ = self.actor(observations)
# 3- get q-value predictions # 3- get q-value predictions
with torch.inference_mode(): with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False) q_preds = self.critic_forward(observations, actions, use_target=False)
@@ -405,8 +405,9 @@ class Policy(nn.Module):
actions = torch.tanh(x_t) actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim log_probs = log_probs.sum(-1) # sum over action dim
means = torch.tanh(means)
return actions, log_probs return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor: def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations""" """Get encoded features from observations"""

View File

@@ -15,7 +15,7 @@ training:
# Offline training dataloader # Offline training dataloader
num_workers: 4 num_workers: 4
batch_size: 128 batch_size: 256
grad_clip_norm: 10.0 grad_clip_norm: 10.0
lr: 3e-4 lr: 3e-4