checkout normalize.py to prev commit
This commit is contained in:
@@ -79,28 +79,48 @@ def create_stats_buffers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||||
if stats:
|
if stats and key in stats:
|
||||||
if isinstance(stats[key]["mean"], np.ndarray):
|
# NOTE:(maractingi, azouitine): Change the order of these conditions becuase in online environments we don't have dataset stats
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
# Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only
|
||||||
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
|
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(stats[key]["mean"], np.ndarray):
|
||||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
# unnormalization). See the logic here
|
||||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
|
||||||
# unnormalization). See the logic here
|
|
||||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
|
||||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
else:
|
||||||
|
type_ = type(stats[key]["mean"])
|
||||||
|
raise ValueError(
|
||||||
|
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
|
if "min" not in stats[key] or "max" not in stats[key]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(stats[key]["min"], np.ndarray):
|
||||||
|
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||||
|
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||||
|
elif isinstance(stats[key]["min"], torch.Tensor):
|
||||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["mean"])
|
type_ = type(stats[key]["min"])
|
||||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
raise ValueError(
|
||||||
|
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
|
||||||
|
)
|
||||||
|
|
||||||
stats_buffers[key] = buffer
|
stats_buffers[key] = buffer
|
||||||
return stats_buffers
|
return stats_buffers
|
||||||
@@ -148,11 +168,14 @@ class Normalize(nn.Module):
|
|||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
|
# @torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||||
|
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
|
||||||
continue
|
continue
|
||||||
|
|
||||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||||
@@ -220,6 +243,8 @@ class Unnormalize(nn.Module):
|
|||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
|
# @torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user