forked from tangger/lerobot
Add Prod transform, Add test_factory
This commit is contained in:
22
lerobot/common/envs/transforms.py
Normal file
22
lerobot/common/envs/transforms.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Sequence
|
||||
|
||||
from tensordict.utils import NestedKey
|
||||
from torchrl.envs.transforms import ObservationTransform
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
|
||||
def _call(self, td):
|
||||
for key in self.in_keys:
|
||||
td[key] *= self.prod
|
||||
return td
|
||||
|
||||
def transform_observation_spec(self, obs_spec):
|
||||
for key in self.in_keys:
|
||||
obs_spec[key].space.high *= self.prod
|
||||
return obs_spec
|
||||
Reference in New Issue
Block a user