Files
lerobot_piper/lerobot/common/envs/transforms.py
2024-02-20 14:22:16 +00:00

23 lines
582 B
Python

from typing import Sequence
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform
class Prod(ObservationTransform):
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
def _call(self, td):
for key in self.in_keys:
td[key] *= self.prod
return td
def transform_observation_spec(self, obs_spec):
for key in self.in_keys:
obs_spec[key].space.high *= self.prod
return obs_spec