Follow transformers single file naming conventions (#124)

This commit is contained in:
Alexander Soare
2024-05-01 13:09:42 +01:00
committed by GitHub
parent 986583dc5c
commit 01d5490d44
6 changed files with 62 additions and 58 deletions

View File

@@ -38,11 +38,11 @@ def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.act.modeling_act import ACTPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
policy_cfg = _policy_cfg_from_hydra_cfg(ACTConfig, hydra_cfg)
policy = ACTPolicy(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
else:
raise ValueError(hydra_cfg.policy.name)