update unnormalize
This commit is contained in:
@@ -176,6 +176,7 @@ class Unnormalize(nn.Module):
|
|||||||
shapes: dict[str, list[int]],
|
shapes: dict[str, list[int]],
|
||||||
modes: dict[str, str],
|
modes: dict[str, str],
|
||||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
std_epsilon: float = 1e-5,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -194,19 +195,24 @@ class Unnormalize(nn.Module):
|
|||||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||||
|
std_epsilon (float, optional): A small minimal value for the standard deviation to avoid division by
|
||||||
|
zero in the Normalize step. We use the same value for unnormalization here to have a consistent
|
||||||
|
behavior. Default is `1e-5`. We use `clamp_min` to make sure the standard deviation (or the difference
|
||||||
|
between min and max) is at least `std_epsilon`.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shapes = shapes
|
self.shapes = shapes
|
||||||
self.modes = modes
|
self.modes = modes
|
||||||
self.stats = stats
|
self.stats = stats
|
||||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
stats_buffers = create_stats_buffers(shapes, modes, stats, std_epsilon=std_epsilon)
|
||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
# TODO(rcadene): should we remove torch.no_grad?
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
output_batch = {}
|
||||||
for key, mode in self.modes.items():
|
for key, mode in self.modes.items():
|
||||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||||
|
|
||||||
@@ -215,14 +221,14 @@ class Unnormalize(nn.Module):
|
|||||||
std = buffer["std"]
|
std = buffer["std"]
|
||||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||||
batch[key] = batch[key] * std + mean
|
output_batch[key] = batch[key] * std + mean
|
||||||
elif mode == "min_max":
|
elif mode == "min_max":
|
||||||
min = buffer["min"]
|
min = buffer["min"]
|
||||||
max = buffer["max"]
|
max = buffer["max"]
|
||||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||||
batch[key] = (batch[key] + 1) / 2
|
output_batch[key] = (batch[key] + 1) / 2
|
||||||
batch[key] = batch[key] * (max - min) + min
|
output_batch[key] = output_batch[key] * (max - min) + min
|
||||||
else:
|
else:
|
||||||
raise ValueError(mode)
|
raise ValueError(mode)
|
||||||
return batch
|
return output_batch
|
||||||
|
|||||||
@@ -285,10 +285,10 @@ def test_normalize(insert_temporal_dim):
|
|||||||
}
|
}
|
||||||
output_batch = {
|
output_batch = {
|
||||||
"action": torch.randn(bsize, 5),
|
"action": torch.randn(bsize, 5),
|
||||||
"action_test_std": torch.ones(bsize, 1) * 2.5,
|
"action_test_std": torch.ones(bsize, 1) * 1.5,
|
||||||
"action_test_min_max": torch.ones(bsize, 1) * 2.5,
|
"action_test_min_max": torch.ones(bsize, 1) * 1.5,
|
||||||
"action_test_std_cap": torch.ones(bsize, 2) * 2.5,
|
"action_test_std_cap": torch.ones(bsize, 2) * 1.5,
|
||||||
"action_test_min_max_cap": torch.ones(bsize, 2) * 2.5,
|
"action_test_min_max_cap": torch.ones(bsize, 2) * 1.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
if insert_temporal_dim:
|
if insert_temporal_dim:
|
||||||
@@ -471,8 +471,143 @@ def test_normalize(insert_temporal_dim):
|
|||||||
unnormalize(output_batch)
|
unnormalize(output_batch)
|
||||||
|
|
||||||
# test with stats
|
# test with stats
|
||||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
unnormalize = Unnormalize(
|
||||||
unnormalize(output_batch)
|
output_shapes, unnormalize_output_modes, stats=dataset_stats, std_epsilon=std_epsilon
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that the stats are correctly set including the min capping
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize.buffer_action_test_std.mean, dataset_stats["action_test_std"]["mean"], rtol=0.1, atol=1e-7
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize.buffer_action_test_std.std, dataset_stats["action_test_std"]["std"], rtol=0.1, atol=1e-7
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize.buffer_action_test_min_max.min,
|
||||||
|
dataset_stats["action_test_min_max"]["min"],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize.buffer_action_test_min_max.max,
|
||||||
|
dataset_stats["action_test_min_max"]["max"],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize.buffer_action_test_std_cap.std[0],
|
||||||
|
dataset_stats["action_test_std_cap"]["std"][0],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize.buffer_action_test_std_cap.std[1], torch.ones(1) * std_epsilon, rtol=0.1, atol=1e-7
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize.buffer_action_test_min_max_cap.max[0] - unnormalize.buffer_action_test_min_max_cap.min[0],
|
||||||
|
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||||
|
- dataset_stats["action_test_min_max_cap"]["min"][0],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize.buffer_action_test_min_max_cap.max[1] - unnormalize.buffer_action_test_min_max_cap.min[1],
|
||||||
|
torch.ones(1) * std_epsilon,
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
unnormalize_output = unnormalize(output_batch)
|
||||||
|
|
||||||
|
# check that the unnormalization is correct
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_std"],
|
||||||
|
output_batch["action_test_std"] * dataset_stats["action_test_std"]["std"]
|
||||||
|
+ dataset_stats["action_test_std"]["mean"],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_min_max"],
|
||||||
|
(output_batch["action_test_min_max"] + 1)
|
||||||
|
/ 2
|
||||||
|
* (dataset_stats["action_test_min_max"]["max"] - dataset_stats["action_test_min_max"]["min"])
|
||||||
|
+ dataset_stats["action_test_min_max"]["min"],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
if insert_temporal_dim:
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_std_cap"][0, 0, 0],
|
||||||
|
output_batch["action_test_std_cap"][0, 0, 0] * dataset_stats["action_test_std_cap"]["std"][0]
|
||||||
|
+ dataset_stats["action_test_std_cap"]["mean"][0],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_std_cap"][0, 0, 1],
|
||||||
|
output_batch["action_test_std_cap"][0, 0, 1] * std_epsilon
|
||||||
|
+ dataset_stats["action_test_std_cap"]["mean"][1],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_min_max_cap"][0, 0, 0],
|
||||||
|
(output_batch["action_test_min_max_cap"][0, 0, 0] + 1)
|
||||||
|
/ 2
|
||||||
|
* (
|
||||||
|
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||||
|
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||||
|
)
|
||||||
|
+ dataset_stats["action_test_min_max_cap"]["min"][0],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_min_max_cap"][0, 0, 1],
|
||||||
|
(output_batch["action_test_min_max_cap"][0, 0, 1] + 1) / 2 * std_epsilon
|
||||||
|
+ dataset_stats["action_test_min_max_cap"]["min"][0],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
else:
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_std_cap"][0, 0],
|
||||||
|
output_batch["action_test_std_cap"][0, 0] * dataset_stats["action_test_std_cap"]["std"][0]
|
||||||
|
+ dataset_stats["action_test_std_cap"]["mean"][0],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_std_cap"][0, 1],
|
||||||
|
output_batch["action_test_std_cap"][0, 1] * std_epsilon
|
||||||
|
+ dataset_stats["action_test_std_cap"]["mean"][1],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_min_max_cap"][0, 0],
|
||||||
|
(output_batch["action_test_min_max_cap"][0, 0] + 1)
|
||||||
|
/ 2
|
||||||
|
* (
|
||||||
|
dataset_stats["action_test_min_max_cap"]["max"][0]
|
||||||
|
- dataset_stats["action_test_min_max_cap"]["min"][0]
|
||||||
|
)
|
||||||
|
+ dataset_stats["action_test_min_max_cap"]["min"][0],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
assert torch.isclose(
|
||||||
|
unnormalize_output["action_test_min_max_cap"][0, 1],
|
||||||
|
(output_batch["action_test_min_max_cap"][0, 1] + 1) / 2 * std_epsilon
|
||||||
|
+ dataset_stats["action_test_min_max_cap"]["min"][1],
|
||||||
|
rtol=0.1,
|
||||||
|
atol=1e-7,
|
||||||
|
).all()
|
||||||
|
|
||||||
# test loading pretrained models
|
# test loading pretrained models
|
||||||
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
new_unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=None)
|
||||||
|
|||||||
Reference in New Issue
Block a user