Add regression tests (#119)

- Add `tests/scripts/save_policy_to_safetensor.py` to generate test artifacts
- Add `test_backward_compatibility to test generated outputs from the policies against artifacts
This commit is contained in:
Simon Alibert
2024-05-04 16:20:30 +02:00
committed by GitHub
parent 19812ca470
commit c77633c38c
15 changed files with 236 additions and 43 deletions

View File

@@ -80,7 +80,8 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
self.config = config
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
self.model_target.eval()
for param in self.model_target.parameters():
param.requires_grad = False
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(