Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
@@ -86,8 +86,7 @@ def main():
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
loss = output_dict["loss"]
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
@@ -161,13 +161,13 @@ python lerobot/scripts/train.py \
|
||||
```
|
||||
You should see from the logging that your training picks up from where it left off.
|
||||
|
||||
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--offline.steps`, which is 100 000 by default.
|
||||
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
|
||||
You could double the number of steps of the previous run with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
|
||||
--resume=true \
|
||||
--offline.steps=200000
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
## Outputs of a run
|
||||
@@ -175,12 +175,16 @@ In the output directory, there will be a folder called `checkpoints` with the fo
|
||||
```bash
|
||||
outputs/train/run_resumption/checkpoints
|
||||
├── 000100 # checkpoint_dir for training step 100
|
||||
│ ├── pretrained_model
|
||||
│ │ ├── config.json # pretrained policy config
|
||||
│ │ ├── model.safetensors # model weights
|
||||
│ │ ├── train_config.json # train config
|
||||
│ │ └── README.md # model card
|
||||
│ └── training_state.pth # optimizer/scheduler/rng state and training step
|
||||
│ ├── pretrained_model/
|
||||
│ │ ├── config.json # policy config
|
||||
│ │ ├── model.safetensors # policy weights
|
||||
│ │ └── train_config.json # train config
|
||||
│ └── training_state/
|
||||
│ ├── optimizer_param_groups.json # optimizer param groups
|
||||
│ ├── optimizer_state.safetensors # optimizer state
|
||||
│ ├── rng_state.safetensors # rng states
|
||||
│ ├── scheduler_state.json # scheduler state
|
||||
│ └── training_step.json # training step
|
||||
├── 000200
|
||||
└── last -> 000200 # symlink to the last available checkpoint
|
||||
```
|
||||
@@ -250,7 +254,7 @@ python lerobot/scripts/train.py \
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=checkpoint/pretrained_model/ \
|
||||
--resume=true \
|
||||
--offline.steps=200000 # <- you can change some training parameters
|
||||
--steps=200000 # <- you can change some training parameters
|
||||
```
|
||||
|
||||
#### Fine-tuning
|
||||
|
||||
@@ -75,9 +75,9 @@ def main():
|
||||
n_examples_evaluated = 0
|
||||
for batch in val_dataloader:
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
output_dict = policy.forward(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
|
||||
loss_cumsum += output_dict["loss"].item()
|
||||
loss_cumsum += loss.item()
|
||||
n_examples_evaluated += batch["index"].shape[0]
|
||||
|
||||
# Calculate the average loss over the validation set.
|
||||
|
||||
Reference in New Issue
Block a user