fix(trainer): overrides device to the target device, for the device processor on the preprocessor (#1993)
* fix(trainer): overiddes device to the target defice, for device processor on preprocessor * Update src/lerobot/scripts/train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -183,6 +183,9 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
# Only provide dataset_stats when not resuming from saved processor state
|
# Only provide dataset_stats when not resuming from saved processor state
|
||||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||||
|
|
||||||
|
if cfg.policy.pretrained_path is not None:
|
||||||
|
processor_kwargs["preprocessor_overrides"] = {"device_processor": {"device": device.type}}
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
|
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user