Update torch.load calls in network_utils.py to include weights_only=False, to ensure no regression with torch 2.6 update

This commit is contained in:
AdilZouitine
2025-04-29 18:23:51 +02:00
parent 4257fe5045
commit fb7c288c94

View File

@@ -111,7 +111,7 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer) # nosec B614: Safe usage of torch.load
return torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
def python_object_to_bytes(python_object: Any) -> bytes:
@@ -129,7 +129,7 @@ def bytes_to_python_object(buffer: bytes) -> Any:
def bytes_to_transitions(buffer: bytes) -> list[Transition]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
transitions = torch.load(buffer) # nosec B614: Safe usage of torch.load
transitions = torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
# Add validation checks here
return transitions