WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -6,6 +6,8 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.utils import init_hydra_config
import logging
from lerobot.common.datasets.factory import make_dataset
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@@ -26,14 +28,29 @@ def test_factory(env_name, dataset_id):
DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
)
offline_buffer = make_offline_buffer(cfg)
for key in offline_buffer.image_keys:
img = offline_buffer[0].get(key)
dataset = make_dataset(cfg)
item = dataset[0]
assert "action" in item
assert "episode" in item
assert "frame_id" in item
assert "timestamp" in item
assert "next.done" in item
# TODO(rcadene): should we rename it agent_pos?
assert "observation.state" in item
for key in dataset.image_keys:
img = item.get(key)
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert img.max() <= 1.0
assert img.min() >= 0.0
if "next.reward" not in item:
logging.warning(f'Missing "next.reward" key in dataset {dataset}.')
if "next.done" not in item:
logging.warning(f'Missing "next.done" key in dataset {dataset}.')
def test_compute_stats():
"""Check that the statistics are computed correctly according to the stats_patterns property.

View File

@@ -2,7 +2,7 @@ import pytest
from tensordict import TensorDict
import torch
from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.factory import make_env
@@ -116,15 +116,15 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"],
)
offline_buffer = make_offline_buffer(cfg)
dataset = make_dataset(cfg)
env = make_env(cfg)
for key in offline_buffer.image_keys:
for key in dataset.image_keys:
assert env.reset().get(key).dtype == torch.uint8
check_env_specs(env)
env = make_env(cfg, transform=offline_buffer.transform)
for key in offline_buffer.image_keys:
env = make_env(cfg, transform=dataset.transform)
for key in dataset.image_keys:
img = env.reset().get(key)
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model

View File

@@ -7,7 +7,7 @@ from torchrl.envs import EnvBase
from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@@ -45,13 +45,13 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that we run select_actions and get the appropriate output.
offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform)
dataset = make_dataset(cfg)
env = make_env(cfg, transform=dataset.transform)
if env_name != "aloha":
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
# seq_length as a list is not supported for now.
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
policy.update(dataset, torch.tensor(0, device=DEVICE))
action = policy(
env.observation_spec.rand()["observation"].to(DEVICE),