633 lines
25 KiB
Python
633 lines
25 KiB
Python
"""
|
|
PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
|
|
This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
|
|
entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
|
|
pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
|
|
|
|
Usage
|
|
Single GPU:
|
|
python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
|
Example:
|
|
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
|
|
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
|
|
Multi-GPU (single node):
|
|
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
|
|
Example:
|
|
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
|
|
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
|
|
Multi-Node Training:
|
|
torchrun \
|
|
--nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
|
|
--master_addr=<master_ip> --master_port=<port> \
|
|
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
|
|
|
"""
|
|
|
|
import dataclasses
|
|
import gc
|
|
import logging
|
|
import os
|
|
import platform
|
|
import shutil
|
|
import time
|
|
|
|
import jax
|
|
import numpy as np
|
|
import safetensors.torch
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.parallel
|
|
import tqdm
|
|
import wandb
|
|
|
|
import openpi.models.pi0_config
|
|
import openpi.models_pytorch.pi0_pytorch
|
|
import openpi.shared.normalize as _normalize
|
|
import openpi.training.config as _config
|
|
import openpi.training.data_loader as _data
|
|
|
|
|
|
def init_logging():
|
|
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
|
|
|
class CustomFormatter(logging.Formatter):
|
|
def format(self, record):
|
|
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
|
return super().format(record)
|
|
|
|
formatter = CustomFormatter(
|
|
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
if not logger.handlers:
|
|
ch = logging.StreamHandler()
|
|
ch.setFormatter(formatter)
|
|
logger.addHandler(ch)
|
|
else:
|
|
logger.handlers[0].setFormatter(formatter)
|
|
|
|
|
|
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
|
"""Initialize wandb logging."""
|
|
if not enabled:
|
|
wandb.init(mode="disabled")
|
|
return
|
|
|
|
ckpt_dir = config.checkpoint_dir
|
|
if not ckpt_dir.exists():
|
|
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
|
|
|
if resuming:
|
|
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
|
wandb.init(id=run_id, resume="must", project=config.project_name)
|
|
else:
|
|
wandb.init(
|
|
name=config.exp_name,
|
|
config=dataclasses.asdict(config),
|
|
project=config.project_name,
|
|
)
|
|
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
|
|
|
|
|
def setup_ddp():
|
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
use_ddp = world_size > 1
|
|
if use_ddp and not torch.distributed.is_initialized():
|
|
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
|
torch.distributed.init_process_group(backend=backend, init_method="env://")
|
|
|
|
# Set up debugging environment variables for DDP issues
|
|
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
|
|
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
|
|
|
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
|
|
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_device(device)
|
|
return use_ddp, local_rank, device
|
|
|
|
|
|
def cleanup_ddp():
|
|
if torch.distributed.is_initialized():
|
|
torch.distributed.barrier()
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
def set_seed(seed: int, local_rank: int):
|
|
torch.manual_seed(seed + local_rank)
|
|
np.random.seed(seed + local_rank)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed + local_rank)
|
|
|
|
|
|
def build_datasets(config: _config.TrainConfig):
|
|
# Use the unified data loader with PyTorch framework
|
|
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
|
|
return data_loader, data_loader.data_config()
|
|
|
|
|
|
def get_model_state_dict(model):
|
|
"""Get state dict from model, handling DDP wrapper."""
|
|
return (
|
|
model.module.state_dict()
|
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
|
else model.state_dict()
|
|
)
|
|
|
|
|
|
def get_model_parameters(model):
|
|
"""Get parameters from model, handling DDP wrapper."""
|
|
return (
|
|
model.module.parameters()
|
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
|
else model.parameters()
|
|
)
|
|
|
|
|
|
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
|
|
"""Save a checkpoint with model state, optimizer state, and metadata."""
|
|
if not is_main:
|
|
return
|
|
|
|
# Only save if it's time to save or if it's the final step
|
|
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
|
|
# Create temporary directory for atomic checkpoint saving
|
|
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
|
|
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
|
|
|
|
# Remove any existing temp directory and create new one
|
|
if tmp_ckpt_dir.exists():
|
|
shutil.rmtree(tmp_ckpt_dir)
|
|
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save model state using safetensors (handle shared tensors)
|
|
model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
|
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
|
|
|
|
# Save optimizer state using PyTorch format
|
|
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
|
|
|
|
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
|
|
metadata = {
|
|
"global_step": global_step,
|
|
"config": dataclasses.asdict(config),
|
|
"timestamp": time.time(),
|
|
}
|
|
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
|
|
|
|
# save norm stats
|
|
norm_stats = data_config.norm_stats
|
|
if norm_stats is not None and data_config.asset_id is not None:
|
|
_normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
|
|
|
|
# Atomically move temp directory to final location
|
|
if final_ckpt_dir.exists():
|
|
shutil.rmtree(final_ckpt_dir)
|
|
tmp_ckpt_dir.rename(final_ckpt_dir)
|
|
|
|
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
|
|
|
|
# Log checkpoint to wandb
|
|
if config.wandb_enabled:
|
|
wandb.log({"checkpoint_step": global_step}, step=global_step)
|
|
|
|
|
|
def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
|
"""Load the latest checkpoint and return the global step."""
|
|
checkpoint_steps = [
|
|
int(d.name)
|
|
for d in checkpoint_dir.iterdir()
|
|
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
|
]
|
|
|
|
if not checkpoint_steps:
|
|
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
|
|
|
latest_step = max(checkpoint_steps)
|
|
ckpt_dir = checkpoint_dir / f"{latest_step}"
|
|
|
|
# Clear memory before loading checkpoints
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
log_memory_usage(device, latest_step, "before_loading_checkpoint")
|
|
|
|
try:
|
|
# Load model state with error handling
|
|
logging.info("Loading model state...")
|
|
safetensors_path = ckpt_dir / "model.safetensors"
|
|
|
|
if safetensors_path.exists():
|
|
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
|
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
|
logging.info("Loaded model state from safetensors format")
|
|
else:
|
|
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
|
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
log_memory_usage(device, latest_step, "after_loading_model")
|
|
|
|
# Load optimizer state with error handling
|
|
logging.info("Loading optimizer state...")
|
|
optimizer_path = ckpt_dir / "optimizer.pt"
|
|
|
|
if optimizer_path.exists():
|
|
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
|
logging.info("Loaded optimizer state from pt format")
|
|
else:
|
|
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
|
|
|
optimizer.load_state_dict(optimizer_state_dict)
|
|
del optimizer_state_dict
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
log_memory_usage(device, latest_step, "after_loading_optimizer")
|
|
|
|
# Load metadata
|
|
logging.info("Loading metadata...")
|
|
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
|
global_step = metadata.get("global_step", latest_step)
|
|
del metadata
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
log_memory_usage(device, latest_step, "after_loading_metadata")
|
|
|
|
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
|
return global_step
|
|
|
|
except RuntimeError as e:
|
|
if "out of memory" in str(e):
|
|
# Clear memory and provide detailed error message
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
|
|
log_memory_usage(device, latest_step, "after_oom_error")
|
|
raise RuntimeError(
|
|
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
|
) from e
|
|
raise
|
|
|
|
|
|
def get_latest_checkpoint_step(checkpoint_dir):
|
|
"""Get the latest checkpoint step number from a checkpoint directory."""
|
|
checkpoint_steps = [
|
|
int(d.name)
|
|
for d in checkpoint_dir.iterdir()
|
|
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
|
]
|
|
return max(checkpoint_steps) if checkpoint_steps else None
|
|
|
|
|
|
def log_memory_usage(device, step, phase="unknown"):
|
|
"""Log detailed memory usage information."""
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
|
|
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
|
|
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
|
|
memory_free = memory_free / 1e9
|
|
|
|
# Get more detailed memory info
|
|
memory_stats = torch.cuda.memory_stats(device)
|
|
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
|
|
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
|
|
|
|
# Get DDP info if available
|
|
ddp_info = ""
|
|
if dist.is_initialized():
|
|
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
|
|
|
|
logging.info(
|
|
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
|
|
)
|
|
|
|
|
|
def train_loop(config: _config.TrainConfig):
|
|
use_ddp, local_rank, device = setup_ddp()
|
|
is_main = (not use_ddp) or (dist.get_rank() == 0)
|
|
set_seed(config.seed, local_rank)
|
|
|
|
# Initialize checkpoint directory and wandb
|
|
resuming = False
|
|
if config.resume:
|
|
# Find checkpoint directory based on experiment name
|
|
exp_checkpoint_dir = config.checkpoint_dir
|
|
if exp_checkpoint_dir.exists():
|
|
# Use validation to find the latest working checkpoint
|
|
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
|
|
if latest_step is not None:
|
|
resuming = True
|
|
logging.info(
|
|
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
|
|
)
|
|
else:
|
|
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
|
|
else:
|
|
raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
|
|
elif config.overwrite and config.checkpoint_dir.exists():
|
|
shutil.rmtree(config.checkpoint_dir)
|
|
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
|
|
|
|
# Create checkpoint directory with experiment name
|
|
if not resuming:
|
|
# For new runs, create experiment-specific checkpoint directory
|
|
exp_checkpoint_dir = config.checkpoint_dir
|
|
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
|
|
else:
|
|
# For resume, checkpoint_dir is already set to the experiment directory
|
|
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
|
|
|
|
# Initialize wandb (only on main process)
|
|
if is_main:
|
|
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
|
|
|
# Build data loader using the unified data loader
|
|
# Calculate effective batch size per GPU for DDP
|
|
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
|
|
world_size = torch.distributed.get_world_size() if use_ddp else 1
|
|
effective_batch_size = config.batch_size // world_size
|
|
logging.info(
|
|
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
|
|
)
|
|
|
|
# Pass the original batch size to data loader - it will handle DDP splitting internally
|
|
loader, data_config = build_datasets(config)
|
|
|
|
# Log sample images to wandb on first batch
|
|
if is_main and config.wandb_enabled and not resuming:
|
|
# Create a separate data loader for sample batch to avoid consuming the main loader
|
|
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
|
|
sample_batch = next(iter(sample_data_loader))
|
|
# Convert observation and actions to torch tensors
|
|
observation, actions = sample_batch
|
|
sample_batch = observation.to_dict()
|
|
sample_batch["actions"] = actions
|
|
|
|
# Create sample images for wandb
|
|
images_to_log = []
|
|
# Get batch size from the first image tensor
|
|
batch_size = next(iter(sample_batch["image"].values())).shape[0]
|
|
for i in range(min(5, batch_size)):
|
|
# Concatenate all camera views horizontally for this batch item
|
|
# Convert from NCHW to NHWC format for wandb
|
|
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
|
|
img_concatenated = img_concatenated.cpu().numpy()
|
|
images_to_log.append(wandb.Image(img_concatenated))
|
|
|
|
wandb.log({"camera_views": images_to_log}, step=0)
|
|
|
|
# Clear sample batch from memory aggressively
|
|
del sample_batch, observation, actions, images_to_log, img_concatenated
|
|
del sample_data_loader # Also delete the sample data loader
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
logging.info("Cleared sample batch and data loader from memory")
|
|
|
|
# Build model
|
|
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
|
|
# Convert dataclass to Pi0Config if needed
|
|
model_cfg = openpi.models.pi0_config.Pi0Config(
|
|
dtype=config.pytorch_training_precision,
|
|
action_dim=config.model.action_dim,
|
|
action_horizon=config.model.action_horizon,
|
|
max_token_len=config.model.max_token_len,
|
|
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
|
|
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
|
|
pi05=getattr(config.model, "pi05", False),
|
|
)
|
|
else:
|
|
model_cfg = config.model
|
|
# Update dtype to match pytorch_training_precision
|
|
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
|
|
|
|
model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
|
|
|
|
if hasattr(model, "gradient_checkpointing_enable"):
|
|
enable_gradient_checkpointing = True
|
|
model.gradient_checkpointing_enable()
|
|
logging.info("Enabled gradient checkpointing for memory optimization")
|
|
else:
|
|
enable_gradient_checkpointing = False
|
|
logging.info("Gradient checkpointing is not supported for this model")
|
|
|
|
# Log initial memory usage after model creation
|
|
if is_main and torch.cuda.is_available():
|
|
log_memory_usage(device, 0, "after_model_creation")
|
|
|
|
# Enable memory optimizations for large-scale training
|
|
if world_size >= 8:
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
# Set memory allocation configuration
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
|
logging.info("Enabled memory optimizations for 8+ GPU training")
|
|
|
|
if use_ddp:
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model,
|
|
device_ids=[device.index] if device.type == "cuda" else None,
|
|
find_unused_parameters=False, # Disable for memory efficiency
|
|
gradient_as_bucket_view=True, # Enable for memory efficiency
|
|
static_graph=world_size >= 8, # Enable for 8+ GPUs
|
|
)
|
|
|
|
# Load weights from weight_loader if specified (for fine-tuning)
|
|
# if config.pytorch_weight_path is not None:
|
|
# logging.info(f"Loading weights from: {config.pytorch_weight_path}")
|
|
|
|
# model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
|
|
# safetensors.torch.load_model(
|
|
# (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
|
|
# )
|
|
# logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
|
|
|
|
# Optimizer + learning rate schedule from config
|
|
warmup_steps = config.lr_schedule.warmup_steps
|
|
peak_lr = config.lr_schedule.peak_lr
|
|
decay_steps = config.lr_schedule.decay_steps
|
|
end_lr = config.lr_schedule.decay_lr
|
|
|
|
# Create optimizer with config parameters
|
|
optim = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=peak_lr,
|
|
betas=(config.optimizer.b1, config.optimizer.b2),
|
|
eps=config.optimizer.eps,
|
|
weight_decay=config.optimizer.weight_decay,
|
|
)
|
|
|
|
# Load checkpoint if resuming
|
|
global_step = 0
|
|
if resuming:
|
|
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
|
|
logging.info(f"Resumed training from step {global_step}")
|
|
|
|
def lr_schedule(step: int):
|
|
if step < warmup_steps:
|
|
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
|
|
init_lr = peak_lr / (warmup_steps + 1)
|
|
return init_lr + (peak_lr - init_lr) * step / warmup_steps
|
|
# cosine decay
|
|
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
|
|
cos = 0.5 * (1 + np.cos(np.pi * progress))
|
|
return end_lr + (peak_lr - end_lr) * cos
|
|
|
|
model.train()
|
|
start_time = time.time()
|
|
infos = [] # Collect stats over log interval
|
|
if is_main:
|
|
logging.info(
|
|
f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
|
|
)
|
|
logging.info(
|
|
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
|
|
)
|
|
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
|
|
logging.info(
|
|
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
|
|
)
|
|
logging.info(
|
|
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
|
|
)
|
|
logging.info("EMA is not supported for PyTorch training")
|
|
logging.info(f"Training precision: {model_cfg.dtype}")
|
|
|
|
# Training loop - iterate until we reach num_train_steps
|
|
pbar = (
|
|
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
|
|
if is_main
|
|
else None
|
|
)
|
|
|
|
while global_step < config.num_train_steps:
|
|
# Set epoch for distributed training
|
|
if use_ddp and hasattr(loader, "set_epoch"):
|
|
loader.set_epoch(global_step // len(loader))
|
|
|
|
for observation, actions in loader:
|
|
# Check if we've reached the target number of steps
|
|
if global_step >= config.num_train_steps:
|
|
break
|
|
|
|
# The unified data loader returns (observation, actions) tuple
|
|
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
|
|
actions = actions.to(torch.float32) # noqa: PLW2901
|
|
actions = actions.to(device) # noqa: PLW2901
|
|
|
|
# Update LR
|
|
for pg in optim.param_groups:
|
|
pg["lr"] = lr_schedule(global_step)
|
|
|
|
# Forward pass
|
|
losses = model(observation, actions)
|
|
# Ensure losses is a tensor and handle different return types
|
|
if isinstance(losses, list | tuple):
|
|
losses = torch.stack(losses)
|
|
elif not isinstance(losses, torch.Tensor):
|
|
losses = torch.tensor(losses, device=device, dtype=torch.float32)
|
|
|
|
loss = losses.mean()
|
|
|
|
# Backward pass
|
|
loss.backward()
|
|
|
|
# Log memory usage after backward pass
|
|
if global_step < 5 and is_main and torch.cuda.is_available():
|
|
log_memory_usage(device, global_step, "after_backward")
|
|
|
|
# Gradient clipping
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
|
|
|
|
# Optimizer step
|
|
optim.step()
|
|
optim.zero_grad(set_to_none=True)
|
|
|
|
# Clear gradients more aggressively
|
|
for param in model.parameters():
|
|
if param.grad is not None:
|
|
param.grad.detach_()
|
|
param.grad = None
|
|
|
|
# Collect stats
|
|
if is_main:
|
|
infos.append(
|
|
{
|
|
"loss": loss.item(),
|
|
"learning_rate": optim.param_groups[0]["lr"],
|
|
"grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
|
|
}
|
|
)
|
|
|
|
if is_main and (global_step % config.log_interval == 0):
|
|
elapsed = time.time() - start_time
|
|
|
|
# Average stats over log interval
|
|
avg_loss = sum(info["loss"] for info in infos) / len(infos)
|
|
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
|
|
|
|
avg_grad_norm = None
|
|
if any("grad_norm" in info for info in infos):
|
|
vals = [
|
|
info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
|
|
]
|
|
if len(vals) > 0:
|
|
avg_grad_norm = sum(vals) / len(vals)
|
|
logging.info(
|
|
f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
|
|
if avg_grad_norm is not None
|
|
else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
|
|
)
|
|
|
|
# Log to wandb
|
|
if config.wandb_enabled and len(infos) > 0:
|
|
log_payload = {
|
|
"loss": avg_loss,
|
|
"learning_rate": avg_lr,
|
|
"step": global_step,
|
|
"time_per_step": elapsed / config.log_interval,
|
|
}
|
|
if avg_grad_norm is not None:
|
|
log_payload["grad_norm"] = avg_grad_norm
|
|
wandb.log(log_payload, step=global_step)
|
|
|
|
start_time = time.time()
|
|
infos = [] # Reset stats collection
|
|
|
|
global_step += 1
|
|
# Save checkpoint using the new mechanism
|
|
save_checkpoint(model, optim, global_step, config, is_main, data_config)
|
|
|
|
# Update progress bar
|
|
if pbar is not None:
|
|
pbar.update(1)
|
|
pbar.set_postfix(
|
|
{"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
|
|
)
|
|
|
|
# Close progress bar
|
|
if pbar is not None:
|
|
pbar.close()
|
|
|
|
# Finish wandb run
|
|
if is_main and config.wandb_enabled:
|
|
wandb.finish()
|
|
|
|
cleanup_ddp()
|
|
|
|
|
|
def main():
|
|
init_logging()
|
|
config = _config.cli()
|
|
train_loop(config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|