#!/usr/bin/env bash set -ex cd YOUR_PATH/openpi export USE_TF=0 export USE_TORCH=0 export USE_JAX=1 export IMAGEIO_FFMPEG_EXE=ffmpeg # JAX GPU memory fraction export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.9}" # ============================================================================ # NCCL Configuration # ============================================================================ export NCCL_ASYNC_ERROR_HANDLING=1 export NCCL_TIMEOUT=3600 export NCCL_DEBUG="${NCCL_DEBUG:-WARN}" # ============================================================================ # Platform-Injected Configuration # ============================================================================ # The platform automatically injects these when DISTRIBUTED_JOB=true: # - NCCL_IB_HCA, NCCL_IB_GID_INDEX, NCCL_SOCKET_IFNAME # - NODE_RANK, NODE_COUNT, MASTER_ADDR, PROC_PER_NODE # - CUDA_VISIBLE_DEVICES # We trust and use these platform configurations directly. # ============================================================================ echo "" echo "==========================================" echo "Platform Configuration" echo "==========================================" echo "NODE_RANK: ${NODE_RANK:-}" echo "NODE_COUNT: ${NODE_COUNT:-}" echo "MASTER_ADDR: ${MASTER_ADDR:-}" echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-}" echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-}" echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-}" echo "==========================================" echo "" # ============================================================================ # NCCL Transport Configuration # ============================================================================ # Use platform-injected configuration if available, otherwise fallback # ============================================================================ if [ -n "${NCCL_IB_HCA:-}" ]; then # Platform has configured InfiniBand echo "[NCCL] ✓ Using platform-injected InfiniBand configuration" # Only set NCCL_NET if not already set if [ -z "${NCCL_NET:-}" ]; then export NCCL_NET="IB" fi # Set IB timeout if not already set if [ -z "${NCCL_IB_TIMEOUT:-}" ]; then export NCCL_IB_TIMEOUT=23 fi echo "[NCCL] NCCL_NET: ${NCCL_NET}" echo "[NCCL] NCCL_IB_HCA: ${NCCL_IB_HCA}" echo "[NCCL] NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX}" echo "[NCCL] NCCL_IB_TIMEOUT: ${NCCL_IB_TIMEOUT}" elif [ -n "${NCCL_SOCKET_IFNAME:-}" ]; then # Platform has configured Socket echo "[NCCL] ✓ Using platform-injected Socket configuration" if [ -z "${NCCL_NET:-}" ]; then export NCCL_NET="Socket" fi echo "[NCCL] NCCL_NET: ${NCCL_NET}" echo "[NCCL] NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME}" else # No platform injection - use OPENPI_NCCL_NET preference echo "[NCCL] ⚠️ No platform-injected NCCL configuration" if [ "${OPENPI_NCCL_NET:-IB}" = "IB" ]; then echo "[NCCL] ✗ InfiniBand requested but not configured by platform" echo "[NCCL] ✗ Falling back to Socket transport" export NCCL_NET="Socket" export NCCL_IB_DISABLE=1 else export NCCL_NET="Socket" export NCCL_IB_DISABLE=1 echo "[NCCL] Using Socket transport" fi fi echo "" # ============================================================================ # JAX Distributed Configuration # ============================================================================ # Map platform variables to JAX variables # ============================================================================ echo "==========================================" echo "JAX Distributed Configuration" echo "==========================================" JAX_COORDINATOR_PORT="${JAX_COORDINATOR_PORT:-12345}" # Set JAX coordinator address if [ -z "${JAX_COORDINATOR_ADDRESS:-}" ] && [ -n "${MASTER_ADDR:-}" ]; then export JAX_COORDINATOR_ADDRESS="${MASTER_ADDR}:${JAX_COORDINATOR_PORT}" echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS} (from MASTER_ADDR)" elif [ -n "${JAX_COORDINATOR_ADDRESS:-}" ]; then echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS}" else echo "[JAX] ✗ WARNING: No coordinator address set!" fi # Set JAX process count if [ -z "${JAX_PROCESS_COUNT:-}" ] && [ -n "${NODE_COUNT:-}" ]; then export JAX_PROCESS_COUNT="${NODE_COUNT}" echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT} (from NODE_COUNT)" elif [ -n "${JAX_PROCESS_COUNT:-}" ]; then echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT}" fi # Set JAX process index if [ -z "${JAX_PROCESS_INDEX:-}" ] && [ -n "${NODE_RANK:-}" ]; then export JAX_PROCESS_INDEX="${NODE_RANK}" echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX} (from NODE_RANK)" elif [ -n "${JAX_PROCESS_INDEX:-}" ]; then echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX}" fi echo "==========================================" echo "" # ============================================================================ # Python Environment # ============================================================================ export PYTHONPATH=YOUR_PATH/openpi/src:YOUR_PATH/openpi/packages/openpi-client/src:YOUR_PATH/openpi/third_party/lerobot:${PYTHONPATH} conda activate pi0 # ============================================================================ # Configuration Summary # ============================================================================ echo "==========================================" echo "Configuration Summary" echo "==========================================" echo "NCCL_NET: ${NCCL_NET:-}" echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-}" echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-}" echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-}" echo "JAX_COORDINATOR: ${JAX_COORDINATOR_ADDRESS:-}" echo "JAX_PROCESS_COUNT: ${JAX_PROCESS_COUNT:-}" echo "JAX_PROCESS_INDEX: ${JAX_PROCESS_INDEX:-}" echo "==========================================" echo "" # ============================================================================ # Display Host Information # ============================================================================ python - <<'EOF' import socket import os import jax hostname = socket.gethostname() devices = jax.local_devices() device_count = len(devices) device_ids = [d.id for d in devices] print(f"[JAX] host={hostname}, devices={device_count}xgpu, ids={device_ids}") print(f"[JAX] JAX_COORDINATOR_ADDRESS={os.environ.get('JAX_COORDINATOR_ADDRESS', '')}") print(f"[JAX] JAX_PROCESS_COUNT={os.environ.get('JAX_PROCESS_COUNT', '')}") print(f"[JAX] JAX_PROCESS_INDEX={os.environ.get('JAX_PROCESS_INDEX', '')}") EOF # ============================================================================ # Launch Training # ============================================================================ # Determine experiment name based on transport if [ "${OPENPI_DEBUG_SINGLE_GPU:-0}" = "1" ]; then EXP_NAME="${EXP_NAME:-dev_jax_single_gpu}" echo "[DEBUG] Running in single-GPU mode" else EXP_NAME="${EXP_NAME:-dev_jax_multinode_ib}" fi echo "" echo "==========================================" echo "Starting Training" echo "==========================================" echo "Experiment: $EXP_NAME" echo "==========================================" echo "" ulimit -n 1000000 python scripts/train_jax_multinode.py \ pretrain-interndata-a1 \ --exp-name=pretrain-interndata-a1 \ --num_workers=12 \ --fsdp_devices=8 \ --batch_size=512 \ --num_train_steps=2000000 \ --save_interval=5000