backup wip
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -1,4 +1,5 @@
|
||||
|
||||
from omegaconf import open_dict
|
||||
import pytest
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictModule
|
||||
@@ -7,7 +8,8 @@ from torchrl.data import UnboundedContinuousTensorSpec
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
from .utils import DEVICE, init_config
|
||||
@@ -30,7 +32,19 @@ def test_factory(env_name, policy_name):
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
)
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
# Check that we run select_action and get the appropriate output.
|
||||
if env_name == "simxarm":
|
||||
# TODO(rcadene): Not implemented
|
||||
return
|
||||
if policy_name == "tdmpc":
|
||||
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
|
||||
with open_dict(cfg):
|
||||
cfg['n_obs_steps'] = 1
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))
|
||||
|
||||
|
||||
def test_abstract_policy_forward():
|
||||
|
||||
Reference in New Issue
Block a user