Files
issacdataengine/policy/openpi-InternData-A1/scripts/training_scripts/multi_node.sh
2026-03-17 23:05:23 +08:00

210 lines
7.5 KiB
Bash
Executable File

#!/usr/bin/env bash
set -ex
cd YOUR_PATH/openpi
export USE_TF=0
export USE_TORCH=0
export USE_JAX=1
export IMAGEIO_FFMPEG_EXE=ffmpeg
# JAX GPU memory fraction
export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.9}"
# ============================================================================
# NCCL Configuration
# ============================================================================
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_TIMEOUT=3600
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
# ============================================================================
# Platform-Injected Configuration
# ============================================================================
# The platform automatically injects these when DISTRIBUTED_JOB=true:
# - NCCL_IB_HCA, NCCL_IB_GID_INDEX, NCCL_SOCKET_IFNAME
# - NODE_RANK, NODE_COUNT, MASTER_ADDR, PROC_PER_NODE
# - CUDA_VISIBLE_DEVICES
# We trust and use these platform configurations directly.
# ============================================================================
echo ""
echo "=========================================="
echo "Platform Configuration"
echo "=========================================="
echo "NODE_RANK: ${NODE_RANK:-<not set>}"
echo "NODE_COUNT: ${NODE_COUNT:-<not set>}"
echo "MASTER_ADDR: ${MASTER_ADDR:-<not set>}"
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-<not set>}"
echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-<not set>}"
echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-<not set>}"
echo "=========================================="
echo ""
# ============================================================================
# NCCL Transport Configuration
# ============================================================================
# Use platform-injected configuration if available, otherwise fallback
# ============================================================================
if [ -n "${NCCL_IB_HCA:-}" ]; then
# Platform has configured InfiniBand
echo "[NCCL] ✓ Using platform-injected InfiniBand configuration"
# Only set NCCL_NET if not already set
if [ -z "${NCCL_NET:-}" ]; then
export NCCL_NET="IB"
fi
# Set IB timeout if not already set
if [ -z "${NCCL_IB_TIMEOUT:-}" ]; then
export NCCL_IB_TIMEOUT=23
fi
echo "[NCCL] NCCL_NET: ${NCCL_NET}"
echo "[NCCL] NCCL_IB_HCA: ${NCCL_IB_HCA}"
echo "[NCCL] NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX}"
echo "[NCCL] NCCL_IB_TIMEOUT: ${NCCL_IB_TIMEOUT}"
elif [ -n "${NCCL_SOCKET_IFNAME:-}" ]; then
# Platform has configured Socket
echo "[NCCL] ✓ Using platform-injected Socket configuration"
if [ -z "${NCCL_NET:-}" ]; then
export NCCL_NET="Socket"
fi
echo "[NCCL] NCCL_NET: ${NCCL_NET}"
echo "[NCCL] NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME}"
else
# No platform injection - use OPENPI_NCCL_NET preference
echo "[NCCL] ⚠️ No platform-injected NCCL configuration"
if [ "${OPENPI_NCCL_NET:-IB}" = "IB" ]; then
echo "[NCCL] ✗ InfiniBand requested but not configured by platform"
echo "[NCCL] ✗ Falling back to Socket transport"
export NCCL_NET="Socket"
export NCCL_IB_DISABLE=1
else
export NCCL_NET="Socket"
export NCCL_IB_DISABLE=1
echo "[NCCL] Using Socket transport"
fi
fi
echo ""
# ============================================================================
# JAX Distributed Configuration
# ============================================================================
# Map platform variables to JAX variables
# ============================================================================
echo "=========================================="
echo "JAX Distributed Configuration"
echo "=========================================="
JAX_COORDINATOR_PORT="${JAX_COORDINATOR_PORT:-12345}"
# Set JAX coordinator address
if [ -z "${JAX_COORDINATOR_ADDRESS:-}" ] && [ -n "${MASTER_ADDR:-}" ]; then
export JAX_COORDINATOR_ADDRESS="${MASTER_ADDR}:${JAX_COORDINATOR_PORT}"
echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS} (from MASTER_ADDR)"
elif [ -n "${JAX_COORDINATOR_ADDRESS:-}" ]; then
echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS}"
else
echo "[JAX] ✗ WARNING: No coordinator address set!"
fi
# Set JAX process count
if [ -z "${JAX_PROCESS_COUNT:-}" ] && [ -n "${NODE_COUNT:-}" ]; then
export JAX_PROCESS_COUNT="${NODE_COUNT}"
echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT} (from NODE_COUNT)"
elif [ -n "${JAX_PROCESS_COUNT:-}" ]; then
echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT}"
fi
# Set JAX process index
if [ -z "${JAX_PROCESS_INDEX:-}" ] && [ -n "${NODE_RANK:-}" ]; then
export JAX_PROCESS_INDEX="${NODE_RANK}"
echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX} (from NODE_RANK)"
elif [ -n "${JAX_PROCESS_INDEX:-}" ]; then
echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX}"
fi
echo "=========================================="
echo ""
# ============================================================================
# Python Environment
# ============================================================================
export PYTHONPATH=YOUR_PATH/openpi/src:YOUR_PATH/openpi/packages/openpi-client/src:YOUR_PATH/openpi/third_party/lerobot:${PYTHONPATH}
conda activate pi0
# ============================================================================
# Configuration Summary
# ============================================================================
echo "=========================================="
echo "Configuration Summary"
echo "=========================================="
echo "NCCL_NET: ${NCCL_NET:-<not set>}"
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-<not set>}"
echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-<not set>}"
echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-<not set>}"
echo "JAX_COORDINATOR: ${JAX_COORDINATOR_ADDRESS:-<not set>}"
echo "JAX_PROCESS_COUNT: ${JAX_PROCESS_COUNT:-<not set>}"
echo "JAX_PROCESS_INDEX: ${JAX_PROCESS_INDEX:-<not set>}"
echo "=========================================="
echo ""
# ============================================================================
# Display Host Information
# ============================================================================
python - <<'EOF'
import socket
import os
import jax
hostname = socket.gethostname()
devices = jax.local_devices()
device_count = len(devices)
device_ids = [d.id for d in devices]
print(f"[JAX] host={hostname}, devices={device_count}xgpu, ids={device_ids}")
print(f"[JAX] JAX_COORDINATOR_ADDRESS={os.environ.get('JAX_COORDINATOR_ADDRESS', '<not set>')}")
print(f"[JAX] JAX_PROCESS_COUNT={os.environ.get('JAX_PROCESS_COUNT', '<not set>')}")
print(f"[JAX] JAX_PROCESS_INDEX={os.environ.get('JAX_PROCESS_INDEX', '<not set>')}")
EOF
# ============================================================================
# Launch Training
# ============================================================================
# Determine experiment name based on transport
if [ "${OPENPI_DEBUG_SINGLE_GPU:-0}" = "1" ]; then
EXP_NAME="${EXP_NAME:-dev_jax_single_gpu}"
echo "[DEBUG] Running in single-GPU mode"
else
EXP_NAME="${EXP_NAME:-dev_jax_multinode_ib}"
fi
echo ""
echo "=========================================="
echo "Starting Training"
echo "=========================================="
echo "Experiment: $EXP_NAME"
echo "=========================================="
echo ""
ulimit -n 1000000
python scripts/train_jax_multinode.py \
pretrain-interndata-a1 \
--exp-name=pretrain-interndata-a1 \
--num_workers=12 \
--fsdp_devices=8 \
--batch_size=512 \
--num_train_steps=2000000 \
--save_interval=5000