Add policies/factory, Add test, Add _self_ in config

This commit is contained in:
Cadene
2024-02-25 10:50:23 +00:00
parent 64b5920e94
commit 598bb496b0
13 changed files with 61 additions and 38 deletions

View File

View File

@@ -0,0 +1,21 @@
from lerobot.common.policies.tdmpc import TDMPC
def make_policy(cfg):
if cfg.policy == "tdmpc":
policy = TDMPC(cfg)
else:
raise ValueError(cfg.policy)
if cfg.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
return policy

View File

@@ -5,7 +5,7 @@ import numpy as np
import torch
import torch.nn as nn
import lerobot.common.tdmpc_helper as h
import lerobot.common.policies.tdmpc_helper as h
class TOLD(nn.Module):