feat(train): add accelerate for multi gpu training (#2154)
* Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. * Initialize logging in training script for both main and non-main processes - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions. * add docs and only push model once * Place logging under accelerate and update docs * fix pre commit * only log in main process * main logging * try with local rank * add tests * change runner * fix test * dont push to hub in multi gpu tests * pre download dataset in tests * small fixes * fix path optimizer state * update docs, and small improvements in train * simplify accelerate main process detection * small improvements in train * fix OOM bug * change accelerate detection * add some debugging * always use accelerate * cleanup update method * cleanup * fix bug * scale lr decay if we reduce steps * cleanup logging * fix formatting * encorperate feedback pr * add min memory to cpu tests * use accelerate to determin logging * fix precommit and fix tests * chore: minor details --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
33
.github/workflows/nightly.yml
vendored
33
.github/workflows/nightly.yml
vendored
@@ -119,6 +119,7 @@ jobs:
|
|||||||
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||||
container:
|
container:
|
||||||
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||||
|
options: --shm-size "16gb"
|
||||||
credentials:
|
credentials:
|
||||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||||
@@ -158,3 +159,35 @@ jobs:
|
|||||||
run: pytest tests -vv --maxfail=10
|
run: pytest tests -vv --maxfail=10
|
||||||
- name: Run end-to-end tests
|
- name: Run end-to-end tests
|
||||||
run: make test-end-to-end
|
run: make test-end-to-end
|
||||||
|
|
||||||
|
# This job runs multi-GPU training tests with 4 GPUs
|
||||||
|
nightly-multi-gpu-tests:
|
||||||
|
name: Nightly Multi-GPU Tests
|
||||||
|
needs: [build-docker-gpu-nightly]
|
||||||
|
runs-on:
|
||||||
|
group: aws-g4dn-12xlarge # Instance with 4 GPUs
|
||||||
|
env:
|
||||||
|
HF_HOME: /home/user_lerobot/.cache/huggingface
|
||||||
|
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
|
||||||
|
TORCH_HOME: /home/user_lerobot/.cache/torch
|
||||||
|
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
|
||||||
|
CUDA_VISIBLE_DEVICES: "0,1,2,3"
|
||||||
|
container:
|
||||||
|
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
|
||||||
|
options: --gpus all --shm-size "16gb"
|
||||||
|
credentials:
|
||||||
|
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: /lerobot
|
||||||
|
steps:
|
||||||
|
- name: Verify GPU availability
|
||||||
|
run: |
|
||||||
|
nvidia-smi
|
||||||
|
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||||
|
|
||||||
|
- name: Run multi-GPU training tests
|
||||||
|
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
|
||||||
|
timeout-minutes: 10
|
||||||
|
|||||||
@@ -17,6 +17,8 @@
|
|||||||
title: Train RL in Simulation
|
title: Train RL in Simulation
|
||||||
- local: async
|
- local: async
|
||||||
title: Use Async Inference
|
title: Use Async Inference
|
||||||
|
- local: multi_gpu_training
|
||||||
|
title: Multi GPU training
|
||||||
title: "Tutorials"
|
title: "Tutorials"
|
||||||
- sections:
|
- sections:
|
||||||
- local: lerobot-dataset-v3
|
- local: lerobot-dataset-v3
|
||||||
|
|||||||
125
docs/source/multi_gpu_training.mdx
Normal file
125
docs/source/multi_gpu_training.mdx
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# Multi-GPU Training
|
||||||
|
|
||||||
|
This guide shows you how to train policies on multiple GPUs using [Hugging Face Accelerate](https://huggingface.co/docs/accelerate).
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
First, ensure you have accelerate installed:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install accelerate
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training with Multiple GPUs
|
||||||
|
|
||||||
|
You can launch training in two ways:
|
||||||
|
|
||||||
|
### Option 1: Without config (specify parameters directly)
|
||||||
|
|
||||||
|
You can specify all parameters directly in the command without running `accelerate config`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch \
|
||||||
|
--multi_gpu \
|
||||||
|
--num_processes=2 \
|
||||||
|
$(which lerobot-train) \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||||
|
--policy.type=act \
|
||||||
|
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||||
|
--output_dir=outputs/train/act_multi_gpu \
|
||||||
|
--job_name=act_multi_gpu \
|
||||||
|
--wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key accelerate parameters:**
|
||||||
|
|
||||||
|
- `--multi_gpu`: Enable multi-GPU training
|
||||||
|
- `--num_processes=2`: Number of GPUs to use
|
||||||
|
- `--mixed_precision=fp16`: Use fp16 mixed precision (or `bf16` if supported)
|
||||||
|
|
||||||
|
### Option 2: Using accelerate config
|
||||||
|
|
||||||
|
If you prefer to save your configuration, you can optionally configure accelerate for your hardware setup by running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate config
|
||||||
|
```
|
||||||
|
|
||||||
|
This interactive setup will ask you questions about your training environment (number of GPUs, mixed precision settings, etc.) and saves the configuration for future use. For a simple multi-GPU setup on a single machine, you can use these recommended settings:
|
||||||
|
|
||||||
|
- Compute environment: This machine
|
||||||
|
- Number of machines: 1
|
||||||
|
- Number of processes: (number of GPUs you want to use)
|
||||||
|
- GPU ids to use: (leave empty to use all)
|
||||||
|
- Mixed precision: fp16 or bf16 (recommended for faster training)
|
||||||
|
|
||||||
|
Then launch training with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch $(which lerobot-train) \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||||
|
--policy.type=act \
|
||||||
|
--policy.repo_id=${HF_USER}/my_trained_policy \
|
||||||
|
--output_dir=outputs/train/act_multi_gpu \
|
||||||
|
--job_name=act_multi_gpu \
|
||||||
|
--wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
When you launch training with accelerate:
|
||||||
|
|
||||||
|
1. **Automatic detection**: LeRobot automatically detects if it's running under accelerate
|
||||||
|
2. **Data distribution**: Your batch is automatically split across GPUs
|
||||||
|
3. **Gradient synchronization**: Gradients are synchronized across GPUs during backpropagation
|
||||||
|
4. **Single process logging**: Only the main process logs to wandb and saves checkpoints
|
||||||
|
|
||||||
|
## Learning Rate and Training Steps Scaling
|
||||||
|
|
||||||
|
**Important:** LeRobot does **NOT** automatically scale learning rates or training steps based on the number of GPUs. This gives you full control over your training hyperparameters.
|
||||||
|
|
||||||
|
### Why No Automatic Scaling?
|
||||||
|
|
||||||
|
Many distributed training frameworks automatically scale the learning rate by the number of GPUs (e.g., `lr = base_lr × num_gpus`).
|
||||||
|
However, LeRobot keeps the learning rate exactly as you specify it.
|
||||||
|
|
||||||
|
### When and How to Scale
|
||||||
|
|
||||||
|
If you want to scale your hyperparameters when using multiple GPUs, you should do it manually:
|
||||||
|
|
||||||
|
**Learning Rate Scaling:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Example: 2 GPUs with linear LR scaling
|
||||||
|
# Base LR: 1e-4, with 2 GPUs -> 2e-4
|
||||||
|
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||||
|
--optimizer.lr=2e-4 \
|
||||||
|
--dataset.repo_id=lerobot/pusht \
|
||||||
|
--policy=act
|
||||||
|
```
|
||||||
|
|
||||||
|
**Training Steps Scaling:**
|
||||||
|
|
||||||
|
Since the effective batch size `bs` increases with multiple GPUs (batch_size × num_gpus), you may want to reduce the number of training steps proportionally:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Example: 2 GPUs with effective batch size 2x larger
|
||||||
|
# Original: batch_size=8, steps=100000
|
||||||
|
# With 2 GPUs: batch_size=8 (16 in total), steps=50000
|
||||||
|
accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||||
|
--batch_size=8 \
|
||||||
|
--steps=50000 \
|
||||||
|
--dataset.repo_id=lerobot/pusht \
|
||||||
|
--policy=act
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||||
|
- Training logs, checkpoints, and hub uploads are only done by the main process to avoid conflicts. Non-main processes have console logging disabled to prevent duplicate output.
|
||||||
|
- The effective batch size is `batch_size × num_gpus`. If you use 4 GPUs with `--batch_size=8`, your effective batch size is 32.
|
||||||
|
- Learning rate scheduling is handled correctly across multiple processes—LeRobot sets `step_scheduler_with_optimizer=False` to prevent accelerate from adjusting scheduler steps based on the number of processes.
|
||||||
|
- When saving or pushing models, LeRobot automatically unwraps the model from accelerate's distributed wrapper to ensure compatibility.
|
||||||
|
- WandB integration automatically initializes only on the main process, preventing multiple runs from being created.
|
||||||
|
|
||||||
|
For more advanced configurations and troubleshooting, see the [Accelerate documentation](https://huggingface.co/docs/accelerate). If you want to learn more about how to train on a large number of GPUs, checkout this awesome guide: [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
|
||||||
@@ -62,6 +62,7 @@ dependencies = [
|
|||||||
"datasets>=4.0.0,<4.2.0",
|
"datasets>=4.0.0,<4.2.0",
|
||||||
"diffusers>=0.27.2,<0.36.0",
|
"diffusers>=0.27.2,<0.36.0",
|
||||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||||
|
"accelerate>=1.10.0,<2.0.0",
|
||||||
|
|
||||||
# Core dependencies
|
# Core dependencies
|
||||||
"setuptools>=71.0.0,<81.0.0",
|
"setuptools>=71.0.0,<81.0.0",
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -79,7 +80,11 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
|||||||
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
|
||||||
@dataclass
|
@dataclass
|
||||||
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||||
"""Used by Physical Intelligence to train Pi0"""
|
"""Used by Physical Intelligence to train Pi0.
|
||||||
|
|
||||||
|
Automatically scales warmup and decay steps if num_training_steps < num_decay_steps.
|
||||||
|
This ensures the learning rate schedule completes properly even with shorter training runs.
|
||||||
|
"""
|
||||||
|
|
||||||
num_warmup_steps: int
|
num_warmup_steps: int
|
||||||
num_decay_steps: int
|
num_decay_steps: int
|
||||||
@@ -87,23 +92,39 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
|||||||
decay_lr: float
|
decay_lr: float
|
||||||
|
|
||||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||||
del num_training_steps
|
# Auto-scale scheduler parameters if training steps are shorter than configured decay steps
|
||||||
|
actual_warmup_steps = self.num_warmup_steps
|
||||||
|
actual_decay_steps = self.num_decay_steps
|
||||||
|
|
||||||
|
if num_training_steps < self.num_decay_steps:
|
||||||
|
# Calculate scaling factor to fit the schedule into the available training steps
|
||||||
|
scale_factor = num_training_steps / self.num_decay_steps
|
||||||
|
actual_warmup_steps = int(self.num_warmup_steps * scale_factor)
|
||||||
|
actual_decay_steps = num_training_steps
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Auto-scaling LR scheduler: "
|
||||||
|
f"num_training_steps ({num_training_steps}) < num_decay_steps ({self.num_decay_steps}). "
|
||||||
|
f"Scaling warmup: {self.num_warmup_steps} → {actual_warmup_steps}, "
|
||||||
|
f"decay: {self.num_decay_steps} → {actual_decay_steps} "
|
||||||
|
f"(scale factor: {scale_factor:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
def lr_lambda(current_step):
|
def lr_lambda(current_step):
|
||||||
def linear_warmup_schedule(current_step):
|
def linear_warmup_schedule(current_step):
|
||||||
if current_step <= 0:
|
if current_step <= 0:
|
||||||
return 1 / (self.num_warmup_steps + 1)
|
return 1 / (actual_warmup_steps + 1)
|
||||||
frac = 1 - current_step / self.num_warmup_steps
|
frac = 1 - current_step / actual_warmup_steps
|
||||||
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
|
return (1 / (actual_warmup_steps + 1) - 1) * frac + 1
|
||||||
|
|
||||||
def cosine_decay_schedule(current_step):
|
def cosine_decay_schedule(current_step):
|
||||||
step = min(current_step, self.num_decay_steps)
|
step = min(current_step, actual_decay_steps)
|
||||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
|
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps))
|
||||||
alpha = self.decay_lr / self.peak_lr
|
alpha = self.decay_lr / self.peak_lr
|
||||||
decayed = (1 - alpha) * cosine_decay + alpha
|
decayed = (1 - alpha) * cosine_decay + alpha
|
||||||
return decayed
|
return decayed
|
||||||
|
|
||||||
if current_step < self.num_warmup_steps:
|
if current_step < actual_warmup_steps:
|
||||||
return linear_warmup_schedule(current_step)
|
return linear_warmup_schedule(current_step)
|
||||||
|
|
||||||
return cosine_decay_schedule(current_step)
|
return cosine_decay_schedule(current_step)
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ class PI0Config(PreTrainedConfig):
|
|||||||
optimizer_grad_clip_norm: float = 1.0
|
optimizer_grad_clip_norm: float = 1.0
|
||||||
|
|
||||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||||
|
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||||
|
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||||
scheduler_warmup_steps: int = 1_000
|
scheduler_warmup_steps: int = 1_000
|
||||||
scheduler_decay_steps: int = 30_000
|
scheduler_decay_steps: int = 30_000
|
||||||
scheduler_decay_lr: float = 2.5e-6
|
scheduler_decay_lr: float = 2.5e-6
|
||||||
|
|||||||
@@ -75,6 +75,8 @@ class PI05Config(PreTrainedConfig):
|
|||||||
optimizer_grad_clip_norm: float = 1.0
|
optimizer_grad_clip_norm: float = 1.0
|
||||||
|
|
||||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||||
|
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||||
|
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||||
scheduler_warmup_steps: int = 1_000
|
scheduler_warmup_steps: int = 1_000
|
||||||
scheduler_decay_steps: int = 30_000
|
scheduler_decay_steps: int = 30_000
|
||||||
scheduler_decay_lr: float = 2.5e-6
|
scheduler_decay_lr: float = 2.5e-6
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class WandBLogger:
|
|||||||
cfg.wandb.run_id = run_id
|
cfg.wandb.run_id = run_id
|
||||||
# Handle custom step key for rl asynchronous training.
|
# Handle custom step key for rl asynchronous training.
|
||||||
self._wandb_custom_step_key: set[str] | None = None
|
self._wandb_custom_step_key: set[str] | None = None
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
logging.info(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ from pprint import pformat
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.amp import GradScaler
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
@@ -34,7 +34,6 @@ from lerobot.envs.utils import close_envs
|
|||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
|
||||||
from lerobot.rl.wandb_utils import WandBLogger
|
from lerobot.rl.wandb_utils import WandBLogger
|
||||||
from lerobot.scripts.lerobot_eval import eval_policy_all
|
from lerobot.scripts.lerobot_eval import eval_policy_all
|
||||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
@@ -48,7 +47,6 @@ from lerobot.utils.train_utils import (
|
|||||||
)
|
)
|
||||||
from lerobot.utils.utils import (
|
from lerobot.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
|
||||||
has_method,
|
has_method,
|
||||||
init_logging,
|
init_logging,
|
||||||
)
|
)
|
||||||
@@ -60,16 +58,15 @@ def update_policy(
|
|||||||
batch: Any,
|
batch: Any,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
grad_clip_norm: float,
|
grad_clip_norm: float,
|
||||||
grad_scaler: GradScaler,
|
accelerator: Accelerator,
|
||||||
lr_scheduler=None,
|
lr_scheduler=None,
|
||||||
use_amp: bool = False,
|
|
||||||
lock=None,
|
lock=None,
|
||||||
) -> tuple[MetricsTracker, dict]:
|
) -> tuple[MetricsTracker, dict]:
|
||||||
"""
|
"""
|
||||||
Performs a single training step to update the policy's weights.
|
Performs a single training step to update the policy's weights.
|
||||||
|
|
||||||
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
||||||
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
|
learning rate scheduler. Accelerator handles mixed-precision training automatically.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_metrics: A MetricsTracker instance to record training statistics.
|
train_metrics: A MetricsTracker instance to record training statistics.
|
||||||
@@ -77,9 +74,8 @@ def update_policy(
|
|||||||
batch: A batch of training data.
|
batch: A batch of training data.
|
||||||
optimizer: The optimizer used to update the policy's parameters.
|
optimizer: The optimizer used to update the policy's parameters.
|
||||||
grad_clip_norm: The maximum norm for gradient clipping.
|
grad_clip_norm: The maximum norm for gradient clipping.
|
||||||
grad_scaler: The GradScaler for automatic mixed-precision training.
|
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||||
lr_scheduler: An optional learning rate scheduler.
|
lr_scheduler: An optional learning rate scheduler.
|
||||||
use_amp: A boolean indicating whether to use automatic mixed precision.
|
|
||||||
lock: An optional lock for thread-safe optimizer updates.
|
lock: An optional lock for thread-safe optimizer updates.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -88,28 +84,27 @@ def update_policy(
|
|||||||
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
||||||
"""
|
"""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
device = get_device_from_parameters(policy)
|
|
||||||
policy.train()
|
policy.train()
|
||||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
|
||||||
|
# Let accelerator handle mixed precision
|
||||||
|
with accelerator.autocast():
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
grad_scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
# Use accelerator's backward method
|
||||||
grad_scaler.unscale_(optimizer)
|
accelerator.backward(loss)
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
# Clip gradients if specified
|
||||||
policy.parameters(),
|
if grad_clip_norm > 0:
|
||||||
grad_clip_norm,
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||||
error_if_nonfinite=False,
|
else:
|
||||||
)
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||||
|
)
|
||||||
|
|
||||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
# Optimizer step
|
||||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
|
||||||
with lock if lock is not None else nullcontext():
|
with lock if lock is not None else nullcontext():
|
||||||
grad_scaler.step(optimizer)
|
optimizer.step()
|
||||||
# Updates the scale for next iteration.
|
|
||||||
grad_scaler.update()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
@@ -117,9 +112,9 @@ def update_policy(
|
|||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
if has_method(policy, "update"):
|
# Update internal buffers if policy has update method
|
||||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
||||||
policy.update()
|
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||||
|
|
||||||
train_metrics.loss = loss.item()
|
train_metrics.loss = loss.item()
|
||||||
train_metrics.grad_norm = grad_norm.item()
|
train_metrics.grad_norm = grad_norm.item()
|
||||||
@@ -129,7 +124,7 @@ def update_policy(
|
|||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def train(cfg: TrainPipelineConfig):
|
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||||
"""
|
"""
|
||||||
Main function to train a policy.
|
Main function to train a policy.
|
||||||
|
|
||||||
@@ -143,41 +138,76 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||||
|
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||||
"""
|
"""
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
logging.info(pformat(cfg.to_dict()))
|
|
||||||
|
|
||||||
if cfg.wandb.enable and cfg.wandb.project:
|
# Create Accelerator if not provided
|
||||||
|
# It will automatically detect if running in distributed mode or single-process mode
|
||||||
|
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||||
|
# We set find_unused_parameters=True to handle models with conditional computation
|
||||||
|
if accelerator is None:
|
||||||
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
|
||||||
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
|
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||||
|
|
||||||
|
init_logging(accelerator=accelerator)
|
||||||
|
|
||||||
|
# Determine if this is the main process (for logging and checkpointing)
|
||||||
|
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||||
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
|
# Only log on main process
|
||||||
|
if is_main_process:
|
||||||
|
logging.info(pformat(cfg.to_dict()))
|
||||||
|
|
||||||
|
# Initialize wandb only on main process
|
||||||
|
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
|
||||||
wandb_logger = WandBLogger(cfg)
|
wandb_logger = WandBLogger(cfg)
|
||||||
else:
|
else:
|
||||||
wandb_logger = None
|
wandb_logger = None
|
||||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
if is_main_process:
|
||||||
|
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||||
|
|
||||||
if cfg.seed is not None:
|
if cfg.seed is not None:
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed, accelerator=accelerator)
|
||||||
|
|
||||||
# Check device is available
|
# Use accelerator's device
|
||||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
device = accelerator.device
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
logging.info("Creating dataset")
|
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||||
dataset = make_dataset(cfg)
|
if is_main_process:
|
||||||
|
logging.info("Creating dataset")
|
||||||
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# Now all other processes can safely load the dataset
|
||||||
|
if not is_main_process:
|
||||||
|
dataset = make_dataset(cfg)
|
||||||
|
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
eval_env = None
|
eval_env = None
|
||||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||||
logging.info("Creating env")
|
if is_main_process:
|
||||||
|
logging.info("Creating env")
|
||||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
logging.info("Creating policy")
|
if is_main_process:
|
||||||
|
logging.info("Creating policy")
|
||||||
policy = make_policy(
|
policy = make_policy(
|
||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Wait for all processes to finish policy creation before continuing
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
postprocessor_kwargs = {}
|
postprocessor_kwargs = {}
|
||||||
@@ -209,9 +239,9 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Creating optimizer and scheduler")
|
if is_main_process:
|
||||||
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
|
||||||
|
|
||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|
||||||
@@ -221,14 +251,18 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
if is_main_process:
|
||||||
if cfg.env is not None:
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||||
logging.info(f"{cfg.env.task=}")
|
if cfg.env is not None:
|
||||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
logging.info(f"{cfg.env.task=}")
|
||||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||||
logging.info(f"{dataset.num_episodes=}")
|
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
logging.info(f"{dataset.num_episodes=}")
|
||||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
num_processes = accelerator.num_processes
|
||||||
|
effective_bs = cfg.batch_size * num_processes
|
||||||
|
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
|
||||||
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||||
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||||
@@ -251,7 +285,13 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=device.type == "cuda",
|
pin_memory=device.type == "cuda",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
prefetch_factor=2,
|
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare everything with accelerator
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
policy, optimizer, dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
@@ -265,11 +305,20 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Use effective batch size for proper epoch calculation in distributed training
|
||||||
|
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||||
train_tracker = MetricsTracker(
|
train_tracker = MetricsTracker(
|
||||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
effective_batch_size,
|
||||||
|
dataset.num_frames,
|
||||||
|
dataset.num_episodes,
|
||||||
|
train_metrics,
|
||||||
|
initial_step=step,
|
||||||
|
accelerator=accelerator,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Start offline training on a fixed dataset")
|
if is_main_process:
|
||||||
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
|
||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
@@ -282,16 +331,15 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
batch,
|
batch,
|
||||||
optimizer,
|
optimizer,
|
||||||
cfg.optimizer.grad_clip_norm,
|
cfg.optimizer.grad_clip_norm,
|
||||||
grad_scaler=grad_scaler,
|
accelerator=accelerator,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
use_amp=cfg.policy.use_amp,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
# increment `step` here.
|
# increment `step` here.
|
||||||
step += 1
|
step += 1
|
||||||
train_tracker.step()
|
train_tracker.step()
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||||
|
|
||||||
@@ -305,69 +353,90 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
if cfg.save_checkpoint and is_saving_step:
|
if cfg.save_checkpoint and is_saving_step:
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
if is_main_process:
|
||||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
save_checkpoint(
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||||
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
|
save_checkpoint(
|
||||||
)
|
checkpoint_dir=checkpoint_dir,
|
||||||
update_last_checkpoint(checkpoint_dir)
|
step=step,
|
||||||
if wandb_logger:
|
cfg=cfg,
|
||||||
wandb_logger.log_policy(checkpoint_dir)
|
policy=accelerator.unwrap_model(policy),
|
||||||
|
optimizer=optimizer,
|
||||||
if cfg.env and is_eval_step:
|
scheduler=lr_scheduler,
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
|
||||||
logging.info(f"Eval policy at step {step}")
|
|
||||||
with (
|
|
||||||
torch.no_grad(),
|
|
||||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
|
||||||
):
|
|
||||||
eval_info = eval_policy_all(
|
|
||||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
n_episodes=cfg.eval.n_episodes,
|
|
||||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
|
||||||
max_episodes_rendered=4,
|
|
||||||
start_seed=cfg.seed,
|
|
||||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
|
||||||
)
|
)
|
||||||
# overall metrics (suite-agnostic)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
aggregated = eval_info["overall"]
|
if wandb_logger:
|
||||||
|
wandb_logger.log_policy(checkpoint_dir)
|
||||||
|
|
||||||
# optional: per-suite logging
|
accelerator.wait_for_everyone()
|
||||||
for suite, suite_info in eval_info.items():
|
|
||||||
logging.info("Suite %s aggregated: %s", suite, suite_info)
|
|
||||||
|
|
||||||
# meters/tracker
|
if cfg.env and is_eval_step:
|
||||||
eval_metrics = {
|
if is_main_process:
|
||||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
"pc_success": AverageMeter("success", ":.1f"),
|
logging.info(f"Eval policy at step {step}")
|
||||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
with torch.no_grad(), accelerator.autocast():
|
||||||
}
|
eval_info = eval_policy_all(
|
||||||
eval_tracker = MetricsTracker(
|
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
policy=accelerator.unwrap_model(policy),
|
||||||
)
|
preprocessor=preprocessor,
|
||||||
eval_tracker.eval_s = aggregated.pop("eval_s")
|
postprocessor=postprocessor,
|
||||||
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
n_episodes=cfg.eval.n_episodes,
|
||||||
eval_tracker.pc_success = aggregated.pop("pc_success")
|
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||||
if wandb_logger:
|
max_episodes_rendered=4,
|
||||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
start_seed=cfg.seed,
|
||||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||||
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
)
|
||||||
|
# overall metrics (suite-agnostic)
|
||||||
|
aggregated = eval_info["overall"]
|
||||||
|
|
||||||
|
# optional: per-suite logging
|
||||||
|
for suite, suite_info in eval_info.items():
|
||||||
|
logging.info("Suite %s aggregated: %s", suite, suite_info)
|
||||||
|
|
||||||
|
# meters/tracker
|
||||||
|
eval_metrics = {
|
||||||
|
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||||
|
"pc_success": AverageMeter("success", ":.1f"),
|
||||||
|
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||||
|
}
|
||||||
|
eval_tracker = MetricsTracker(
|
||||||
|
cfg.batch_size,
|
||||||
|
dataset.num_frames,
|
||||||
|
dataset.num_episodes,
|
||||||
|
eval_metrics,
|
||||||
|
initial_step=step,
|
||||||
|
accelerator=accelerator,
|
||||||
|
)
|
||||||
|
eval_tracker.eval_s = aggregated.pop("eval_s")
|
||||||
|
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
||||||
|
eval_tracker.pc_success = aggregated.pop("pc_success")
|
||||||
|
if wandb_logger:
|
||||||
|
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||||
|
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||||
|
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if eval_env:
|
if eval_env:
|
||||||
close_envs(eval_env)
|
close_envs(eval_env)
|
||||||
logging.info("End of training")
|
|
||||||
|
|
||||||
if cfg.policy.push_to_hub:
|
if is_main_process:
|
||||||
policy.push_model_to_hub(cfg)
|
logging.info("End of training")
|
||||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
|
||||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
if cfg.policy.push_to_hub:
|
||||||
|
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||||
|
unwrapped_policy.push_model_to_hub(cfg)
|
||||||
|
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
|
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
|
|
||||||
|
# Properly clean up the distributed process group
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
init_logging()
|
|
||||||
train()
|
train()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.utils.utils import format_big_number
|
from lerobot.utils.utils import format_big_number
|
||||||
@@ -84,6 +85,7 @@ class MetricsTracker:
|
|||||||
"samples",
|
"samples",
|
||||||
"episodes",
|
"episodes",
|
||||||
"epochs",
|
"epochs",
|
||||||
|
"accelerator",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -93,6 +95,7 @@ class MetricsTracker:
|
|||||||
num_episodes: int,
|
num_episodes: int,
|
||||||
metrics: dict[str, AverageMeter],
|
metrics: dict[str, AverageMeter],
|
||||||
initial_step: int = 0,
|
initial_step: int = 0,
|
||||||
|
accelerator: Callable | None = None,
|
||||||
):
|
):
|
||||||
self.__dict__.update(dict.fromkeys(self.__keys__))
|
self.__dict__.update(dict.fromkeys(self.__keys__))
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
@@ -106,6 +109,7 @@ class MetricsTracker:
|
|||||||
self.samples = self.steps * self._batch_size
|
self.samples = self.steps * self._batch_size
|
||||||
self.episodes = self.samples / self._avg_samples_per_ep
|
self.episodes = self.samples / self._avg_samples_per_ep
|
||||||
self.epochs = self.samples / self._num_frames
|
self.epochs = self.samples / self._num_frames
|
||||||
|
self.accelerator = accelerator
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
|
||||||
if name in self.__dict__:
|
if name in self.__dict__:
|
||||||
@@ -128,7 +132,7 @@ class MetricsTracker:
|
|||||||
Updates metrics that depend on 'step' for one step.
|
Updates metrics that depend on 'step' for one step.
|
||||||
"""
|
"""
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
self.samples += self._batch_size
|
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||||
self.episodes = self.samples / self._avg_samples_per_ep
|
self.episodes = self.samples / self._avg_samples_per_ep
|
||||||
self.epochs = self.samples / self._num_frames
|
self.epochs = self.samples / self._num_frames
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
from collections.abc import Generator
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]):
|
|||||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed) -> None:
|
def set_seed(seed, accelerator: Callable | None = None) -> None:
|
||||||
"""Set seed for reproducibility."""
|
"""Set seed for reproducibility."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
if accelerator:
|
||||||
|
from accelerate.utils import set_seed as _accelerate_set_seed
|
||||||
|
|
||||||
|
_accelerate_set_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from statistics import mean
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
||||||
|
|
||||||
|
|
||||||
@@ -110,36 +111,50 @@ def init_logging(
|
|||||||
display_pid: bool = False,
|
display_pid: bool = False,
|
||||||
console_level: str = "INFO",
|
console_level: str = "INFO",
|
||||||
file_level: str = "DEBUG",
|
file_level: str = "DEBUG",
|
||||||
|
accelerator: Accelerator | None = None,
|
||||||
):
|
):
|
||||||
|
"""Initialize logging configuration for LeRobot.
|
||||||
|
|
||||||
|
In multi-GPU training, only the main process logs to console to avoid duplicate output.
|
||||||
|
Non-main processes have console logging suppressed but can still log to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file: Optional file path to write logs to
|
||||||
|
display_pid: Include process ID in log messages (useful for debugging multi-process)
|
||||||
|
console_level: Logging level for console output
|
||||||
|
file_level: Logging level for file output
|
||||||
|
accelerator: Optional Accelerator instance (for multi-GPU detection)
|
||||||
|
"""
|
||||||
|
|
||||||
def custom_format(record: logging.LogRecord) -> str:
|
def custom_format(record: logging.LogRecord) -> str:
|
||||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
fnameline = f"{record.pathname}:{record.lineno}"
|
fnameline = f"{record.pathname}:{record.lineno}"
|
||||||
|
pid_str = f"[PID: {os.getpid()}] " if display_pid else ""
|
||||||
# NOTE: Display PID is useful for multi-process logging.
|
return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
||||||
if display_pid:
|
|
||||||
pid_str = f"[PID: {os.getpid()}]"
|
|
||||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
|
||||||
else:
|
|
||||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
|
|
||||||
return message
|
|
||||||
|
|
||||||
formatter = logging.Formatter()
|
formatter = logging.Formatter()
|
||||||
formatter.format = custom_format
|
formatter.format = custom_format
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages
|
logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
# Remove unused default handlers
|
# Clear any existing handlers
|
||||||
for handler in logger.handlers[:]:
|
logger.handlers.clear()
|
||||||
logger.removeHandler(handler)
|
|
||||||
|
|
||||||
# Write logs to console
|
# Determine if this is a non-main process in distributed training
|
||||||
console_handler = logging.StreamHandler()
|
is_main_process = accelerator.is_main_process if accelerator is not None else True
|
||||||
console_handler.setFormatter(formatter)
|
|
||||||
console_handler.setLevel(console_level.upper())
|
# Console logging (main process only)
|
||||||
logger.addHandler(console_handler)
|
if is_main_process:
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
console_handler.setLevel(console_level.upper())
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
else:
|
||||||
|
# Suppress console output for non-main processes
|
||||||
|
logger.addHandler(logging.NullHandler())
|
||||||
|
logger.setLevel(logging.ERROR)
|
||||||
|
|
||||||
# Additionally write logs to file
|
|
||||||
if log_file is not None:
|
if log_file is not None:
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
|||||||
211
tests/training/test_multi_gpu.py
Normal file
211
tests/training/test_multi_gpu.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Multi-GPU Training Tests
|
||||||
|
|
||||||
|
This module tests multi-GPU training functionality with accelerate.
|
||||||
|
These tests are designed to run on machines with 2+ GPUs and are executed
|
||||||
|
in the nightly CI workflow.
|
||||||
|
|
||||||
|
The tests automatically generate accelerate configs and launch training
|
||||||
|
with subprocess to properly test the distributed training environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_available_gpus():
|
||||||
|
"""Returns the number of available GPUs."""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return 0
|
||||||
|
return torch.cuda.device_count()
|
||||||
|
|
||||||
|
|
||||||
|
def download_dataset(repo_id, episodes):
|
||||||
|
"""
|
||||||
|
Pre-download dataset to avoid race conditions in multi-GPU training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: HuggingFace dataset repository ID
|
||||||
|
episodes: List of episode indices to download
|
||||||
|
"""
|
||||||
|
# Simply instantiating the dataset will download it
|
||||||
|
_ = LeRobotDataset(repo_id, episodes=episodes)
|
||||||
|
print(f"Dataset {repo_id} downloaded successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||||
|
"""
|
||||||
|
Helper function to run training with accelerate launch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_args: List of config arguments to pass to lerobot_train.py
|
||||||
|
num_processes: Number of processes (GPUs) to use
|
||||||
|
temp_dir: Temporary directory for outputs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
subprocess.CompletedProcess result
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_path = Path(temp_dir) / "accelerate_config.yaml"
|
||||||
|
|
||||||
|
# Write YAML config
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
f.write("compute_environment: LOCAL_MACHINE\n")
|
||||||
|
f.write("distributed_type: MULTI_GPU\n")
|
||||||
|
f.write("mixed_precision: 'no'\n")
|
||||||
|
f.write(f"num_processes: {num_processes}\n")
|
||||||
|
f.write("use_cpu: false\n")
|
||||||
|
f.write("gpu_ids: all\n")
|
||||||
|
f.write("downcast_bf16: 'no'\n")
|
||||||
|
f.write("machine_rank: 0\n")
|
||||||
|
f.write("main_training_function: main\n")
|
||||||
|
f.write("num_machines: 1\n")
|
||||||
|
f.write("rdzv_backend: static\n")
|
||||||
|
f.write("same_network: true\n")
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--config_file",
|
||||||
|
str(config_path),
|
||||||
|
"-m",
|
||||||
|
"lerobot.scripts.lerobot_train",
|
||||||
|
] + config_args
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
env={**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, range(num_processes)))},
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
get_num_available_gpus() < 2,
|
||||||
|
reason="Multi-GPU tests require at least 2 GPUs",
|
||||||
|
)
|
||||||
|
class TestMultiGPUTraining:
|
||||||
|
"""Test suite for multi-GPU training functionality."""
|
||||||
|
|
||||||
|
def test_basic_multi_gpu_training(self):
|
||||||
|
"""
|
||||||
|
Test that basic multi-GPU training runs successfully.
|
||||||
|
Verifies that the training completes without errors.
|
||||||
|
"""
|
||||||
|
# Pre-download dataset to avoid race conditions
|
||||||
|
download_dataset("lerobot/pusht", episodes=[0])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
output_dir = Path(temp_dir) / "outputs"
|
||||||
|
|
||||||
|
config_args = [
|
||||||
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
|
"--dataset.episodes=[0]",
|
||||||
|
"--policy.type=act",
|
||||||
|
"--policy.device=cuda",
|
||||||
|
"--policy.push_to_hub=false",
|
||||||
|
f"--output_dir={output_dir}",
|
||||||
|
"--batch_size=4",
|
||||||
|
"--steps=10",
|
||||||
|
"--eval_freq=-1",
|
||||||
|
"--log_freq=5",
|
||||||
|
"--save_freq=10",
|
||||||
|
"--seed=42",
|
||||||
|
"--num_workers=0",
|
||||||
|
]
|
||||||
|
|
||||||
|
result = run_accelerate_training(config_args, num_processes=4, temp_dir=temp_dir)
|
||||||
|
|
||||||
|
# Check that training completed successfully
|
||||||
|
assert result.returncode == 0, (
|
||||||
|
f"Multi-GPU training failed with return code {result.returncode}\n"
|
||||||
|
f"STDOUT:\n{result.stdout}\n"
|
||||||
|
f"STDERR:\n{result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify checkpoint was saved
|
||||||
|
checkpoints_dir = output_dir / "checkpoints"
|
||||||
|
assert checkpoints_dir.exists(), "Checkpoints directory was not created"
|
||||||
|
|
||||||
|
# Verify that training completed
|
||||||
|
assert "End of training" in result.stdout or "End of training" in result.stderr
|
||||||
|
|
||||||
|
def test_checkpoint_saving_multi_gpu(self):
|
||||||
|
"""
|
||||||
|
Test that checkpoints are correctly saved during multi-GPU training.
|
||||||
|
Only the main process (rank 0) should save checkpoints.
|
||||||
|
"""
|
||||||
|
# Pre-download dataset to avoid race conditions
|
||||||
|
download_dataset("lerobot/pusht", episodes=[0])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
output_dir = Path(temp_dir) / "outputs"
|
||||||
|
|
||||||
|
config_args = [
|
||||||
|
"--dataset.repo_id=lerobot/pusht",
|
||||||
|
"--dataset.episodes=[0]",
|
||||||
|
"--policy.type=act",
|
||||||
|
"--policy.device=cuda",
|
||||||
|
"--policy.push_to_hub=false",
|
||||||
|
f"--output_dir={output_dir}",
|
||||||
|
"--batch_size=4",
|
||||||
|
"--steps=20",
|
||||||
|
"--eval_freq=-1",
|
||||||
|
"--log_freq=5",
|
||||||
|
"--save_freq=10",
|
||||||
|
"--seed=42",
|
||||||
|
"--num_workers=0",
|
||||||
|
]
|
||||||
|
|
||||||
|
result = run_accelerate_training(config_args, num_processes=2, temp_dir=temp_dir)
|
||||||
|
|
||||||
|
assert result.returncode == 0, (
|
||||||
|
f"Training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify checkpoint directory exists
|
||||||
|
checkpoints_dir = output_dir / "checkpoints"
|
||||||
|
assert checkpoints_dir.exists(), "Checkpoints directory not created"
|
||||||
|
|
||||||
|
# Count checkpoint directories (should have checkpoint at step 10 and 20)
|
||||||
|
checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir()]
|
||||||
|
assert len(checkpoint_dirs) >= 1, f"Expected at least 1 checkpoint, found {len(checkpoint_dirs)}"
|
||||||
|
|
||||||
|
# Verify checkpoint contents
|
||||||
|
for checkpoint_dir in checkpoint_dirs:
|
||||||
|
# Check for model files
|
||||||
|
model_files = list(checkpoint_dir.rglob("*.safetensors"))
|
||||||
|
assert len(model_files) > 0, f"No model files in checkpoint {checkpoint_dir}"
|
||||||
|
|
||||||
|
# Check for training state
|
||||||
|
training_state_dir = checkpoint_dir / "training_state"
|
||||||
|
assert training_state_dir.exists(), f"No training state in checkpoint {checkpoint_dir}"
|
||||||
|
|
||||||
|
# Verify optimizer state exists
|
||||||
|
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
||||||
|
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
|
||||||
Reference in New Issue
Block a user