Remove numpy array support
This commit is contained in:
@@ -254,7 +254,7 @@ class Unnormalize(nn.Module):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
# TODO: We should replace all normalization on the policies with register_buffer normalization
|
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||||
# and remove the `Normalize` and `Unnormalize` classes.
|
# and remove the `Normalize` and `Unnormalize` classes.
|
||||||
def _initialize_stats_buffers(
|
def _initialize_stats_buffers(
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
@@ -287,10 +287,11 @@ def _initialize_stats_buffers(
|
|||||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||||
mean_data = stats[key]["mean"]
|
mean_data = stats[key]["mean"]
|
||||||
std_data = stats[key]["std"]
|
std_data = stats[key]["std"]
|
||||||
if isinstance(mean_data, np.ndarray):
|
if isinstance(mean_data, torch.Tensor):
|
||||||
mean = torch.from_numpy(mean_data).to(dtype=torch.float32)
|
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||||
std = torch.from_numpy(std_data).to(dtype=torch.float32)
|
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||||
elif isinstance(mean_data, torch.Tensor):
|
# unnormalization). See the logic here
|
||||||
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
mean = mean_data.clone().to(dtype=torch.float32)
|
mean = mean_data.clone().to(dtype=torch.float32)
|
||||||
std = std_data.clone().to(dtype=torch.float32)
|
std = std_data.clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
@@ -307,10 +308,7 @@ def _initialize_stats_buffers(
|
|||||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||||
min_data = stats[key]["min"]
|
min_data = stats[key]["min"]
|
||||||
max_data = stats[key]["max"]
|
max_data = stats[key]["max"]
|
||||||
if isinstance(min_data, np.ndarray):
|
if isinstance(min_data, torch.Tensor):
|
||||||
min_val = torch.from_numpy(min_data).to(dtype=torch.float32)
|
|
||||||
max_val = torch.from_numpy(max_data).to(dtype=torch.float32)
|
|
||||||
elif isinstance(min_data, torch.Tensor):
|
|
||||||
min_val = min_data.clone().to(dtype=torch.float32)
|
min_val = min_data.clone().to(dtype=torch.float32)
|
||||||
max_val = max_data.clone().to(dtype=torch.float32)
|
max_val = max_data.clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user