forked from tangger/lerobot
fix(scripts): parser instead of draccus in record + add __get_path_fields__() to RecordConfig (#1155)
This commit is contained in:
@@ -38,7 +38,6 @@ from dataclasses import asdict, dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
import draccus
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import rerun as rr
|
import rerun as rr
|
||||||
|
|
||||||
@@ -151,6 +150,11 @@ class RecordConfig:
|
|||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||||
self.policy.pretrained_path = policy_path
|
self.policy.pretrained_path = policy_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||||
|
return ["policy"]
|
||||||
|
|
||||||
|
|
||||||
@safe_stop_image_writer
|
@safe_stop_image_writer
|
||||||
def record_loop(
|
def record_loop(
|
||||||
@@ -220,7 +224,7 @@ def record_loop(
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@draccus.wrap()
|
@parser.wrap()
|
||||||
def record(cfg: RecordConfig) -> LeRobotDataset:
|
def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||||
init_logging()
|
init_logging()
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|||||||
Reference in New Issue
Block a user