This commit is contained in:
AdilZouitine
2025-04-25 17:22:02 +02:00
parent 1d4f660075
commit 50e9a8ed6a
3 changed files with 24 additions and 63 deletions

View File

@@ -79,46 +79,28 @@ def create_stats_buffers(
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats and key in stats:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" not in stats[key] or "std" not in stats[key]:
raise ValueError(
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
)
if isinstance(stats[key]["mean"], np.ndarray):
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" not in stats[key] or "max" not in stats[key]:
raise ValueError(
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
)
if isinstance(stats[key]["min"], np.ndarray):
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["min"], torch.Tensor):
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["min"])
raise ValueError(
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
@@ -166,14 +148,11 @@ class Normalize(nn.Module):
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
# @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
@@ -241,8 +220,6 @@ class Unnormalize(nn.Module):
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
# @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():