Remove torch.no_grad decorator and optimize next action prediction in SAC policy
- Removed `@torch.no_grad` decorator from Unnormalize forward method - Added TODO comment for optimizing next action prediction in SAC policy - Minor formatting adjustment in NaN assertion for log standard deviation Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
This commit is contained in:
@@ -210,6 +210,11 @@ class SACPolicy(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
|
||||
next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[
|
||||
"action"
|
||||
]
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
@@ -512,9 +517,9 @@ class Policy(nn.Module):
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(
|
||||
log_std
|
||||
).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
assert not torch.isnan(log_std).any(), (
|
||||
"[ERROR] log_std became NaN after std_layer!"
|
||||
)
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
|
||||
Reference in New Issue
Block a user