Remove offline training, refactor train.py and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2025-02-11 10:36:06 +01:00
committed by GitHub
parent 334deb985d
commit 90e099b39f
40 changed files with 1515 additions and 935 deletions

View File

@@ -25,7 +25,7 @@ from lerobot.common.datasets.transforms import (
ImageTransformsConfig,
make_transform_from_config,
)
from lerobot.common.utils.utils import seeded_context
from lerobot.common.utils.random_utils import seeded_context
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"

View File

@@ -22,14 +22,14 @@ from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy, make_policy_config
from lerobot.common.utils.utils import set_global_seed
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
# TODO(rcadene, aliberts): env_name?
set_global_seed(1337)
set_seed(1337)
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
@@ -53,9 +53,9 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
)
batch = next(iter(dataloader))
output_dict = policy.forward(batch)
loss, output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
loss = output_dict["loss"]
output_dict["loss"] = loss
loss.backward()
grad_stats = {}