This commit is contained in:
Thomas Wolf
2024-06-18 11:26:49 +02:00
parent b72d574891
commit cd9ace20b6
2 changed files with 212 additions and 8 deletions

View File

@@ -21,6 +21,7 @@ def create_stats_buffers(
shapes: dict[str, list[int]], shapes: dict[str, list[int]],
modes: dict[str, str], modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
std_epsilon: float = 1e-5,
) -> dict[str, dict[str, nn.ParameterDict]]: ) -> dict[str, dict[str, nn.ParameterDict]]:
""" """
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
@@ -78,10 +79,14 @@ def create_stats_buffers(
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if mode == "mean_std": if mode == "mean_std":
buffer["mean"].data = stats[key]["mean"].clone() buffer["mean"].data = stats[key]["mean"].clone()
buffer["std"].data = stats[key]["std"].clone() buffer["std"].data = stats[key]["std"].clone().clamp_min(std_epsilon)
elif mode == "min_max": elif mode == "min_max":
buffer["min"].data = stats[key]["min"].clone() buffer["min"].data = stats[key]["min"].clone()
buffer["max"].data = stats[key]["max"].clone() buffer["max"].data = stats[key]["max"].clone()
epsilon = (std_epsilon - (stats[key]["max"] - stats[key]["min"]).abs()).clamp_min(
0
) # To add to have at least std_epsilon between min and max
buffer["max"].data += epsilon
stats_buffers[key] = buffer stats_buffers[key] = buffer
return stats_buffers return stats_buffers
@@ -102,6 +107,7 @@ class Normalize(nn.Module):
shapes: dict[str, list[int]], shapes: dict[str, list[int]],
modes: dict[str, str], modes: dict[str, str],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
std_epsilon: float = 1e-5,
): ):
""" """
Args: Args:
@@ -120,18 +126,22 @@ class Normalize(nn.Module):
not provided, as expected for finetuning or evaluation, the default buffers should to be not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict. dataset is not needed to get the stats, since they are already in the policy state_dict.
std_epsilon (float, optional): A small minimal value for the standard deviation to avoid division by
zero. Default is `1e-5`. We use `clamp_min` to make sure the standard deviation (or the difference
between min and max) is at least `std_epsilon`.
""" """
super().__init__() super().__init__()
self.shapes = shapes self.shapes = shapes
self.modes = modes self.modes = modes
self.stats = stats self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats) stats_buffers = create_stats_buffers(shapes, modes, stats, std_epsilon=std_epsilon)
for key, buffer in stats_buffers.items(): for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad? # TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad @torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
output_batch = {}
for key, mode in self.modes.items(): for key, mode in self.modes.items():
buffer = getattr(self, "buffer_" + key.replace(".", "_")) buffer = getattr(self, "buffer_" + key.replace(".", "_"))
@@ -140,19 +150,19 @@ class Normalize(nn.Module):
std = buffer["std"] std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean") assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std") assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8) output_batch[key] = (batch[key] - mean) / std
elif mode == "min_max": elif mode == "min_max":
min = buffer["min"] min = buffer["min"]
max = buffer["max"] max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min") assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max") assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1] # normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8) output_batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1] # normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1 output_batch[key] = output_batch[key] * 2 - 1
else: else:
raise ValueError(mode) raise ValueError(mode)
return batch return output_batch
class Unnormalize(nn.Module): class Unnormalize(nn.Module):

View File

