WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -1,53 +1,49 @@
from typing import Sequence
import torch
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform, Transform
from torchvision.transforms.v2 import Compose, Transform
class Prod(ObservationTransform):
def apply_inverse_transform(item, transform):
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
for tf in transforms[::-1]:
if tf.invertible:
item = tf.inverse_transform(item)
else:
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
return item
class Prod(Transform):
invertible = True
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
def __init__(self, in_keys: list[str], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td):
def forward(self, item):
for key in self.in_keys:
if td.get(key, None) is None:
if key not in item:
continue
self.original_dtypes[key] = td[key].dtype
td[key] = td[key].type(torch.float32) * self.prod
return td
self.original_dtypes[key] = item[key].dtype
item[key] = item[key].type(torch.float32) * self.prod
return item
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
def inverse_transform(self, item):
for key in self.in_keys:
if td.get(key, None) is None:
if key not in item:
continue
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
return td
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
return item
def transform_observation_spec(self, obs_spec):
for key in self.in_keys:
if obs_spec.get(key, None) is None:
continue
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
obs_spec[key].dtype = torch.float32
return obs_spec
# def transform_observation_spec(self, obs_spec):
# for key in self.in_keys:
# if obs_spec.get(key, None) is None:
# continue
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
# obs_spec[key].dtype = torch.float32
# return obs_spec
class NormalizeTransform(Transform):
@@ -55,65 +51,50 @@ class NormalizeTransform(Transform):
def __init__(
self,
stats: TensorDictBase,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] | None = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
stats: dict,
in_keys: list[str] = None,
out_keys: list[str] | None = None,
in_keys_inv: list[str] | None = None,
out_keys_inv: list[str] | None = None,
mode="mean_std",
):
if out_keys is None:
out_keys = in_keys
if in_keys_inv is None:
in_keys_inv = out_keys
if out_keys_inv is None:
out_keys_inv = in_keys
super().__init__(
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
)
super().__init__()
self.in_keys = in_keys
self.out_keys = in_keys if out_keys is None else out_keys
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
self.stats = stats
assert mode in ["mean_std", "min_max"]
self.mode = mode
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase:
def forward(self, item):
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
td[outkey] = (td[inkey] - min) / (max - min)
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
td[outkey] = td[outkey] * 2 - 1
return td
item[outkey] = item[outkey] * 2 - 1
return item
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
def inverse_transform(self, item):
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
td[outkey] = td[inkey] * std + mean
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
td[outkey] = (td[inkey] + 1) / 2
td[outkey] = td[outkey] * (max - min) + min
return td
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item