forked from tangger/lerobot
Remove torch.no_grad decorator and optimize next action prediction in SAC policy
- Removed `@torch.no_grad` decorator from Unnormalize forward method - Added TODO comment for optimizing next action prediction in SAC policy - Minor formatting adjustment in NaN assertion for log standard deviation Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
This commit is contained in:
committed by
Michel Aractingi
parent
3dfb37e976
commit
e002c5ec56
@@ -223,7 +223,7 @@ class Unnormalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
# @torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
|
||||
Reference in New Issue
Block a user