Remove numpy array support

This commit is contained in:
AdilZouitine
2025-05-15 18:42:59 +02:00
parent 1a936113c2
commit fa72aed5b6

View File

@@ -254,7 +254,7 @@ class Unnormalize(nn.Module):
return batch
# TODO: We should replace all normalization on the policies with register_buffer normalization
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
module: nn.Module,
@@ -287,10 +287,11 @@ def _initialize_stats_buffers(
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
std_data = stats[key]["std"]
if isinstance(mean_data, np.ndarray):
mean = torch.from_numpy(mean_data).to(dtype=torch.float32)
std = torch.from_numpy(std_data).to(dtype=torch.float32)
elif isinstance(mean_data, torch.Tensor):
if isinstance(mean_data, torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
else:
@@ -307,10 +308,7 @@ def _initialize_stats_buffers(
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, np.ndarray):
min_val = torch.from_numpy(min_data).to(dtype=torch.float32)
max_val = torch.from_numpy(max_data).to(dtype=torch.float32)
elif isinstance(min_data, torch.Tensor):
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
else: