Remove torch.no_grad decorator and optimize next action prediction in SAC policy

- Removed `@torch.no_grad` decorator from Unnormalize forward method

- Added TODO comment for optimizing next action prediction in SAC policy
- Minor formatting adjustment in NaN assertion for log standard deviation
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
This commit is contained in:
AdilZouitine
2025-03-10 10:31:38 +00:00
committed by Michel Aractingi
parent 3dfb37e976
commit e002c5ec56
2 changed files with 9 additions and 4 deletions

View File

@@ -223,7 +223,7 @@ class Unnormalize(nn.Module):
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
# @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():