Add Aloha env and ACT policy

WIP Aloha env tests pass

Rendering works (fps look fast tho? TODO action bounding is too wide [-1,1])

Update README

Copy past from act repo

Remove download.py add a WIP for Simxarm

Remove download.py add a WIP for Simxarm

Add act yaml (TODO: try train.py)

Training can runs (TODO: eval)

Add tasks without end_effector that are compatible with dataset, Eval can run (TODO: training and pretrained model)

Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)

poetry lock

fix bug in compute_stats for action normalization

fix more bugs in normalization

fix training

fix import

PushtEnv inheriates AbstractEnv, Improve factory Normalization

Add _make_env to EnvAbstract

Add call_rendering_hooks to pusht env

SimxarmEnv inherites from AbstractEnv (NOT TESTED)

Add aloha tests artifacts + update pusht stats

fix image normalization: before env was in [0,1] but dataset in [0,255], and now both in [0,255]

Small fix on simxarm

Add next to obs

Add top camera to Aloha env (TODO: make it compatible with set of cameras)

Add top camera to Aloha env (TODO: make it compatible with set of cameras)
This commit is contained in:
Remi Cadene
2024-03-08 09:47:39 +00:00
committed by Cadene
parent 060bac7672
commit 9d002032d1
116 changed files with 3658 additions and 301 deletions

View File

@@ -1,5 +1,6 @@
from typing import Sequence
import torch
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
@@ -7,19 +8,45 @@ from torchrl.envs.transforms import ObservationTransform, Transform
class Prod(ObservationTransform):
invertible = True
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td):
for key in self.in_keys:
td[key] *= self.prod
if td.get(key, None) is None:
continue
self.original_dtypes[key] = td[key].dtype
td[key] = td[key].type(torch.float32) * self.prod
return td
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
for key in self.in_keys:
if td.get(key, None) is None:
continue
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
return td
def transform_observation_spec(self, obs_spec):
for key in self.in_keys:
obs_spec[key].space.high *= self.prod
if obs_spec.get(key, None) is None:
continue
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
obs_spec[key].dtype = torch.float32
return obs_spec