Initial commit
This commit is contained in:
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
75
scripts/compute_norm_stats.py
Normal file
75
scripts/compute_norm_stats.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Compute normalization statistics for a config.
|
||||
|
||||
This script is used to compute the normalization statistics for a given config. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config assets directory.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_dataset(config: _config.TrainConfig) -> tuple[_config.DataConfig, _data_loader.Dataset]:
|
||||
data_config = config.data.create(config.assets_dirs, config.model)
|
||||
if data_config.repo_id is None:
|
||||
raise ValueError("Data config must have a repo_id")
|
||||
dataset = _data_loader.create_dataset(data_config, config.model)
|
||||
dataset = _data_loader.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
return data_config, dataset
|
||||
|
||||
|
||||
def main(config_name: str, max_frames: int | None = None):
|
||||
config = _config.get_config(config_name)
|
||||
data_config, dataset = create_dataset(config)
|
||||
|
||||
num_frames = len(dataset)
|
||||
shuffle = False
|
||||
|
||||
if max_frames is not None and max_frames < num_frames:
|
||||
num_frames = max_frames
|
||||
shuffle = True
|
||||
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=1,
|
||||
num_workers=8,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_frames,
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
|
||||
for key in keys:
|
||||
values = np.asarray(batch[key][0])
|
||||
stats[key].update(values.reshape(-1, values.shape[-1]))
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
output_path = config.assets_dirs / data_config.repo_id
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
29
scripts/docker/compose.yml
Normal file
29
scripts/docker/compose.yml
Normal file
@@ -0,0 +1,29 @@
|
||||
# Run with:
|
||||
# docker compose -f scripts/compose.yml up --build
|
||||
services:
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
# Populate configured openpi data home to /openpi_assets inside the container.
|
||||
# Populate aws credential inside the container.
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
37
scripts/docker/install_docker_ubuntu22.sh
Executable file
37
scripts/docker/install_docker_ubuntu22.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Add Docker's official GPG key:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ca-certificates curl
|
||||
sudo install -m 0755 -d /etc/apt/keyrings
|
||||
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
|
||||
sudo chmod a+r /etc/apt/keyrings/docker.asc
|
||||
|
||||
# Add the repository to Apt sources:
|
||||
echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
|
||||
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
|
||||
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
|
||||
sudo apt-get update
|
||||
|
||||
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
|
||||
|
||||
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
|
||||
# See https://docs.docker.com/engine/install/linux-postinstall/
|
||||
username=$(whoami)
|
||||
sudo usermod -aG docker $username
|
||||
|
||||
# Configure docker to start automatically on system boot.
|
||||
sudo systemctl enable docker.service
|
||||
sudo systemctl enable containerd.service
|
||||
|
||||
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
|
||||
if [ ~/.docker/config.json ]; then
|
||||
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "********************************************************************"
|
||||
echo "**** Restart to allow Docker permission changes to take effect. ****"
|
||||
echo "********************************************************************"
|
||||
echo ""
|
||||
17
scripts/docker/install_nvidia_container_toolkit.sh
Executable file
17
scripts/docker/install_nvidia_container_toolkit.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
|
||||
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
|
||||
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
|
||||
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
|
||||
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
|
||||
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
|
||||
sudo nvidia-ctk runtime configure --runtime=docker
|
||||
sudo systemctl restart docker
|
||||
34
scripts/docker/serve_policy.Dockerfile
Normal file
34
scripts/docker/serve_policy.Dockerfile
Normal file
@@ -0,0 +1,34 @@
|
||||
# Dockerfile for serving a PI policy.
|
||||
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Needed because LeRobot uses git-lfs.
|
||||
RUN apt-get update && apt-get install -y git git-lfs
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
|
||||
122
scripts/serve_policy.py
Normal file
122
scripts/serve_policy.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import tyro
|
||||
|
||||
from openpi.policies import policy as _policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.serving import websocket_policy_server
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Checkpoint:
|
||||
"""Load a policy from a trained checkpoint."""
|
||||
|
||||
# Training config name (e.g., "pi0_aloha_sim").
|
||||
config: str
|
||||
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
||||
dir: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Default:
|
||||
"""Use the default policy for the given environment."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Arguments for the serve_policy script."""
|
||||
|
||||
# Environment to serve the policy for. This is only used when serving default policies.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
|
||||
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
|
||||
# prompt.
|
||||
default_prompt: str | None = None
|
||||
|
||||
# Port to serve the policy on.
|
||||
port: int = 8000
|
||||
# Record the policy's behavior for debugging.
|
||||
record: bool = False
|
||||
|
||||
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
||||
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
|
||||
|
||||
|
||||
# Default checkpoints that should be used for each environment.
|
||||
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
|
||||
EnvMode.ALOHA: Checkpoint(
|
||||
config="pi0_aloha",
|
||||
dir="s3://openpi-assets/checkpoints/pi0_base",
|
||||
),
|
||||
EnvMode.ALOHA_SIM: Checkpoint(
|
||||
config="pi0_aloha_sim",
|
||||
dir="s3://openpi-assets/checkpoints/pi0_aloha_sim",
|
||||
),
|
||||
EnvMode.DROID: Checkpoint(
|
||||
config="pi0_fast_droid",
|
||||
dir="s3://openpi-assets/checkpoints/pi0_fast_droid",
|
||||
),
|
||||
EnvMode.LIBERO: Checkpoint(
|
||||
config="pi0_fast_libero",
|
||||
dir="s3://openpi-assets/checkpoints/pi0_fast_libero",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
|
||||
"""Create a default policy for the given environment."""
|
||||
if checkpoint := DEFAULT_CHECKPOINT.get(env):
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
|
||||
)
|
||||
raise ValueError(f"Unsupported environment mode: {env}")
|
||||
|
||||
|
||||
def create_policy(args: Args) -> _policy.Policy:
|
||||
"""Create a policy from the given arguments."""
|
||||
match args.policy:
|
||||
case Checkpoint():
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
|
||||
)
|
||||
case Default():
|
||||
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
policy = create_policy(args)
|
||||
policy_metadata = policy.metadata
|
||||
|
||||
# Record the policy's behavior.
|
||||
if args.record:
|
||||
policy = _policy.PolicyRecorder(policy, "policy_records")
|
||||
|
||||
hostname = socket.gethostname()
|
||||
local_ip = socket.gethostbyname(hostname)
|
||||
logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
|
||||
|
||||
server = websocket_policy_server.WebsocketPolicyServer(
|
||||
policy=policy,
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
metadata=policy_metadata,
|
||||
)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
main(tyro.cli(Args))
|
||||
274
scripts/train.py
Normal file
274
scripts/train.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import platform
|
||||
from typing import Any
|
||||
|
||||
import etils.epath as epath
|
||||
import flax.nnx as nnx
|
||||
from flax.training import common_utils
|
||||
import flax.traverse_util as traverse_util
|
||||
import jax
|
||||
import jax.experimental
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import tqdm_loggable.auto as tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.optimizer as _optimizer
|
||||
import openpi.training.sharding as sharding
|
||||
import openpi.training.utils as training_utils
|
||||
import openpi.training.weight_loaders as _weight_loaders
|
||||
|
||||
|
||||
def init_logging():
|
||||
"""Custom logging format for better readability."""
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
if log_code:
|
||||
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
||||
|
||||
|
||||
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
|
||||
"""Loads and validates the weights. Returns a loaded subset of the weights."""
|
||||
loaded_params = loader.load(params_shape)
|
||||
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
|
||||
|
||||
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
|
||||
return traverse_util.unflatten_dict(
|
||||
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
|
||||
)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def init_train_state(
|
||||
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
|
||||
) -> tuple[training_utils.TrainState, Any]:
|
||||
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
|
||||
|
||||
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
|
||||
rng, model_rng = jax.random.split(rng)
|
||||
# initialize the model (and its parameters).
|
||||
model = config.model.create(model_rng)
|
||||
|
||||
# Merge the partial params into the model.
|
||||
if partial_params is not None:
|
||||
graphdef, state = nnx.split(model)
|
||||
# This will produce an error if the partial params are not a subset of the state.
|
||||
state.replace_by_pure_dict(partial_params)
|
||||
model = nnx.merge(graphdef, state)
|
||||
|
||||
params = nnx.state(model)
|
||||
# Convert frozen params to bfloat16.
|
||||
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
|
||||
|
||||
return training_utils.TrainState(
|
||||
step=0,
|
||||
params=params,
|
||||
model_def=nnx.graphdef(model),
|
||||
tx=tx,
|
||||
opt_state=tx.init(params.filter(config.trainable_filter)),
|
||||
ema_decay=config.ema_decay,
|
||||
ema_params=None if config.ema_decay is None else params,
|
||||
)
|
||||
|
||||
train_state_shape = jax.eval_shape(init, init_rng)
|
||||
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
||||
|
||||
if resume:
|
||||
return train_state_shape, state_sharding
|
||||
|
||||
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
# Initialize the train state and mix in the partial params.
|
||||
train_state = jax.jit(
|
||||
init,
|
||||
donate_argnums=(1,), # donate the partial params buffer.
|
||||
in_shardings=replicated_sharding,
|
||||
out_shardings=state_sharding,
|
||||
)(init_rng, partial_params)
|
||||
|
||||
return train_state, state_sharding
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def train_step(
|
||||
config: _config.TrainConfig,
|
||||
rng: at.KeyArrayLike,
|
||||
state: training_utils.TrainState,
|
||||
batch: tuple[_model.Observation, _model.Actions],
|
||||
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
||||
model = nnx.merge(state.model_def, state.params)
|
||||
model.train()
|
||||
|
||||
@at.typecheck
|
||||
def loss_fn(
|
||||
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
|
||||
):
|
||||
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
|
||||
return jnp.mean(chunked_loss)
|
||||
|
||||
train_rng = jax.random.fold_in(rng, state.step)
|
||||
observation, actions = batch
|
||||
|
||||
# Filter out frozen params.
|
||||
diff_state = nnx.DiffState(0, config.trainable_filter)
|
||||
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
|
||||
|
||||
params = state.params.filter(config.trainable_filter)
|
||||
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
# Update the model in place and return the new full state.
|
||||
nnx.update(model, new_params)
|
||||
new_params = nnx.state(model)
|
||||
|
||||
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
||||
if state.ema_decay is not None:
|
||||
new_state = dataclasses.replace(
|
||||
new_state,
|
||||
ema_params=jax.tree.map(
|
||||
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
||||
),
|
||||
)
|
||||
|
||||
# Filter out params that aren't kernels.
|
||||
kernel_params = nnx.state(
|
||||
model,
|
||||
nnx.All(
|
||||
nnx.Param,
|
||||
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
|
||||
lambda _, x: x.value.ndim > 1,
|
||||
),
|
||||
)
|
||||
info = {
|
||||
"loss": loss,
|
||||
"grad_norm": optax.global_norm(grads),
|
||||
"param_norm": optax.global_norm(kernel_params),
|
||||
}
|
||||
return new_state, info
|
||||
|
||||
|
||||
def main(config: _config.TrainConfig):
|
||||
init_logging()
|
||||
logging.info(f"Running on: {platform.node()}")
|
||||
|
||||
if config.batch_size % jax.device_count() != 0:
|
||||
raise ValueError(
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
train_rng, init_rng = jax.random.split(rng)
|
||||
|
||||
mesh = sharding.make_mesh(config.fsdp_devices)
|
||||
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
||||
config.checkpoint_dir,
|
||||
keep_period=config.keep_period,
|
||||
overwrite=config.overwrite,
|
||||
resume=config.resume,
|
||||
)
|
||||
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
||||
|
||||
data_loader = _data_loader.create_data_loader(
|
||||
config,
|
||||
sharding=data_sharding,
|
||||
num_workers=config.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
data_iter = iter(data_loader)
|
||||
batch = next(data_iter)
|
||||
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
||||
|
||||
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
|
||||
jax.block_until_ready(train_state)
|
||||
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
||||
|
||||
if resuming:
|
||||
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
||||
|
||||
ptrain_step = jax.jit(
|
||||
functools.partial(train_step, config),
|
||||
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
|
||||
out_shardings=(train_state_sharding, replicated_sharding),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
start_step = int(train_state.step)
|
||||
pbar = tqdm.tqdm(
|
||||
range(start_step, config.num_train_steps),
|
||||
initial=start_step,
|
||||
total=config.num_train_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
infos = []
|
||||
for step in pbar:
|
||||
with sharding.set_mesh(mesh):
|
||||
train_state, info = ptrain_step(train_rng, train_state, batch)
|
||||
infos.append(info)
|
||||
if step % config.log_interval == 0:
|
||||
stacked_infos = common_utils.stack_forest(infos)
|
||||
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
||||
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
||||
pbar.write(f"Step {step}: {info_str}")
|
||||
wandb.log(reduced_info, step=step)
|
||||
infos = []
|
||||
batch = next(data_iter)
|
||||
|
||||
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
|
||||
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
||||
|
||||
logging.info("Waiting for checkpoint manager to finish")
|
||||
checkpoint_manager.wait_until_finished()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(_config.cli())
|
||||
30
scripts/train_test.py
Normal file
30
scripts/train_test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
|
||||
from openpi.training import config as _config
|
||||
|
||||
from . import train
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["debug"])
|
||||
def test_train(tmp_path: pathlib.Path, config_name: str):
|
||||
config = dataclasses.replace(
|
||||
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
||||
batch_size=2,
|
||||
checkpoint_base_dir=tmp_path / "checkpoint",
|
||||
exp_name="test",
|
||||
overwrite=False,
|
||||
resume=False,
|
||||
num_train_steps=2,
|
||||
log_interval=1,
|
||||
)
|
||||
train.main(config)
|
||||
|
||||
# test resuming
|
||||
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
||||
train.main(config)
|
||||
Reference in New Issue
Block a user