WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -2,7 +2,7 @@ import pytest
from tensordict import TensorDict
import torch
from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.factory import make_env
@@ -116,15 +116,15 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"],
)
offline_buffer = make_offline_buffer(cfg)
dataset = make_dataset(cfg)
env = make_env(cfg)
for key in offline_buffer.image_keys:
for key in dataset.image_keys:
assert env.reset().get(key).dtype == torch.uint8
check_env_specs(env)
env = make_env(cfg, transform=offline_buffer.transform)
for key in offline_buffer.image_keys:
env = make_env(cfg, transform=dataset.transform)
for key in dataset.image_keys:
img = env.reset().get(key)
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model