fix(scripts): parser instead of draccus in record + add __get_path_fields__() to RecordConfig (#1155)

This commit is contained in:
Steven Palma
2025-05-26 10:51:05 +02:00
committed by GitHub
parent 809a9c6de0
commit fb4bfaf029

View File

@@ -38,7 +38,6 @@ from dataclasses import asdict, dataclass
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
import draccus
import numpy as np import numpy as np
import rerun as rr import rerun as rr
@@ -151,6 +150,11 @@ class RecordConfig:
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
@safe_stop_image_writer @safe_stop_image_writer
def record_loop( def record_loop(
@@ -220,7 +224,7 @@ def record_loop(
break break
@draccus.wrap() @parser.wrap()
def record(cfg: RecordConfig) -> LeRobotDataset: def record(cfg: RecordConfig) -> LeRobotDataset:
init_logging() init_logging()
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))