Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||
return policy_cfg
|
||||
|
||||
|
||||
def make_policy(hydra_cfg: DictConfig):
|
||||
def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
||||
if hydra_cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
@@ -35,14 +35,14 @@ def make_policy(hydra_cfg: DictConfig):
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps)
|
||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
elif hydra_cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg)
|
||||
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||
else:
|
||||
raise ValueError(hydra_cfg.policy.name)
|
||||
|
||||
Reference in New Issue
Block a user