forked from tangger/lerobot
Remove numpy array support
This commit is contained in:
@@ -254,7 +254,7 @@ class Unnormalize(nn.Module):
|
||||
return batch
|
||||
|
||||
|
||||
# TODO: We should replace all normalization on the policies with register_buffer normalization
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
@@ -287,10 +287,11 @@ def _initialize_stats_buffers(
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, np.ndarray):
|
||||
mean = torch.from_numpy(mean_data).to(dtype=torch.float32)
|
||||
std = torch.from_numpy(std_data).to(dtype=torch.float32)
|
||||
elif isinstance(mean_data, torch.Tensor):
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
@@ -307,10 +308,7 @@ def _initialize_stats_buffers(
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, np.ndarray):
|
||||
min_val = torch.from_numpy(min_data).to(dtype=torch.float32)
|
||||
max_val = torch.from_numpy(max_data).to(dtype=torch.float32)
|
||||
elif isinstance(min_data, torch.Tensor):
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user