Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
@@ -21,68 +21,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineConfig:
|
||||
steps: int = 100_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class OnlineConfig:
|
||||
"""
|
||||
The online training loop looks something like:
|
||||
|
||||
```python
|
||||
for i in range(steps):
|
||||
do_online_rollout_and_update_online_buffer()
|
||||
for j in range(steps_between_rollouts):
|
||||
batch = next(dataloader_with_offline_and_online_data)
|
||||
loss = policy(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
Note that the online training loop adopts most of the options from the offline loop unless specified
|
||||
otherwise.
|
||||
"""
|
||||
|
||||
steps: int = 0
|
||||
# How many episodes to collect at once when we reach the online rollout part of the training loop.
|
||||
rollout_n_episodes: int = 1
|
||||
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
|
||||
# the policy. Ideally you should set this to by an even divisor of rollout_n_episodes.
|
||||
rollout_batch_size: int = 1
|
||||
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
|
||||
steps_between_rollouts: int | None = None
|
||||
# The proportion of online samples (vs offline samples) to include in the online training batches.
|
||||
sampling_ratio: float = 0.5
|
||||
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
|
||||
env_seed: int | None = None
|
||||
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
|
||||
# FIFO.
|
||||
buffer_capacity: int | None = None
|
||||
# The minimum number of frames to have in the online buffer before commencing online training.
|
||||
# If buffer_seed_size > rollout_n_episodes, the rollout will be run multiple times until the
|
||||
# seed size condition is satisfied.
|
||||
buffer_seed_size: int = 0
|
||||
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
|
||||
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
|
||||
# + eval + environment rendering simultaneously.
|
||||
do_rollout_async: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.steps == 0:
|
||||
return
|
||||
|
||||
if self.steps_between_rollouts is None:
|
||||
raise ValueError(
|
||||
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
|
||||
)
|
||||
if self.env_seed is None:
|
||||
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
|
||||
if self.buffer_capacity is None:
|
||||
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
@@ -107,13 +45,12 @@ class TrainPipelineConfig(HubMixin):
|
||||
# Number of workers for the dataloader.
|
||||
num_workers: int = 4
|
||||
batch_size: int = 8
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
save_checkpoint: bool = True
|
||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||
save_freq: int = 20_000
|
||||
offline: OfflineConfig = field(default_factory=OfflineConfig)
|
||||
online: OnlineConfig = field(default_factory=OnlineConfig)
|
||||
use_policy_training_preset: bool = True
|
||||
optimizer: OptimizerConfig | None = None
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
@@ -168,11 +105,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
|
||||
if self.online.steps > 0:
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
if self.env is None:
|
||||
raise ValueError("An environment is required for online training")
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
@@ -185,6 +119,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return draccus.encode(self)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
Reference in New Issue
Block a user