@@ -204,17 +204,33 @@ def test_normalize(insert_temporal_dim):
input_shapes = { input_shapes = {
"observation.image": [3, 96, 96], "observation.image": [3, 96, 96],
"observation.state": [10], "observation.state": [10],
"action_test_std": [1],
"action_test_min_max": [1],
"action_test_std_cap": [2],
"action_test_min_max_cap": [2],
} }
output_shapes = { output_shapes = {
"action": [5], "action": [5],
"action_test_std": [1],
"action_test_min_max": [1],
"action_test_std_cap": [2],
"action_test_min_max_cap": [2],
} }
normalize_input_modes = { normalize_input_modes = {
"observation.image": "mean_std", "observation.image": "mean_std",
"observation.state": "min_max", "observation.state": "min_max",
"action_test_std": "mean_std",
"action_test_min_max": "min_max",
"action_test_std_cap": "mean_std",
"action_test_min_max_cap": "min_max",
} }
unnormalize_output_modes = { unnormalize_output_modes = {
"action": "min_max", "action": "min_max",
"action_test_std": "mean_std",
"action_test_min_max": "min_max",
"action_test_std_cap": "mean_std",
"action_test_min_max_cap": "min_max",
} }
dataset_stats = { dataset_stats = {
@@ -236,15 +252,43 @@ def test_normalize(insert_temporal_dim):
"min": torch.randn(5), "min": torch.randn(5),
"max": torch.randn(5), "max": torch.randn(5),
}, },
"action_test_std": {
"mean": torch.ones(1) * 2,
"std": torch.ones(1) * 0.2,
},
"action_test_min_max": {
"min": torch.ones(1) * 1,
"max": torch.ones(1) * 3,
},
"action_test_std_cap": {
"mean": torch.ones(2) * 2,
"std": torch.ones(2) * 0.2,
},
"action_test_min_max_cap": {
"min": torch.ones(2) * 1.0,
"max": torch.ones(2) * 3.0,
},
} }
# Set some values to 0 to test the case where the std is 0 - for max we set it to min
dataset_stats["action_test_std_cap"]["std"][1] = 0.0
dataset_stats["action_test_min_max_cap"]["max"][1] = dataset_stats["action_test_min_max_cap"]["min"][1]
bsize = 2 bsize = 2
input_batch = { input_batch = {
"observation.image": torch.randn(bsize, 3, 96, 96), "observation.image": torch.randn(bsize, 3, 96, 96),
"observation.state": torch.randn(bsize, 10), "observation.state": torch.randn(bsize, 10),
"action_test_std": torch.ones(bsize, 1) * 2.5,
"action_test_min_max": torch.ones(bsize, 1) * 2.5,
"action_test_std_cap": torch.ones(bsize, 2) * 2.5,
"action_test_min_max_cap": torch.ones(bsize, 2) * 2.5,
} }
output_batch = { output_batch = {
"action": torch.randn(bsize, 5), "action": torch.randn(bsize, 5),
"action_test_std": torch.ones(bsize, 1) * 2.5,
"action_test_min_max": torch.ones(bsize, 1) * 2.5,
"action_test_std_cap": torch.ones(bsize, 2) * 2.5,
"action_test_min_max_cap": torch.ones(bsize, 2) * 2.5,
} }
if insert_temporal_dim: if insert_temporal_dim:
@@ -263,8 +307,158 @@ def test_normalize(insert_temporal_dim):
normalize(input_batch) normalize(input_batch)
# test with stats # test with stats
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats) std_epsilon = 1e-2
normalize(input_batch) normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats, std_epsilon=std_epsilon)
# check that the stats are correctly set including the min capping
assert torch.isclose(
normalize.buffer_action_test_std.mean, dataset_stats["action_test_std"]["mean"], rtol=0.1, atol=1e-7
).all()
assert torch.isclose(
normalize.buffer_action_test_std.std, dataset_stats["action_test_std"]["std"], rtol=0.1, atol=1e-7
).all()
assert torch.isclose(
normalize.buffer_action_test_min_max.min,
dataset_stats["action_test_min_max"]["min"],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalize.buffer_action_test_min_max.max,
dataset_stats["action_test_min_max"]["max"],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalize.buffer_action_test_std_cap.std[0],
dataset_stats["action_test_std_cap"]["std"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7
).all()
assert torch.isclose(
normalize.buffer_action_test_min_max_cap.max[0] - normalize.buffer_action_test_min_max_cap.min[0],
dataset_stats["action_test_min_max_cap"]["max"][0]
- dataset_stats["action_test_min_max_cap"]["min"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalize.buffer_action_test_min_max_cap.max[1] - normalize.buffer_action_test_min_max_cap.min[1],
torch.ones(1) * std_epsilon,
rtol=0.1,
atol=1e-7,
).all()
normalized_output = normalize(input_batch)
# check that the normalization is correct
assert torch.isclose(
normalized_output["action_test_std"],
(input_batch["action_test_std"] - dataset_stats["action_test_std"]["mean"])
/ dataset_stats["action_test_std"]["std"],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalized_output["action_test_min_max"],
(input_batch["action_test_min_max"] - dataset_stats["action_test_min_max"]["min"])
/ (dataset_stats["action_test_min_max"]["max"] - dataset_stats["action_test_min_max"]["min"])
* 2
- 1,
rtol=0.1,
atol=1e-7,
).all()
if insert_temporal_dim:
assert torch.isclose(
normalized_output["action_test_std_cap"][0, 0, 0],
(input_batch["action_test_std_cap"][0, 0, 0] - dataset_stats["action_test_std_cap"]["mean"][0])
/ dataset_stats["action_test_std_cap"]["std"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalized_output["action_test_std_cap"][0, 0, 1],
(input_batch["action_test_std_cap"][0, 0, 1] - dataset_stats["action_test_std_cap"]["mean"][1])
/ std_epsilon,
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalized_output["action_test_min_max_cap"][0, 0, 0],
(
input_batch["action_test_min_max_cap"][0, 0, 0]
- dataset_stats["action_test_min_max_cap"]["min"][0]
)
/ (
dataset_stats["action_test_min_max_cap"]["max"][0]
- dataset_stats["action_test_min_max_cap"]["min"][0]
)
* 2
- 1,
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalized_output["action_test_min_max_cap"][0, 0, 1],
(
input_batch["action_test_min_max_cap"][0, 0, 1]
- dataset_stats["action_test_min_max_cap"]["min"][1]
)
/ std_epsilon
* 2
- 1,
rtol=0.1,
atol=1e-7,
).all()
else:
assert torch.isclose(
normalized_output["action_test_std_cap"][0, 0],
(input_batch["action_test_std_cap"][0, 0] - dataset_stats["action_test_std_cap"]["mean"][0])
/ dataset_stats["action_test_std_cap"]["std"][0],
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalized_output["action_test_std_cap"][0, 1],
(input_batch["action_test_std_cap"][0, 1] - dataset_stats["action_test_std_cap"]["mean"][1])
/ std_epsilon,
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalized_output["action_test_min_max_cap"][0, 0],
(
input_batch["action_test_min_max_cap"][0, 0]
- dataset_stats["action_test_min_max_cap"]["min"][0]
)
/ (
dataset_stats["action_test_min_max_cap"]["max"][0]
- dataset_stats["action_test_min_max_cap"]["min"][0]
)
* 2
- 1,
rtol=0.1,
atol=1e-7,
).all()
assert torch.isclose(
normalized_output["action_test_min_max_cap"][0, 1],
(
input_batch["action_test_min_max_cap"][0, 1]
- dataset_stats["action_test_min_max_cap"]["min"][1]
)
/ std_epsilon
* 2
- 1,
rtol=0.1,
atol=1e-7,
).all()
# test loading pretrained models # test loading pretrained models
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None) new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)