ready for review
This commit is contained in:
@@ -18,7 +18,6 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
"""Get a module's device by checking one of its parameters.
|
||||
|
||||
Note: assumes that all parameters have the same device
|
||||
TODO(now): Add test.
|
||||
"""
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
@@ -27,6 +26,5 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
"""Get a module's parameter dtype by checking one of its parameters.
|
||||
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
TODO(now): Add test.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
Reference in New Issue
Block a user