use mean instead of sampled action for the inference
This commit is contained in:
@@ -54,7 +54,7 @@ class SACConfig:
|
||||
critic_target_update_weight = 0.005
|
||||
utd_ratio = 2
|
||||
state_encoder_hidden_dim = 256
|
||||
latent_dim = 128
|
||||
latent_dim = 256
|
||||
target_entropy = None
|
||||
backup_entropy = True
|
||||
critic_network_kwargs = {
|
||||
|
||||
@@ -111,7 +111,7 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
actions, _ = self.actor(batch)
|
||||
_, _, actions = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
@@ -155,7 +155,7 @@ class SACPolicy(
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor(next_observations)
|
||||
action_preds, log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
|
||||
@@ -195,7 +195,7 @@ class SACPolicy(
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor(observations)
|
||||
actions, log_probs, _ = self.actor(observations)
|
||||
# 3- get q-value predictions
|
||||
with torch.inference_mode():
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
@@ -405,8 +405,9 @@ class Policy(nn.Module):
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
||||
log_probs = log_probs.sum(-1) # sum over action dim
|
||||
means = torch.tanh(means)
|
||||
|
||||
return actions, log_probs
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
|
||||
@@ -15,7 +15,7 @@ training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 128
|
||||
batch_size: 256
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
|
||||
Reference in New Issue
Block a user