Add Prod transform, Add test_factory

This commit is contained in:
Cadene
2024-02-20 14:22:16 +00:00
parent 3da6ffb2cb
commit 3dc14b5576
5 changed files with 56 additions and 12 deletions

View File

@@ -0,0 +1,22 @@
from typing import Sequence
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform
class Prod(ObservationTransform):
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
def _call(self, td):
for key in self.in_keys:
td[key] *= self.prod
return td
def transform_observation_spec(self, obs_spec):
for key in self.in_keys:
obs_spec[key].space.high *= self.prod
return obs_spec