WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -6,6 +6,8 @@ from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.utils import init_hydra_config
import logging
from lerobot.common.datasets.factory import make_dataset
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@@ -26,14 +28,29 @@ def test_factory(env_name, dataset_id):
DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
)
offline_buffer = make_offline_buffer(cfg)
for key in offline_buffer.image_keys:
img = offline_buffer[0].get(key)
dataset = make_dataset(cfg)
item = dataset[0]
assert "action" in item
assert "episode" in item
assert "frame_id" in item
assert "timestamp" in item
assert "next.done" in item
# TODO(rcadene): should we rename it agent_pos?
assert "observation.state" in item
for key in dataset.image_keys:
img = item.get(key)
assert img.dtype == torch.float32
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert img.max() <= 1.0
assert img.min() >= 0.0
if "next.reward" not in item:
logging.warning(f'Missing "next.reward" key in dataset {dataset}.')
if "next.done" not in item:
logging.warning(f'Missing "next.done" key in dataset {dataset}.')
def test_compute_stats():
"""Check that the statistics are computed correctly according to the stats_patterns property.