Add get_safe_torch_device in policies

This commit is contained in:
Simon Alibert
2024-03-20 18:38:55 +01:00
parent ec536ef0fa
commit 4631d36c05
6 changed files with 39 additions and 18 deletions

View File

@@ -7,6 +7,7 @@ import torchvision.transforms as transforms
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.act.detr_vae import build
from lerobot.common.utils import get_safe_torch_device
def build_act_model_and_optimizer(cfg):
@@ -45,7 +46,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
super().__init__(n_action_steps)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.device = device
self.device = get_safe_torch_device(device)
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
self.kl_weight = self.cfg.kl_weight
logging.info(f"KL Weight {self.kl_weight}")