22 lines
581 B
Python
22 lines
581 B
Python
from typing import Sequence
|
|
|
|
from tensordict.utils import NestedKey
|
|
from torchrl.envs.transforms import ObservationTransform
|
|
|
|
|
|
class Prod(ObservationTransform):
|
|
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
|
super().__init__()
|
|
self.in_keys = in_keys
|
|
self.prod = prod
|
|
|
|
def _call(self, td):
|
|
for key in self.in_keys:
|
|
td[key] *= self.prod
|
|
return td
|
|
|
|
def transform_observation_spec(self, obs_spec):
|
|
for key in self.in_keys:
|
|
obs_spec[key].space.high *= self.prod
|
|
return obs_spec
|