WIP
WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
@@ -6,6 +6,8 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
import logging
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
@@ -26,14 +28,29 @@ def test_factory(env_name, dataset_id):
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
|
||||
)
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
for key in offline_buffer.image_keys:
|
||||
img = offline_buffer[0].get(key)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
assert "action" in item
|
||||
assert "episode" in item
|
||||
assert "frame_id" in item
|
||||
assert "timestamp" in item
|
||||
assert "next.done" in item
|
||||
# TODO(rcadene): should we rename it agent_pos?
|
||||
assert "observation.state" in item
|
||||
for key in dataset.image_keys:
|
||||
img = item.get(key)
|
||||
assert img.dtype == torch.float32
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
assert img.max() <= 1.0
|
||||
assert img.min() >= 0.0
|
||||
|
||||
if "next.reward" not in item:
|
||||
logging.warning(f'Missing "next.reward" key in dataset {dataset}.')
|
||||
if "next.done" not in item:
|
||||
logging.warning(f'Missing "next.done" key in dataset {dataset}.')
|
||||
|
||||
|
||||
def test_compute_stats():
|
||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
from tensordict import TensorDict
|
||||
import torch
|
||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
from lerobot.common.envs.factory import make_env
|
||||
@@ -116,15 +116,15 @@ def test_factory(env_name):
|
||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
env = make_env(cfg)
|
||||
for key in offline_buffer.image_keys:
|
||||
for key in dataset.image_keys:
|
||||
assert env.reset().get(key).dtype == torch.uint8
|
||||
check_env_specs(env)
|
||||
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
for key in offline_buffer.image_keys:
|
||||
env = make_env(cfg, transform=dataset.transform)
|
||||
for key in dataset.image_keys:
|
||||
img = env.reset().get(key)
|
||||
assert img.dtype == torch.float32
|
||||
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
||||
|
||||
@@ -7,7 +7,7 @@ from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
@@ -45,13 +45,13 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
dataset = make_dataset(cfg)
|
||||
env = make_env(cfg, transform=dataset.transform)
|
||||
|
||||
if env_name != "aloha":
|
||||
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
|
||||
# seq_length as a list is not supported for now.
|
||||
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
|
||||
policy.update(dataset, torch.tensor(0, device=DEVICE))
|
||||
|
||||
action = policy(
|
||||
env.observation_spec.rand()["observation"].to(DEVICE),
|
||||
|
||||
Reference in New Issue
Block a user