WIP
WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
@@ -2,7 +2,7 @@ import pytest
|
||||
from tensordict import TensorDict
|
||||
import torch
|
||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
from lerobot.common.envs.factory import make_env
|
||||
@@ -116,15 +116,15 @@ def test_factory(env_name):
|
||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
env = make_env(cfg)
|
||||
for key in offline_buffer.image_keys:
|
||||
for key in dataset.image_keys:
|
||||
assert env.reset().get(key).dtype == torch.uint8
|
||||
check_env_specs(env)
|
||||
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
for key in offline_buffer.image_keys:
|
||||
env = make_env(cfg, transform=dataset.transform)
|
||||
for key in dataset.image_keys:
|
||||
img = env.reset().get(key)
|
||||
assert img.dtype == torch.float32
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
|
||||
Reference in New Issue
Block a user