Compare commits
6 Commits
qgallouede
...
thomwolf_2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8efe947def | ||
|
|
f9295e2c8f | ||
|
|
33166e1d43 | ||
|
|
1cd7ca71a1 | ||
|
|
c9150c361b | ||
|
|
cd9ace20b6 |
@@ -102,6 +102,7 @@ class Normalize(nn.Module):
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
std_epsilon: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -120,11 +121,15 @@ class Normalize(nn.Module):
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
std_epsilon (float, optional): A small minimal value for the standard deviation to avoid division by
|
||||
zero. Default is `1e-5`. We use `clamp_min` to make sure the standard deviation (or the difference
|
||||
between min and max) is at least `std_epsilon`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
self.std_epsilon = std_epsilon
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
@@ -132,6 +137,7 @@ class Normalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
output_batch = {}
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
@@ -140,19 +146,25 @@ class Normalize(nn.Module):
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
output_batch[key] = (batch[key] - mean) / std.clamp_min(self.std_epsilon)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# To add to have at least std_epsilon between min and max
|
||||
epsilon = (self.std_epsilon - (max - min).abs()).clamp_min(0)
|
||||
max = max + epsilon
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
output_batch[key] = (batch[key] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
output_batch[key] = output_batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
for key in batch:
|
||||
if key not in output_batch:
|
||||
output_batch[key] = batch[key]
|
||||
return output_batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
@@ -166,6 +178,7 @@ class Unnormalize(nn.Module):
|
||||
shapes: dict[str, list[int]],
|
||||
modes: dict[str, str],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
std_epsilon: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -184,11 +197,16 @@ class Unnormalize(nn.Module):
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
std_epsilon (float, optional): A small minimal value for the standard deviation to avoid division by
|
||||
zero in the Normalize step. We use the same value for unnormalization here to have a consistent
|
||||
behavior. Default is `1e-5`. We use `clamp_min` to make sure the standard deviation (or the difference
|
||||
between min and max) is at least `std_epsilon`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
self.std_epsilon = std_epsilon
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
@@ -197,6 +215,7 @@ class Unnormalize(nn.Module):
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
output_batch = {}
|
||||
for key, mode in self.modes.items():
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
@@ -205,14 +224,20 @@ class Unnormalize(nn.Module):
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
output_batch[key] = batch[key] * std.clamp_min(self.std_epsilon) + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
# To add to have at least std_epsilon between min and max
|
||||
epsilon = (self.std_epsilon - (max - min).abs()).clamp_min(0)
|
||||
max = max + epsilon
|
||||
output_batch[key] = (batch[key] + 1) / 2
|
||||
output_batch[key] = output_batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
for key in batch:
|
||||
if key not in output_batch:
|
||||
output_batch[key] = batch[key]
|
||||
return output_batch
|
||||
|
||||
@@ -232,17 +232,33 @@ def test_normalize(insert_temporal_dim):
|
||||
input_shapes = {
|
||||
"observation.image": [3, 96, 96],
|
||||
"observation.state": [10],
|
||||
"action_test_std": [1],
|
||||
"action_test_min_max": [1],
|
||||
"action_test_std_cap": [2],
|
||||
"action_test_min_max_cap": [2],
|
||||
}
|
||||
output_shapes = {
|
||||
"action": [5],
|
||||
"action_test_std": [1],
|
||||
"action_test_min_max": [1],
|
||||
"action_test_std_cap": [2],
|
||||
"action_test_min_max_cap": [2],
|
||||
}
|
||||
|
||||
normalize_input_modes = {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"action_test_std": "mean_std",
|
||||
"action_test_min_max": "min_max",
|
||||
"action_test_std_cap": "mean_std",
|
||||
"action_test_min_max_cap": "min_max",
|
||||
}
|
||||
unnormalize_output_modes = {
|
||||
"action": "min_max",
|
||||
"action_test_std": "mean_std",
|
||||
"action_test_min_max": "min_max",
|
||||
"action_test_std_cap": "mean_std",
|
||||
"action_test_min_max_cap": "min_max",
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
@@ -264,15 +280,43 @@ def test_normalize(insert_temporal_dim):
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
"action_test_std": {
|
||||
"mean": torch.ones(1) * 2,
|
||||
"std": torch.ones(1) * 0.2,
|
||||
},
|
||||
"action_test_min_max": {
|
||||
"min": torch.ones(1) * 1,
|
||||
"max": torch.ones(1) * 3,
|
||||
},
|
||||
"action_test_std_cap": {
|
||||
"mean": torch.ones(2) * 2,
|
||||
"std": torch.ones(2) * 0.2,
|
||||
},
|
||||
"action_test_min_max_cap": {
|
||||
"min": torch.ones(2) * 1.0,
|
||||
"max": torch.ones(2) * 3.0,
|
||||
},
|
||||
}
|
||||
|
||||
# Set some values to 0 to test the case where the std is 0 - for max we set it to min
|
||||
dataset_stats["action_test_std_cap"]["std"][1] = 0.0
|
||||
dataset_stats["action_test_min_max_cap"]["max"][1] = dataset_stats["action_test_min_max_cap"]["min"][1]
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
"action_test_std": torch.ones(bsize, 1) * 2.5,
|
||||
"action_test_min_max": torch.ones(bsize, 1) * 2.5,
|
||||
"action_test_std_cap": torch.ones(bsize, 2) * 2.5,
|
||||
"action_test_min_max_cap": torch.ones(bsize, 2) * 2.5,
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
"action_test_std": torch.ones(bsize, 1) * 1.5,
|
||||
"action_test_min_max": torch.ones(bsize, 1) * 1.5,
|
||||
"action_test_std_cap": torch.ones(bsize, 2) * 1.5,
|
||||
"action_test_min_max_cap": torch.ones(bsize, 2) * 1.5,
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
@@ -291,8 +335,148 @@ def test_normalize(insert_temporal_dim):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
std_epsilon = 1e-2
|
||||
normalize = Normalize(input_shapes, normalize_input_modes, stats=dataset_stats, std_epsilon=std_epsilon)
|
||||
|
||||
# check that the stats are correctly set including the min capping
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_std.mean, dataset_stats["action_test_std"]["mean"], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_std.std, dataset_stats["action_test_std"]["std"], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_min_max.min,
|
||||
dataset_stats["action_test_min_max"]["min"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_min_max.max,
|
||||
dataset_stats["action_test_min_max"]["max"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_std_cap.std,
|
||||
dataset_stats["action_test_std_cap"]["std"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalize.buffer_action_test_min_max_cap.max - normalize.buffer_action_test_min_max_cap.min,
|
||||
dataset_stats["action_test_min_max_cap"]["max"] - dataset_stats["action_test_min_max_cap"]["min"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
normalized_output = normalize(input_batch)
|
||||
|
||||
# check that the normalization is correct
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std"],
|
||||
(input_batch["action_test_std"] - dataset_stats["action_test_std"]["mean"])
|
||||
/ dataset_stats["action_test_std"]["std"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max"],
|
||||
(input_batch["action_test_min_max"] - dataset_stats["action_test_min_max"]["min"])
|
||||
/ (dataset_stats["action_test_min_max"]["max"] - dataset_stats["action_test_min_max"]["min"])
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
if insert_temporal_dim:
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std_cap"][0, 0, 0],
|
||||
(input_batch["action_test_std_cap"][0, 0, 0] - dataset_stats["action_test_std_cap"]["mean"][0])
|
||||
/ dataset_stats["action_test_std_cap"]["std"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std_cap"][0, 0, 1],
|
||||
(input_batch["action_test_std_cap"][0, 0, 1] - dataset_stats["action_test_std_cap"]["mean"][1])
|
||||
/ std_epsilon,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max_cap"][0, 0, 0],
|
||||
(
|
||||
input_batch["action_test_min_max_cap"][0, 0, 0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
/ (
|
||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max_cap"][0, 0, 1],
|
||||
(
|
||||
input_batch["action_test_min_max_cap"][0, 0, 1]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][1]
|
||||
)
|
||||
/ std_epsilon
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
else:
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std_cap"][0, 0],
|
||||
(input_batch["action_test_std_cap"][0, 0] - dataset_stats["action_test_std_cap"]["mean"][0])
|
||||
/ dataset_stats["action_test_std_cap"]["std"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_std_cap"][0, 1],
|
||||
(input_batch["action_test_std_cap"][0, 1] - dataset_stats["action_test_std_cap"]["mean"][1])
|
||||
/ std_epsilon,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max_cap"][0, 0],
|
||||
(
|
||||
input_batch["action_test_min_max_cap"][0, 0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
/ (
|
||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
normalized_output["action_test_min_max_cap"][0, 1],
|
||||
(
|
||||
input_batch["action_test_min_max_cap"][0, 1]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][1]
|
||||
)
|
||||
/ std_epsilon
|
||||
* 2
|
||||
- 1,
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_shapes, normalize_input_modes, stats=None)
|
||||
@@ -305,8 +489,133 @@ def test_normalize(insert_temporal_dim):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
unnormalize = Unnormalize(
|
||||
output_shapes, unnormalize_output_modes, stats=dataset_stats, std_epsilon=std_epsilon
|
||||
)
|
||||
|
||||
# check that the stats are correctly set including the min capping
|
||||
assert torch.isclose(
|
||||
unnormalize.buffer_action_test_std.mean, dataset_stats["action_test_std"]["mean"], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize.buffer_action_test_std.std, dataset_stats["action_test_std"]["std"], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize.buffer_action_test_min_max.min,
|
||||
dataset_stats["action_test_min_max"]["min"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize.buffer_action_test_min_max.max,
|
||||
dataset_stats["action_test_min_max"]["max"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
unnormalize.buffer_action_test_std_cap.std,
|
||||
dataset_stats["action_test_std_cap"]["std"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize.buffer_action_test_min_max_cap.max - unnormalize.buffer_action_test_min_max_cap.min,
|
||||
dataset_stats["action_test_min_max_cap"]["max"] - dataset_stats["action_test_min_max_cap"]["min"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
unnormalize_output = unnormalize(output_batch)
|
||||
|
||||
# check that the unnormalization is correct
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_std"],
|
||||
output_batch["action_test_std"] * dataset_stats["action_test_std"]["std"]
|
||||
+ dataset_stats["action_test_std"]["mean"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_min_max"],
|
||||
(output_batch["action_test_min_max"] + 1)
|
||||
/ 2
|
||||
* (dataset_stats["action_test_min_max"]["max"] - dataset_stats["action_test_min_max"]["min"])
|
||||
+ dataset_stats["action_test_min_max"]["min"],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
if insert_temporal_dim:
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_std_cap"][0, 0, 0],
|
||||
output_batch["action_test_std_cap"][0, 0, 0] * dataset_stats["action_test_std_cap"]["std"][0]
|
||||
+ dataset_stats["action_test_std_cap"]["mean"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_std_cap"][0, 0, 1],
|
||||
output_batch["action_test_std_cap"][0, 0, 1] * std_epsilon
|
||||
+ dataset_stats["action_test_std_cap"]["mean"][1],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_min_max_cap"][0, 0, 0],
|
||||
(output_batch["action_test_min_max_cap"][0, 0, 0] + 1)
|
||||
/ 2
|
||||
* (
|
||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
+ dataset_stats["action_test_min_max_cap"]["min"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_min_max_cap"][0, 0, 1],
|
||||
(output_batch["action_test_min_max_cap"][0, 0, 1] + 1) / 2 * std_epsilon
|
||||
+ dataset_stats["action_test_min_max_cap"]["min"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
else:
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_std_cap"][0, 0],
|
||||
output_batch["action_test_std_cap"][0, 0] * dataset_stats["action_test_std_cap"]["std"][0]
|
||||
+ dataset_stats["action_test_std_cap"]["mean"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_std_cap"][0, 1],
|
||||
output_batch["action_test_std_cap"][0, 1] * std_epsilon
|
||||
+ dataset_stats["action_test_std_cap"]["mean"][1],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_min_max_cap"][0, 0],
|
||||
(output_batch["action_test_min_max_cap"][0, 0] + 1)
|
||||
/ 2
|
||||
* (
|
||||
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||
)
|
||||
+ dataset_stats["action_test_min_max_cap"]["min"][0],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
assert torch.isclose(
|
||||
unnormalize_output["action_test_min_max_cap"][0, 1],
|
||||
(output_batch["action_test_min_max_cap"][0, 1] + 1) / 2 * std_epsilon
|
||||
+ dataset_stats["action_test_min_max_cap"]["min"][1],
|
||||
rtol=0.1,
|
||||
atol=1e-7,
|
||||
).all()
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||
|
||||
Reference in New Issue
Block a user