Remove offline training, refactor train.py and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2025-02-11 10:36:06 +01:00
committed by GitHub
parent 334deb985d
commit 90e099b39f
40 changed files with 1515 additions and 935 deletions

26
tests/fixtures/optimizers.py vendored Normal file
View File

@@ -0,0 +1,26 @@
import pytest
import torch
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
@pytest.fixture
def model_params():
return [torch.nn.Parameter(torch.randn(10, 10))]
@pytest.fixture
def optimizer(model_params):
optimizer = AdamConfig().build(model_params)
# Dummy step to populate state
loss = sum(param.sum() for param in model_params)
loss.backward()
optimizer.step()
return optimizer
@pytest.fixture
def scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
return config.build(optimizer, num_training_steps=100)