Files
lerobot_piper/lerobot/common/transforms.py
Cadene 1cdfbc8b52 WIP
WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
2024-04-04 15:31:03 +00:00

101 lines
3.5 KiB
Python

import torch
from torchvision.transforms.v2 import Compose, Transform
def apply_inverse_transform(item, transform):
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
for tf in transforms[::-1]:
if tf.invertible:
item = tf.inverse_transform(item)
else:
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
return item
class Prod(Transform):
invertible = True
def __init__(self, in_keys: list[str], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def forward(self, item):
for key in self.in_keys:
if key not in item:
continue
self.original_dtypes[key] = item[key].dtype
item[key] = item[key].type(torch.float32) * self.prod
return item
def inverse_transform(self, item):
for key in self.in_keys:
if key not in item:
continue
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
return item
# def transform_observation_spec(self, obs_spec):
# for key in self.in_keys:
# if obs_spec.get(key, None) is None:
# continue
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
# obs_spec[key].dtype = torch.float32
# return obs_spec
class NormalizeTransform(Transform):
invertible = True
def __init__(
self,
stats: dict,
in_keys: list[str] = None,
out_keys: list[str] | None = None,
in_keys_inv: list[str] | None = None,
out_keys_inv: list[str] | None = None,
mode="mean_std",
):
super().__init__()
self.in_keys = in_keys
self.out_keys = in_keys if out_keys is None else out_keys
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def forward(self, item):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
item[outkey] = item[outkey] * 2 - 1
return item
def inverse_transform(self, item):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item