ready for review

This commit is contained in:
Alexander Soare
2024-04-12 11:36:52 +01:00
parent 5666ec3ec7
commit 6d0a45a97d
7 changed files with 11 additions and 42 deletions

View File

@@ -18,7 +18,6 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Note: assumes that all parameters have the same device
TODO(now): Add test.
"""
return next(iter(module.parameters())).device
@@ -27,6 +26,5 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
"""Get a module's parameter dtype by checking one of its parameters.
Note: assumes that all parameters have the same dtype.
TODO(now): Add test.
"""
return next(iter(module.parameters())).dtype