Add Aloha env and ACT policy

WIP Aloha env tests pass

Rendering works (fps look fast tho? TODO action bounding is too wide [-1,1])

Update README

Copy past from act repo

Remove download.py add a WIP for Simxarm

Remove download.py add a WIP for Simxarm

Add act yaml (TODO: try train.py)

Training can runs (TODO: eval)

Add tasks without end_effector that are compatible with dataset, Eval can run (TODO: training and pretrained model)

Add AbstractEnv, Refactor AlohaEnv, Add rendering_hook in env, Minor modifications, (TODO: Refactor Pusht and Simxarm)

poetry lock

fix bug in compute_stats for action normalization

fix more bugs in normalization

fix training

fix import

PushtEnv inheriates AbstractEnv, Improve factory Normalization

Add _make_env to EnvAbstract

Add call_rendering_hooks to pusht env

SimxarmEnv inherites from AbstractEnv (NOT TESTED)

Add aloha tests artifacts + update pusht stats

fix image normalization: before env was in [0,1] but dataset in [0,255], and now both in [0,255]

Small fix on simxarm

Add next to obs

Add top camera to Aloha env (TODO: make it compatible with set of cameras)

Add top camera to Aloha env (TODO: make it compatible with set of cameras)
This commit is contained in:
Remi Cadene
2024-03-08 09:47:39 +00:00
committed by Cadene
parent 060bac7672
commit 9d002032d1
116 changed files with 3658 additions and 301 deletions

View File

@@ -81,7 +81,10 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def set_transform(self, transform):
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
self._transform = Compose(transform)
if isinstance(transform, list):
self._transform = Compose(*transform)
else:
self._transform = Compose(transform)
else:
self._transform = transform

View File

@@ -73,11 +73,11 @@ def download(data_dir, dataset_id):
data_dir.mkdir(parents=True, exist_ok=True)
gdown.download_folder(FOLDER_URLS[dataset_id], output=data_dir)
gdown.download_folder(FOLDER_URLS[dataset_id], output=str(data_dir))
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
gdown.download(EP48_URLS[dataset_id], output=data_dir / "episode_48.hdf5", fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=data_dir / "episode_49.hdf5", fuzzy=True)
gdown.download(EP48_URLS[dataset_id], output=str(data_dir / "episode_48.hdf5"), fuzzy=True)
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
class AlohaExperienceReplay(AbstractExperienceReplay):
@@ -124,9 +124,6 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
# def _is_downloaded(self) -> bool:
# return False
def _download_and_preproc(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw"
if not raw_dir.is_dir():

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import torch
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.envs.transforms import NormalizeTransform
from lerobot.common.envs.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ.get("DATA_DIR", "data"))
@@ -84,6 +84,16 @@ def make_offline_buffer(
prefetch=prefetch if isinstance(prefetch, int) else None,
)
if cfg.policy.name == "tdmpc":
img_keys = []
for key in offline_buffer.image_keys:
img_keys.append(("next", *key))
img_keys += offline_buffer.image_keys
else:
img_keys = offline_buffer.image_keys
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
@@ -92,11 +102,10 @@ def make_offline_buffer(
in_keys = [("observation", "state"), ("action")]
if cfg.policy.name == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
# since we use next observations in tdmpc
in_keys.append(("next", *key))
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
in_keys += img_keys
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
in_keys += [("next", *key) for key in img_keys]
in_keys.append(("next", "observation", "state"))
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
@@ -106,8 +115,11 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
transform = NormalizeTransform(stats, in_keys, mode="min_max")
offline_buffer.set_transform(transform)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
offline_buffer.set_transform(transforms)
if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1)

View File

@@ -1,4 +1,5 @@
import pickle
import zipfile
from pathlib import Path
from typing import Callable
@@ -15,6 +16,22 @@ from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractExperienceReplay
def download():
raise NotImplementedError()
import gdown
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
download_path = "data.zip"
gdown.download(url, download_path, quiet=False)
print("Extracting...")
with zipfile.ZipFile(download_path, "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
Path(download_path).unlink()
class SimxarmExperienceReplay(AbstractExperienceReplay):
available_datasets = [
"xarm_lift_medium",
@@ -48,8 +65,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
)
def _download_and_preproc(self):
# download
# TODO(rcadene)
# TODO(rcadene): finish download
download()
dataset_path = self.data_dir / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")