pre-commit run -a

This commit is contained in:
Remi Cadene
2024-03-02 15:58:21 +00:00
parent 1ae6205269
commit 45b4ecb727
6 changed files with 44 additions and 43 deletions

View File

@@ -1,10 +1,9 @@
from typing import Sequence
from tensordict import TensorDictBase
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform
from torchrl.envs.transforms import Transform
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform, Transform
class Prod(ObservationTransform):
@@ -27,28 +26,31 @@ class Prod(ObservationTransform):
class NormalizeTransform(Transform):
invertible = True
def __init__(self,
mean_std: TensorDictBase,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
):
def __init__(
self,
mean_std: TensorDictBase,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
):
if out_keys is None:
out_keys = in_keys
if in_keys_inv is None:
in_keys_inv = out_keys
if out_keys_inv is None:
out_keys_inv = in_keys
super().__init__(in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv)
super().__init__(
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
)
self.mean_std = mean_std
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys, self.out_keys):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
@@ -58,7 +60,7 @@ class NormalizeTransform(Transform):
return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue