forked from tangger/lerobot
add fixes for reproducibility only try to start env if it is closed revision fix normalization and data type Improve README Improve README Tests are passing, Eval pretrained model works, Add gif Update gif Update gif Update gif Update gif Update README Update README update minor Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Address suggestions Update thumbnail + stats Update thumbnail + stats Update README.md Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Add more comments Add test_examples.py
109 lines
2.9 KiB
Python
109 lines
2.9 KiB
Python
import pytest
|
|
from tensordict import TensorDict
|
|
import torch
|
|
from torchrl.envs.utils import check_env_specs, step_mdp
|
|
from lerobot.common.datasets.factory import make_offline_buffer
|
|
|
|
from lerobot.common.envs.factory import make_env
|
|
from lerobot.common.envs.pusht.env import PushtEnv
|
|
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
|
|
|
from .utils import DEVICE, init_config
|
|
|
|
|
|
def print_spec_rollout(env):
|
|
print("observation_spec:", env.observation_spec)
|
|
print("action_spec:", env.action_spec)
|
|
print("reward_spec:", env.reward_spec)
|
|
print("done_spec:", env.done_spec)
|
|
|
|
td = env.reset()
|
|
print("reset tensordict", td)
|
|
|
|
td = env.rand_step(td)
|
|
print("random step tensordict", td)
|
|
|
|
def simple_rollout(steps=100):
|
|
# preallocate:
|
|
data = TensorDict({}, [steps])
|
|
# reset
|
|
_data = env.reset()
|
|
for i in range(steps):
|
|
_data["action"] = env.action_spec.rand()
|
|
_data = env.step(_data)
|
|
data[i] = _data
|
|
_data = step_mdp(_data, keep_other=True)
|
|
return data
|
|
|
|
print("data from rollout:", simple_rollout(100))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"task,from_pixels,pixels_only",
|
|
[
|
|
("lift", False, False),
|
|
("lift", True, False),
|
|
("lift", True, True),
|
|
# TODO(aliberts): Add simxarm other tasks
|
|
# ("reach", False, False),
|
|
# ("reach", True, False),
|
|
# ("push", False, False),
|
|
# ("push", True, False),
|
|
# ("peg_in_box", False, False),
|
|
# ("peg_in_box", True, False),
|
|
],
|
|
)
|
|
def test_simxarm(task, from_pixels, pixels_only):
|
|
env = SimxarmEnv(
|
|
task,
|
|
from_pixels=from_pixels,
|
|
pixels_only=pixels_only,
|
|
image_size=84 if from_pixels else None,
|
|
)
|
|
# print_spec_rollout(env)
|
|
check_env_specs(env)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"from_pixels,pixels_only",
|
|
[
|
|
(True, False),
|
|
],
|
|
)
|
|
def test_pusht(from_pixels, pixels_only):
|
|
env = PushtEnv(
|
|
from_pixels=from_pixels,
|
|
pixels_only=pixels_only,
|
|
image_size=96 if from_pixels else None,
|
|
)
|
|
# print_spec_rollout(env)
|
|
check_env_specs(env)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_name",
|
|
[
|
|
"simxarm",
|
|
"pusht",
|
|
"aloha",
|
|
],
|
|
)
|
|
def test_factory(env_name):
|
|
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
|
|
|
|
offline_buffer = make_offline_buffer(cfg)
|
|
|
|
env = make_env(cfg)
|
|
for key in offline_buffer.image_keys:
|
|
assert env.reset().get(key).dtype == torch.uint8
|
|
check_env_specs(env)
|
|
|
|
env = make_env(cfg, transform=offline_buffer.transform)
|
|
for key in offline_buffer.image_keys:
|
|
img = env.reset().get(key)
|
|
assert img.dtype == torch.float32
|
|
# TODO(rcadene): we assume for now that image normalization takes place in the model
|
|
assert img.max() <= 1.0
|
|
assert img.min() >= 0.0
|
|
check_env_specs(env)
|