Remove latency, tdmpc policy passes tests (TODO: make it work with online RL)

This commit is contained in:
Cadene
2024-04-07 16:01:22 +00:00
parent 44656d2706
commit 4371a5570d
8 changed files with 123 additions and 133 deletions

View File

@@ -1,11 +1,10 @@
def make_policy(cfg):
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(cfg.policy, cfg.device)
policy = TDMPCPolicy(
cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device
)
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
@@ -17,14 +16,18 @@ def make_policy(cfg):
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
cfg.policy,
cfg.device,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
)
else:
raise ValueError(cfg.policy.name)