Files
openpi/scripts/train.py
Ury Zhilinsky 385780ecc3 Initial commit
2024-12-23 13:38:06 -08:00

285 lines
11 KiB
Python

import dataclasses
from functools import partial
import logging
import platform
from typing import Any
import etils.epath as epath
from flax.training import common_utils
import jax
import jax._src.tree_util as private_tree_util
import jax.experimental
import jax.numpy as jnp
import optax
import tqdm_loggable.auto as tqdm
import wandb
import openpi.models.common as _common
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders
def init_logging():
"""Custom logging format for better readability."""
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers[0].setFormatter(formatter)
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
if log_code:
wandb.run.log_code(epath.Path(__file__).parent.parent)
def _load_weights_and_validate(weight_loader: _weight_loaders.WeightLoader, params: at.Params) -> at.Params:
"""Runs the weight loader and validates that the params structure, shapes, and dtypes are unchanged."""
new_params = weight_loader.load(jax.tree.map(lambda x: x, params))
if errors := list(private_tree_util.equality_errors(params, new_params)):
raise ValueError(
"Weight loading changed the params structure:\n"
+ (
"\n".join(
f" - {jax.tree_util.keystr(path)} changed from {thing1} to {thing2}, so {explanation}.\n"
for path, thing1, thing2, explanation in errors
)
)
)
def check(kp, x, y):
if (x := jax.ShapeDtypeStruct(x.shape, x.dtype)) != (y := jax.ShapeDtypeStruct(y.shape, y.dtype)):
raise ValueError(
f"Weight loading changed the params structure: expected {y}, got {x} at {jax.tree_util.keystr(kp)}"
)
jax.tree_util.tree_map_with_path(check, params, new_params)
return new_params
@at.typecheck
def init_train_state(
config: _config.TrainConfig,
model: _model.Model,
init_rng: at.KeyArrayLike,
batch: tuple[_common.Observation, _common.Actions],
mesh: jax.sharding.Mesh,
data_sharding: jax.sharding.Sharding,
*,
resume: bool,
) -> tuple[training_utils.TrainState, Any]:
weight_decay_mask = None
freeze_mask = None
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask, freeze_mask)
def init(
rng: at.KeyArrayLike,
data: tuple[_common.Observation, _common.Actions],
params_sharding: jax.sharding.Sharding | None = None,
) -> training_utils.TrainState:
rng, model_rng = jax.random.split(rng)
observation, actions = data
params = model.init_params(model_rng, observation, actions)
# jax.experimental.io_callback raises spmd partitioning warnings, setting constraints
# to replicate params to avoid the warnings. the returned train state will be sharded still
# since fsdp sharding is specified as output_sharding when jitting this function.
if params_sharding is not None:
params = jax.lax.with_sharding_constraint(params, params_sharding)
params = jax.experimental.io_callback(
partial(_load_weights_and_validate, config.weight_loader),
params,
params,
ordered=True,
)
if params_sharding is not None:
params = jax.lax.with_sharding_constraint(params, params_sharding)
return training_utils.TrainState(
step=0,
params=params,
opt_state=tx.init(params),
tx=tx,
ema_decay=config.ema_decay,
ema_params=None if config.ema_decay is None else params,
)
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
train_state_shape = jax.eval_shape(init, init_rng, batch)
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
if resume:
return train_state_shape, state_sharding
train_state = jax.jit(
init,
in_shardings=(replicated_sharding, data_sharding),
out_shardings=state_sharding,
static_argnums=(2,),
)(init_rng, batch, replicated_sharding)
return train_state, state_sharding
@at.typecheck
def train_step(
rng: at.KeyArrayLike,
state: training_utils.TrainState,
model: _model.Model,
batch: tuple[_common.Observation, _common.Actions],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
def loss_fn(params: at.Params, rng: at.KeyArrayLike, observation: _common.Observation, actions: _common.Actions):
chunked_loss = model.compute_loss(rng, observation, actions, params=params, train=True)
return jnp.mean(chunked_loss)
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
loss, grads = jax.value_and_grad(loss_fn)(state.params, train_rng, observation, actions)
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
new_params = optax.apply_updates(state.params, updates)
new_state = state.replace(step=state.step + 1, params=new_params, opt_state=new_opt_state)
if state.ema_decay is not None:
new_state = new_state.replace(
ema_params=jax.tree.map(
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
)
)
kernel_mask = training_utils.mask_from_regex(r".*\['kernel'\]", state.params)
kernel_params = jax.tree.map(lambda p, m: p if m else None, state.params, kernel_mask)
info = {
"loss": loss,
"grad_norm": optax.global_norm(grads), # TODO: do not compute norm for frozen params
"param_norm": optax.global_norm(kernel_params),
}
return new_state, info
def main(config: _config.TrainConfig):
init_logging()
logging.info(f"Running on: {platform.node()}")
if config.batch_size % jax.device_count() != 0:
raise ValueError(
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)
if jax.device_count() % config.fsdp_devices != 0:
raise ValueError(
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {config.fsdp_devices}."
)
mesh_shape = (jax.device_count() // config.fsdp_devices, config.fsdp_devices)
mesh = jax.make_mesh(mesh_shape, ("batch", "model"))
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(("batch", "model")))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
config.checkpoint_dir,
keep_interval=config.keep_interval,
overwrite=config.overwrite,
resume=config.resume,
)
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
model = config.create_model()
data_loader = _data_loader.create_data_loader(
config,
model,
sharding=data_sharding,
num_workers=config.num_workers,
shuffle=True,
)
data_iter = iter(data_loader)
batch = next(data_iter)
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
train_state, train_state_sharding = init_train_state(
config, model, init_rng, batch, mesh, data_sharding, resume=resuming
)
jax.block_until_ready(train_state)
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
if resuming:
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
ptrain_step = jax.jit(
train_step,
in_shardings=(replicated_sharding, train_state_sharding, None, data_sharding),
out_shardings=(train_state_sharding, replicated_sharding),
donate_argnums=(1,),
)
start_step = int(train_state.step)
pbar = tqdm.tqdm(
range(start_step, config.num_train_steps),
initial=start_step,
total=config.num_train_steps,
dynamic_ncols=True,
)
infos = []
for step in pbar:
train_state, info = ptrain_step(train_rng, train_state, model, batch)
infos.append(info)
if step % config.log_interval == 0:
stacked_infos = common_utils.stack_forest(infos)
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
pbar.write(f"Step {step}: {info_str}")
wandb.log(reduced_info, step=step)
infos = []
batch = next(data_iter)
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
logging.info("Waiting for checkpoint manager to finish")
checkpoint_manager.wait_until_finished()
if __name__ == "__main__":
main(_config.cli())