Remove offline training, refactor train.py and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2025-02-11 10:36:06 +01:00
committed by GitHub
parent 334deb985d
commit 90e099b39f
40 changed files with 1515 additions and 935 deletions

View File

@@ -21,68 +21,6 @@ from lerobot.configs.policies import PreTrainedConfig
TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class OfflineConfig:
steps: int = 100_000
@dataclass
class OnlineConfig:
"""
The online training loop looks something like:
```python
for i in range(steps):
do_online_rollout_and_update_online_buffer()
for j in range(steps_between_rollouts):
batch = next(dataloader_with_offline_and_online_data)
loss = policy(batch)
loss.backward()
optimizer.step()
```
Note that the online training loop adopts most of the options from the offline loop unless specified
otherwise.
"""
steps: int = 0
# How many episodes to collect at once when we reach the online rollout part of the training loop.
rollout_n_episodes: int = 1
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
# the policy. Ideally you should set this to by an even divisor of rollout_n_episodes.
rollout_batch_size: int = 1
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
steps_between_rollouts: int | None = None
# The proportion of online samples (vs offline samples) to include in the online training batches.
sampling_ratio: float = 0.5
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
env_seed: int | None = None
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
# FIFO.
buffer_capacity: int | None = None
# The minimum number of frames to have in the online buffer before commencing online training.
# If buffer_seed_size > rollout_n_episodes, the rollout will be run multiple times until the
# seed size condition is satisfied.
buffer_seed_size: int = 0
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
# + eval + environment rendering simultaneously.
do_rollout_async: bool = False
def __post_init__(self):
if self.steps == 0:
return
if self.steps_between_rollouts is None:
raise ValueError(
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
)
if self.env_seed is None:
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
if self.buffer_capacity is None:
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
@@ -107,13 +45,12 @@ class TrainPipelineConfig(HubMixin):
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
offline: OfflineConfig = field(default_factory=OfflineConfig)
online: OnlineConfig = field(default_factory=OnlineConfig)
use_policy_training_preset: bool = True
optimizer: OptimizerConfig | None = None
scheduler: LRSchedulerConfig | None = None
@@ -168,11 +105,8 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if self.online.steps > 0:
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
if self.env is None:
raise ValueError("An environment is required for online training")
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
@@ -185,6 +119,9 @@ class TrainPipelineConfig(HubMixin):
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def to_dict(self) -> dict:
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)