Add mode to NormalizeTransform with mean_std or min_max (Not fully tested)

This commit is contained in:
Remi Cadene
2024-03-03 13:19:02 +00:00
parent 48ded3dbc7
commit cbbed590a9
4 changed files with 75 additions and 33 deletions

View File

@@ -28,11 +28,12 @@ class NormalizeTransform(Transform):
def __init__(
self,
mean_std: TensorDictBase,
stats: TensorDictBase,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
mode="mean_std",
):
if out_keys is None:
out_keys = in_keys
@@ -43,7 +44,14 @@ class NormalizeTransform(Transform):
super().__init__(
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
)
self.mean_std = mean_std
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
@@ -54,9 +62,17 @@ class NormalizeTransform(Transform):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
mean = self.mean_std[inkey]["mean"]
std = self.mean_std[inkey]["std"]
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
td[outkey] = (td[inkey] - min) / (max - min)
# normalize to [-1, 1]
td[outkey] = td[outkey] * 2 - 1
return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
@@ -64,7 +80,13 @@ class NormalizeTransform(Transform):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
mean = self.mean_std[inkey]["mean"]
std = self.mean_std[inkey]["std"]
td[outkey] = td[inkey] * std + mean
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
td[outkey] = td[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
td[outkey] = (td[inkey] + 1) / 2
td[outkey] = td[outkey] * (max - min) + min
return td