update
This commit is contained in:
@@ -21,7 +21,6 @@ def create_stats_buffers(
|
|||||||
shapes: dict[str, list[int]],
|
shapes: dict[str, list[int]],
|
||||||
modes: dict[str, str],
|
modes: dict[str, str],
|
||||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
std_epsilon: float = 1e-5,
|
|
||||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||||
"""
|
"""
|
||||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||||
@@ -79,14 +78,10 @@ def create_stats_buffers(
|
|||||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
if mode == "mean_std":
|
if mode == "mean_std":
|
||||||
buffer["mean"].data = stats[key]["mean"].clone()
|
buffer["mean"].data = stats[key]["mean"].clone()
|
||||||
buffer["std"].data = stats[key]["std"].clone().clamp_min(std_epsilon)
|
buffer["std"].data = stats[key]["std"].clone()
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
buffer["min"].data = stats[key]["min"].clone()
|
buffer["min"].data = stats[key]["min"].clone()
|
||||||
buffer["max"].data = stats[key]["max"].clone()
|
buffer["max"].data = stats[key]["max"].clone()
|
||||||
epsilon = (std_epsilon - (stats[key]["max"] - stats[key]["min"]).abs()).clamp_min(
|
|
||||||
0
|
|
||||||
) # To add to have at least std_epsilon between min and max
|
|
||||||
buffer["max"].data += epsilon
|
|
||||||
|
|
||||||
stats_buffers[key] = buffer
|
stats_buffers[key] = buffer
|
||||||
return stats_buffers
|
return stats_buffers
|
||||||
@@ -134,7 +129,8 @@ class Normalize(nn.Module):
|
|||||||
self.shapes = shapes
|
self.shapes = shapes
|
||||||
self.modes = modes
|
self.modes = modes
|
||||||
self.stats = stats
|
self.stats = stats
|
||||||
stats_buffers = create_stats_buffers(shapes, modes, stats, std_epsilon=std_epsilon)
|
self.std_epsilon = std_epsilon
|
||||||
|
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
@@ -150,12 +146,15 @@ class Normalize(nn.Module):
|
|||||||
std = buffer["std"]
|
std = buffer["std"]
|
||||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
output_batch[key] = (batch[key] - mean) / std
|
output_batch[key] = (batch[key] - mean) / std.clamp_min(self.std_epsilon)
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"]
|
min = buffer["min"]
|
||||||
max = buffer["max"]
|
max = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||||
|
# To add to have at least std_epsilon between min and max
|
||||||
|
epsilon = (self.std_epsilon - (max - min).abs()).clamp_min(0)
|
||||||
|
max = max + epsilon
|
||||||
# normalize to [0,1]
|
# normalize to [0,1]
|
||||||
output_batch[key] = (batch[key] - min) / (max - min)
|
output_batch[key] = (batch[key] - min) / (max - min)
|
||||||
# normalize to [-1, 1]
|
# normalize to [-1, 1]
|
||||||
@@ -207,8 +206,9 @@ class Unnormalize(nn.Module):
|
|||||||
self.shapes = shapes
|
self.shapes = shapes
|
||||||
self.modes = modes
|
self.modes = modes
|
||||||
self.stats = stats
|
self.stats = stats
|
||||||
|
self.std_epsilon = std_epsilon
|
||||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||||
stats_buffers = create_stats_buffers(shapes, modes, stats, std_epsilon=std_epsilon)
|
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
@@ -224,12 +224,15 @@ class Unnormalize(nn.Module):
|
|||||||
std = buffer["std"]
|
std = buffer["std"]
|
||||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
output_batch[key] = batch[key] * std + mean
|
output_batch[key] = batch[key] * std.clamp_min(self.std_epsilon) + mean
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"]
|
min = buffer["min"]
|
||||||
max = buffer["max"]
|
max = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||||
|
# To add to have at least std_epsilon between min and max
|
||||||
|
epsilon = (self.std_epsilon - (max - min).abs()).clamp_min(0)
|
||||||
|
max = max + epsilon
|
||||||
output_batch[key] = (batch[key] + 1) / 2
|
output_batch[key] = (batch[key] + 1) / 2
|
||||||
output_batch[key] = output_batch[key] * (max - min) + min
|
output_batch[key] = output_batch[key] * (max - min) + min
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -331,24 +331,14 @@ def test_normalize(insert_temporal_dim):
|
|||||||
).all()
|
).all()
|
||||||
|
|
||||||
assert torch.isclose(
|
assert torch.isclose(
|
||||||
normalize.buffer_action_test_std_cap.std[0],
|
normalize.buffer_action_test_std_cap.std,
|
||||||
dataset_stats["action_test_std_cap"]["std"][0],
|
dataset_stats["action_test_std_cap"]["std"],
|
||||||
rtol=0.1,
|
rtol=0.1,
|
||||||
atol=1e-7,
|
atol=1e-7,
|
||||||
).all()
|
).all()
|
||||||
assert torch.isclose(
|
assert torch.isclose(
|
||||||
normalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7
|
normalize.buffer_action_test_min_max_cap.max - normalize.buffer_action_test_min_max_cap.min,
|
||||||
).all()
|
dataset_stats["action_test_min_max_cap"]["max"] - dataset_stats["action_test_min_max_cap"]["min"],
|
||||||
assert torch.isclose(
|
|
||||||
normalize.buffer_action_test_min_max_cap.max[0] - normalize.buffer_action_test_min_max_cap.min[0],
|
|
||||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
|
||||||
- dataset_stats["action_test_min_max_cap"]["min"][0],
|
|
||||||
rtol=0.1,
|
|
||||||
atol=1e-7,
|
|
||||||
).all()
|
|
||||||
assert torch.isclose(
|
|
||||||
normalize.buffer_action_test_min_max_cap.max[1] - normalize.buffer_action_test_min_max_cap.min[1],
|
|
||||||
torch.ones(1) * std_epsilon,
|
|
||||||
rtol=0.1,
|
rtol=0.1,
|
||||||
atol=1e-7,
|
atol=1e-7,
|
||||||
).all()
|
).all()
|
||||||
@@ -496,24 +486,14 @@ def test_normalize(insert_temporal_dim):
|
|||||||
).all()
|
).all()
|
||||||
|
|
||||||
assert torch.isclose(
|
assert torch.isclose(
|
||||||
unnormalize.buffer_action_test_std_cap.std[0],
|
unnormalize.buffer_action_test_std_cap.std,
|
||||||
dataset_stats["action_test_std_cap"]["std"][0],
|
dataset_stats["action_test_std_cap"]["std"],
|
||||||
rtol=0.1,
|
rtol=0.1,
|
||||||
atol=1e-7,
|
atol=1e-7,
|
||||||
).all()
|
).all()
|
||||||
assert torch.isclose(
|
assert torch.isclose(
|
||||||
unnormalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7
|
unnormalize.buffer_action_test_min_max_cap.max - unnormalize.buffer_action_test_min_max_cap.min,
|
||||||
).all()
|
dataset_stats["action_test_min_max_cap"]["max"] - dataset_stats["action_test_min_max_cap"]["min"],
|
||||||
assert torch.isclose(
|
|
||||||
unnormalize.buffer_action_test_min_max_cap.max[0] - unnormalize.buffer_action_test_min_max_cap.min[0],
|
|
||||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
|
||||||
- dataset_stats["action_test_min_max_cap"]["min"][0],
|
|
||||||
rtol=0.1,
|
|
||||||
atol=1e-7,
|
|
||||||
).all()
|
|
||||||
assert torch.isclose(
|
|
||||||
unnormalize.buffer_action_test_min_max_cap.max[1] - unnormalize.buffer_action_test_min_max_cap.min[1],
|
|
||||||
torch.ones(1) * std_epsilon,
|
|
||||||
rtol=0.1,
|
rtol=0.1,
|
||||||
atol=1e-7,
|
atol=1e-7,
|
||||||
).all()
|
).all()
|
||||||
|
|||||||
Reference in New Issue
Block a user