use mean instead of sampled action for the inference
This commit is contained in:
@@ -54,7 +54,7 @@ class SACConfig:
|
|||||||
critic_target_update_weight = 0.005
|
critic_target_update_weight = 0.005
|
||||||
utd_ratio = 2
|
utd_ratio = 2
|
||||||
state_encoder_hidden_dim = 256
|
state_encoder_hidden_dim = 256
|
||||||
latent_dim = 128
|
latent_dim = 256
|
||||||
target_entropy = None
|
target_entropy = None
|
||||||
backup_entropy = True
|
backup_entropy = True
|
||||||
critic_network_kwargs = {
|
critic_network_kwargs = {
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class SACPolicy(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select action for inference/evaluation"""
|
"""Select action for inference/evaluation"""
|
||||||
actions, _ = self.actor(batch)
|
_, _, actions = self.actor(batch)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@@ -155,7 +155,7 @@ class SACPolicy(
|
|||||||
|
|
||||||
# calculate critics loss
|
# calculate critics loss
|
||||||
# 1- compute actions from policy
|
# 1- compute actions from policy
|
||||||
action_preds, log_probs = self.actor(next_observations)
|
action_preds, log_probs, _ = self.actor(next_observations)
|
||||||
|
|
||||||
# 2- compute q targets
|
# 2- compute q targets
|
||||||
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
|
q_targets = self.critic_forward(next_observations, action_preds, use_target=True)
|
||||||
@@ -195,7 +195,7 @@ class SACPolicy(
|
|||||||
# 1- temperature
|
# 1- temperature
|
||||||
temperature = self.temperature()
|
temperature = self.temperature()
|
||||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||||
actions, log_probs = self.actor(observations)
|
actions, log_probs, _ = self.actor(observations)
|
||||||
# 3- get q-value predictions
|
# 3- get q-value predictions
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||||
@@ -405,8 +405,9 @@ class Policy(nn.Module):
|
|||||||
actions = torch.tanh(x_t)
|
actions = torch.tanh(x_t)
|
||||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
|
||||||
log_probs = log_probs.sum(-1) # sum over action dim
|
log_probs = log_probs.sum(-1) # sum over action dim
|
||||||
|
means = torch.tanh(means)
|
||||||
|
|
||||||
return actions, log_probs
|
return actions, log_probs, means
|
||||||
|
|
||||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||||
"""Get encoded features from observations"""
|
"""Get encoded features from observations"""
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ training:
|
|||||||
# Offline training dataloader
|
# Offline training dataloader
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
|
|
||||||
batch_size: 128
|
batch_size: 256
|
||||||
grad_clip_norm: 10.0
|
grad_clip_norm: 10.0
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user