fix(trainer): overrides device to the target device, for the device processor on the preprocessor (#1993)
* fix(trainer): overiddes device to the target defice, for device processor on preprocessor * Update src/lerobot/scripts/train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -183,6 +183,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {"device_processor": {"device": device.type}}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user