forked from tangger/lerobot
Use PytorchModelHubMixin to save models as safetensors (#125)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
@@ -57,17 +57,28 @@ def create_stats_buffers(
|
||||
)
|
||||
|
||||
if stats is not None:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[key]["mean"]
|
||||
buffer["std"].data = stats[key]["std"]
|
||||
buffer["mean"].data = stats[key]["mean"].clone()
|
||||
buffer["std"].data = stats[key]["std"].clone()
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[key]["min"]
|
||||
buffer["max"].data = stats[key]["max"]
|
||||
buffer["min"].data = stats[key]["min"].clone()
|
||||
buffer["max"].data = stats[key]["max"].clone()
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model."
|
||||
)
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
@@ -99,7 +110,6 @@ class Normalize(nn.Module):
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
@@ -113,26 +123,14 @@ class Normalize(nn.Module):
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), (
|
||||
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(std).any(), (
|
||||
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), (
|
||||
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(max).any(), (
|
||||
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
@@ -190,26 +188,14 @@ class Unnormalize(nn.Module):
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), (
|
||||
"`mean` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(std).any(), (
|
||||
"`std` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), (
|
||||
"`min` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(max).any(), (
|
||||
"`max` is infinity. You forgot to initialize with `stats` as argument, or called "
|
||||
"`policy.load_state_dict`."
|
||||
)
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user