WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -7,7 +7,7 @@ from torchrl.envs import EnvBase
from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@@ -45,13 +45,13 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that we run select_actions and get the appropriate output.
offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform)
dataset = make_dataset(cfg)
env = make_env(cfg, transform=dataset.transform)
if env_name != "aloha":
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
# seq_length as a list is not supported for now.
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
policy.update(dataset, torch.tensor(0, device=DEVICE))
action = policy(
env.observation_spec.rand()["observation"].to(DEVICE),