refactor(config): Move device & amp args to PreTrainedConfig (#812)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Steven Palma
2025-03-06 17:59:28 +01:00
committed by GitHub
parent 10706ed753
commit 5e9473806c
19 changed files with 62 additions and 136 deletions

View File

@@ -33,12 +33,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu",
)
train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)