backup wip

This commit is contained in:
Alexander Soare
2024-03-19 16:02:09 +00:00
parent 88347965c2
commit ea17f4ce50
11 changed files with 71 additions and 46 deletions

Binary file not shown.

View File

@@ -1,4 +1,5 @@
from omegaconf import open_dict
import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
@@ -7,7 +8,8 @@ from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase
from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.abstract import AbstractPolicy
from .utils import DEVICE, init_config
@@ -30,7 +32,19 @@ def test_factory(env_name, policy_name):
f"device={DEVICE}",
]
)
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that we run select_action and get the appropriate output.
if env_name == "simxarm":
# TODO(rcadene): Not implemented
return
if policy_name == "tdmpc":
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
with open_dict(cfg):
cfg['n_obs_steps'] = 1
offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform)
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))
def test_abstract_policy_forward():