diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 485fc927..5594d2f9 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -183,6 +183,9 @@ def train(cfg: TrainPipelineConfig): # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats + if cfg.policy.pretrained_path is not None: + processor_kwargs["preprocessor_overrides"] = {"device_processor": {"device": device.type}} + preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs )