forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -82,25 +82,43 @@ def create_stats_buffers(
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
buffer["mean"].data = (
|
||||
stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["std"].data = (
|
||||
stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
buffer["min"].data = (
|
||||
stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
buffer["max"].data = (
|
||||
stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead."
|
||||
)
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
Reference in New Issue
Block a user