use mean instead of sampled action for the inference

This commit is contained in:
KeWang1017
2024-12-31 10:48:06 +00:00
committed by Ke-Wang1017
parent 77a7f92139
commit f1f04eb4f9
3 changed files with 7 additions and 6 deletions

View File

@@ -54,7 +54,7 @@ class SACConfig:
critic_target_update_weight = 0.005
utd_ratio = 2
state_encoder_hidden_dim = 256
latent_dim = 128
latent_dim = 256
target_entropy = None
backup_entropy = True
critic_network_kwargs = {

View File

@@ -111,7 +111,7 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
actions, _ = self.actor(batch)
_, _, actions = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
@@ -155,7 +155,7 @@ class SACPolicy(
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor(next_observations)
action_preds, log_probs, _ = self.actor(next_observations)
# 2- compute q targets
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
@@ -195,7 +195,7 @@ class SACPolicy(
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor(observations)
actions, log_probs, _ = self.actor(observations)
# 3- get q-value predictions
with torch.inference_mode():
q_preds = self.critic_forward(observations, actions, use_target=False)
@@ -405,8 +405,9 @@ class Policy(nn.Module):
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim
means = torch.tanh(means)
return actions, log_probs
return actions, log_probs, means
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""

View File

@@ -15,7 +15,7 @@ training:
# Offline training dataloader
num_workers: 4
batch_size: 128
batch_size: 256
grad_clip_norm: 10.0
lr: 3e-4