diff --git a/lerobot/record.py b/lerobot/record.py index 6ddeb23b7..733955b1b 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -38,7 +38,6 @@ from dataclasses import asdict, dataclass from pathlib import Path from pprint import pformat -import draccus import numpy as np import rerun as rr @@ -151,6 +150,11 @@ class RecordConfig: self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] + @safe_stop_image_writer def record_loop( @@ -220,7 +224,7 @@ def record_loop( break -@draccus.wrap() +@parser.wrap() def record(cfg: RecordConfig) -> LeRobotDataset: init_logging() logging.info(pformat(asdict(cfg)))