Hardcoded some normalization parameters. TODO refactor

Added masking actions on the level of the intervention actions and offline dataset

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-13 14:27:14 +01:00
committed by AdilZouitine
parent a0e0a9a9b1
commit eb7e28d9d9
6 changed files with 36 additions and 8 deletions

View File

@@ -137,7 +137,7 @@ class SACPolicy(
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
if self.actor.fixed_std is not None:
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
self.log_alpha = self.log_alpha.to(*args, **kwargs)
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
super().to(*args, **kwargs)
@torch.no_grad()