Tidy up yaml configs (#121)

This commit is contained in:
Alexander Soare
2024-04-30 16:08:59 +01:00
committed by GitHub
parent e4e739f4f8
commit 9d60dce6f3
21 changed files with 142 additions and 207 deletions

View File

@@ -14,12 +14,13 @@ def make_dataset(
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset.repo_id:
if cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
delta_timestamps = cfg.policy.get("delta_timestamps")
delta_timestamps = cfg.training.get("delta_timestamps")
if delta_timestamps is not None:
for key in delta_timestamps:
if isinstance(delta_timestamps[key], str):
@@ -28,7 +29,7 @@ def make_dataset(
# TODO(rcadene): add data augmentations
dataset = LeRobotDataset(
cfg.dataset.repo_id,
cfg.dataset_repo_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,

View File

@@ -29,9 +29,9 @@ class Logger:
self._job_name = job_name
self._model_dir = self._log_dir / "models"
self._buffer_dir = self._log_dir / "buffers"
self._save_model = cfg.save_model
self._save_model = cfg.training.save_model
self._disable_wandb_artifact = cfg.wandb.disable_artifact
self._save_buffer = cfg.save_buffer
self._save_buffer = cfg.training.get("save_buffer", False)
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg

View File

@@ -112,15 +112,6 @@ class ActionChunkingTransformerConfig:
dropout: float = 0.1
kl_weight: float = 10.0
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 8
lr: float = 1e-5
lr_backbone: float = 1e-5
weight_decay: float = 1e-4
grad_clip_norm: float = 10
utd: int = 1
def __post_init__(self):
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):

View File

@@ -119,15 +119,6 @@ class DiffusionConfig:
# ---
# TODO(alexander-soare): Remove these from the policy config.
batch_size: int = 64
grad_clip_norm: int = 10
lr: float = 1.0e-4
lr_scheduler: str = "cosine"
lr_warmup_steps: int = 500
adam_betas: tuple[float, float] = (0.95, 0.999)
adam_eps: float = 1.0e-8
adam_weight_decay: float = 1.0e-6
utd: int = 1
use_ema: bool = True
ema_update_after_step: int = 0
ema_min_alpha: float = 0.0

View File

@@ -35,7 +35,7 @@ def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps, dataset_stats)
policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig