multi-node openpi commit

This commit is contained in:
Leon998
2026-03-17 23:05:23 +08:00
parent 28833f0c0f
commit 7411e0e004
156 changed files with 33951 additions and 1 deletions

View File

@@ -0,0 +1,159 @@
from __future__ import annotations
import asyncio
import concurrent.futures as futures
import dataclasses
import logging
from typing import Protocol
from etils import epath
import jax
import orbax.checkpoint as ocp
import orbax.checkpoint.future as future
from openpi.shared import array_typing as at
import openpi.shared.normalize as _normalize
import openpi.training.data_loader as _data_loader
import openpi.training.utils as training_utils
def initialize_checkpoint_dir(
checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool
) -> tuple[ocp.CheckpointManager, bool]:
checkpoint_dir = epath.Path(checkpoint_dir).resolve()
resuming = False
if checkpoint_dir.exists():
if overwrite:
checkpoint_dir.rmtree()
checkpoint_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Wiped checkpoint directory {checkpoint_dir}")
elif resume:
resuming = True
else:
raise FileExistsError(
f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume "
"to indicate how to handle it."
)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
mngr = ocp.CheckpointManager(
checkpoint_dir,
item_handlers={
"assets": CallbackHandler(),
"train_state": ocp.PyTreeCheckpointHandler(),
"params": ocp.PyTreeCheckpointHandler(),
},
options=ocp.CheckpointManagerOptions(
max_to_keep=1,
keep_period=keep_period,
create=False,
async_options=ocp.AsyncOptions(timeout_secs=7200),
),
)
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
# checkpoint, since it will fail.
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
resuming = False
return mngr, resuming
def save_state(
checkpoint_manager: ocp.CheckpointManager,
state: training_utils.TrainState,
data_loader: _data_loader.DataLoader,
step: int,
):
def save_assets(directory: epath.Path):
# Save the normalization stats.
data_config = data_loader.data_config()
norm_stats = data_config.norm_stats
if norm_stats is not None and data_config.asset_id is not None:
_normalize.save(directory / data_config.asset_id, norm_stats)
# Split params that can be used for inference into a separate item.
with at.disable_typechecking():
train_state, params = _split_params(state)
items = {
"assets": save_assets,
"train_state": train_state,
"params": {"params": params},
}
checkpoint_manager.save(step, items)
def restore_state(
checkpoint_manager: ocp.CheckpointManager,
state: training_utils.TrainState,
data_loader: _data_loader.DataLoader,
step: int | None = None,
) -> training_utils.TrainState:
del data_loader
with at.disable_typechecking():
# Split params that can be used for inference into a separate item.
train_state, params = _split_params(state)
restored = checkpoint_manager.restore(
step,
items={
"train_state": train_state,
"params": {"params": params},
},
)
return _merge_params(restored["train_state"], restored["params"])
def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None:
norm_stats_dir = epath.Path(assets_dir) / asset_id
norm_stats = _normalize.load(norm_stats_dir)
logging.info(f"Loaded norm stats from {norm_stats_dir}")
return norm_stats
class Callback(Protocol):
def __call__(self, directory: epath.Path) -> None: ...
class CallbackHandler(ocp.AsyncCheckpointHandler):
"""A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
def save(self, directory: epath.Path, args: CallbackSave):
if jax.process_index() == 0:
args.callback(directory)
async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:
return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]
def restore(self, *args, **kwargs):
raise NotImplementedError("CallbackHandler does not support restore")
@ocp.args.register_with_handler(CallbackHandler, for_save=True)
@dataclasses.dataclass
class CallbackSave(ocp.args.CheckpointArgs):
callback: Callback
@ocp.args.register_with_handler(CallbackHandler, for_restore=True)
class CallbackRestore(ocp.args.CheckpointArgs): ...
def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]:
if state.ema_params is not None:
params = state.ema_params
train_state = dataclasses.replace(state, ema_params=None)
else:
params = state.params
train_state = dataclasses.replace(state, params={})
return train_state, params
def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState:
# Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split.
if train_state.params:
return dataclasses.replace(train_state, ema_params=params["params"])
return dataclasses.replace(train_state, params=params["params"])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,721 @@
from collections.abc import Iterator, Sequence
import logging
import multiprocessing
import os
import typing
from typing import Literal, Protocol, SupportsIndex, TypeVar, Dict
import sys
import jax
import jax.numpy as jnp
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
import numpy as np
import torch
from torch.utils.data import dataloader
from torch.multiprocessing import reductions
from multiprocessing.reduction import ForkingPickler
import openpi.shared.normalize as normalize
default_collate_func = dataloader.default_collate
import psutil
def default_collate_override(batch):
dataloader._use_shared_memory = False
return default_collate_func(batch)
setattr(dataloader, 'default_collate', default_collate_override)
for t in torch._storage_classes:
if sys.version_info[0] == 2:
if t in ForkingPickler.dispatch:
del ForkingPickler.dispatch[t]
else:
if t in ForkingPickler._extra_reducers:
del ForkingPickler._extra_reducers[t]
import openpi.models.model as _model
import openpi.training.config as _config
from openpi.training.droid_rlds_dataset import DroidRldsDataset
import openpi.transforms as _transforms
from openpi.training.mixture_dataset import create_mixture_dataset
T_co = TypeVar("T_co", covariant=True)
import copy
from memory_profiler import profile
from pdb import set_trace
class Dataset(Protocol[T_co]):
"""Interface for a dataset with random access."""
def __getitem__(self, index: SupportsIndex) -> T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
def __len__(self) -> int:
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
class IterableDataset(Protocol[T_co]):
"""Interface for an iterable dataset."""
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.")
def __len__(self) -> int:
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
class DataLoader(Protocol[T_co]):
"""Interface for a data loader."""
def data_config(self) -> _config.DataConfig:
"""Get the data config for this data loader."""
raise NotImplementedError("Subclasses of DataLoader should implement data_config.")
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError("Subclasses of DataLoader should implement __iter__.")
class TransformedDataset(Dataset[T_co]):
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
self._dataset = dataset
self._transform = _transforms.compose(transforms)
def __getitem__(self, index: SupportsIndex) -> T_co:
return self._transform(self._dataset[index])
def __len__(self) -> int:
return len(self._dataset)
class IterableTransformedDataset(IterableDataset[T_co]):
def __init__(
self,
dataset: IterableDataset,
transforms: Sequence[_transforms.DataTransformFn],
*,
is_batched: bool = False,
):
self._dataset = dataset
self._transform = _transforms.compose(transforms)
self._is_batched = is_batched
def __iter__(self):
for sample in self._dataset:
if self._is_batched:
# Transforms are designed to be applied to individual samples. So we need to split the batch into
# individual samples and apply the transform to each sample individually.
batch_size = next(v.shape[0] for v in sample.values())
# Split batch into individual samples using tree_map
individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023
# Transform each sample
transformed = [self._transform(s) for s in individual_samples]
# Recombine batch with tree_map
yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed)
else:
yield self._transform(sample)
def __len__(self) -> int:
return len(self._dataset)
class FakeDataset(Dataset):
def __init__(self, model_config: _model.BaseModelConfig, num_samples: int):
self._num_samples = num_samples
self._observation_spec, self._action_spec = model_config.inputs_spec()
def __getitem__(self, index: SupportsIndex) -> dict:
rng = jax.random.key(index.__index__())
def make_from_spec(spec: jax.ShapeDtypeStruct):
nonlocal rng
rng, data_rng = jax.random.split(rng)
# Remove the batch dimension.
shape = spec.shape[1:]
if spec.dtype == jnp.float32:
return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0)
if spec.dtype == jnp.int32:
return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048)
return jnp.zeros(shape=shape, dtype=spec.dtype)
observation = jax.tree.map(make_from_spec, self._observation_spec)
action = jax.tree.map(make_from_spec, self._action_spec)
return {
**observation.to_dict(),
"actions": action,
}
def __len__(self) -> int:
return self._num_samples
def create_torch_dataset(
data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
) -> Dataset:
"""Create a dataset for training."""
repo_id = data_config.repo_id
if repo_id is None:
raise ValueError("Repo ID is not set. Cannot create dataset.")
if repo_id == "fake":
return FakeDataset(model_config, num_samples=1024)
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
dataset = lerobot_dataset.LeRobotDataset(
data_config.repo_id,
delta_timestamps={
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
},
)
if data_config.prompt_from_task:
dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])
return dataset
def create_rlds_dataset(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
*,
shuffle: bool = False,
) -> Dataset:
# At the moment, we only support DROID for RLDS datasets.
return DroidRldsDataset(
data_dir=data_config.rlds_data_dir,
batch_size=batch_size,
shuffle=shuffle,
action_chunk_size=action_horizon,
action_space=data_config.action_space,
filter_dict_path=data_config.filter_dict_path,
)
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset:
"""Transform the dataset by applying the data transforms."""
norm_stats = {}
if data_config.repo_id != "fake" and not skip_norm_stats:
if data_config.norm_stats is None:
raise ValueError(
"Normalization stats not found. "
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
)
norm_stats = data_config.norm_stats
return TransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
)
def transform_iterable_dataset(
dataset: IterableDataset,
data_config: _config.DataConfig,
*,
skip_norm_stats: bool = False,
is_batched: bool = False,
) -> IterableDataset:
"""Transform the dataset by applying the data transforms."""
norm_stats = {}
if data_config.repo_id != "fake" and not skip_norm_stats:
if data_config.norm_stats is None:
raise ValueError(
"Normalization stats not found. "
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
)
norm_stats = data_config.norm_stats
return IterableTransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
is_batched=is_batched,
)
def create_data_loader(
config: _config.TrainConfig,
*,
sharding: jax.sharding.Sharding | None = None,
shuffle: bool = False,
num_batches: int | None = None,
skip_norm_stats: bool = False,
framework: Literal["jax", "pytorch"] = "jax",
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
config: The training configuration.
sharding: The sharding to use for the data loader (JAX only).
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return.
skip_norm_stats: Whether to skip data normalization.
framework: The framework to use ("jax" or "pytorch").
"""
data_config = config.data.create(config.assets_dirs, config.model)
logging.info(f"data_config: {data_config}")
if data_config.rlds_data_dir is not None:
return create_rlds_data_loader(
data_config,
action_horizon=config.model.action_horizon,
batch_size=config.batch_size,
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
skip_norm_stats=skip_norm_stats,
framework=framework,
)
return create_torch_data_loader(
data_config,
model_config=config.model,
action_horizon=config.model.action_horizon,
batch_size=config.batch_size,
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
num_workers=config.num_workers,
seed=config.seed,
skip_norm_stats=skip_norm_stats,
framework=framework,
)
def create_data_loader_multi(
config: _config.TrainConfig,
*,
sharding: jax.sharding.Sharding | None = None,
shuffle: bool = False,
num_batches: int | None = None,
skip_norm_stats: bool = False,
framework: Literal["jax", "pytorch"] = "jax",
global_norm_stats: Dict[str, normalize.NormStats] | None = None,
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
config: The training configuration.
sharding: The sharding to use for the data loader (JAX only).
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return.
skip_norm_stats: Whether to skip data normalization.
framework: The framework to use ("jax" or "pytorch").
"""
data_configs_list = []
for data_config_factory in config.data:
data_configs = data_config_factory.create(config.model, global_norm_stats)
logging.info(f"data_config: {data_configs}")
data_configs_list.append(data_configs)
return create_torch_data_loader_multi(
data_configs_list,
model_config=config.model,
action_horizon=config.model.action_horizon,
batch_size=config.batch_size,
sharding=sharding,
shuffle=shuffle,
num_batches=num_batches,
num_workers=config.num_workers,
seed=config.seed,
skip_norm_stats=skip_norm_stats,
framework=framework,
global_norm_stats=global_norm_stats,
)
def create_torch_data_loader(
data_config: _config.DataConfig,
model_config: _model.BaseModelConfig,
action_horizon: int,
batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
skip_norm_stats: bool = False,
shuffle: bool = False,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
framework: str = "jax",
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
data_config: The data configuration.
action_horizon: The action horizon.
batch_size: The batch size.
sharding: The sharding to use for the data loader. If None, the data loader will
use a single device sharding.
skip_norm_stats: Whether to skip data normalization.
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return. If the number exceeds the
number of batches in the dataset, the data loader will loop over the dataset.
If not provided, will iterate over the dataset indefinitely.
num_workers: The number of worker processes to use. If zero, the data loader will
execute in the main process.
seed: The seed to use for shuffling the data.
"""
dataset = create_torch_dataset(data_config, action_horizon, model_config)
dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats)
# Use TorchDataLoader for both frameworks
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
# For JAX, divide by process count
sampler = None
if framework == "pytorch":
if torch.distributed.is_initialized():
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
shuffle=shuffle,
drop_last=True,
)
local_batch_size = batch_size // torch.distributed.get_world_size()
else:
local_batch_size = batch_size
else:
local_batch_size = batch_size // jax.process_count()
if jax.process_count() > 1:
sampler = JaxProcessDistributedSampler(
dataset_size=len(dataset),
num_replicas=jax.process_count(),
rank=jax.process_index(),
shuffle=shuffle,
seed=seed,
)
logging.info(f"local_batch_size: {local_batch_size}")
data_loader = TorchDataLoader(
dataset,
local_batch_size=local_batch_size,
sharding=None if framework == "pytorch" else sharding,
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
sampler=sampler,
num_batches=num_batches,
num_workers=num_workers,
seed=seed,
framework=framework,
)
return DataLoaderImpl(data_config, data_loader)
def create_torch_data_loader_multi(
data_configs_list: list[_config.DataConfig],
model_config: _model.BaseModelConfig,
action_horizon: int,
batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
skip_norm_stats: bool = False,
shuffle: bool = False,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
framework: str = "jax",
global_norm_stats: Dict[str, normalize.NormStats] | None = None,
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create a data loader for training.
Args:
data_config: The data configuration.
action_horizon: The action horizon.
batch_size: The batch size.
sharding: The sharding to use for the data loader. If None, the data loader will
use a single device sharding.
skip_norm_stats: Whether to skip data normalization.
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return. If the number exceeds the
number of batches in the dataset, the data loader will loop over the dataset.
If not provided, will iterate over the dataset indefinitely.
num_workers: The number of worker processes to use. If zero, the data loader will
execute in the main process.
seed: The seed to use for shuffling the data.
"""
dataset = create_mixture_dataset(data_configs_list, action_horizon, model_config)
# Use TorchDataLoader for both frameworks
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
# For JAX, divide by process count
sampler = None
if framework == "pytorch":
if torch.distributed.is_initialized():
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
shuffle=shuffle,
drop_last=True,
)
local_batch_size = batch_size // torch.distributed.get_world_size()
else:
local_batch_size = batch_size
else:
local_batch_size = batch_size // jax.process_count()
if jax.process_count() > 1:
sampler = JaxProcessDistributedSampler(
dataset_size=len(dataset),
num_replicas=jax.process_count(),
rank=jax.process_index(),
shuffle=shuffle,
seed=seed,
)
logging.info(f"local_batch_size: {local_batch_size}")
data_loader = TorchDataLoader(
dataset,
local_batch_size=local_batch_size,
sharding=None if framework == "pytorch" else sharding,
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
sampler=sampler,
num_batches=num_batches,
num_workers=num_workers,
seed=seed,
framework=framework,
)
return DataLoaderImpl(data_configs_list[0][0], data_loader)
def create_rlds_data_loader(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
skip_norm_stats: bool = False,
shuffle: bool = False,
num_batches: int | None = None,
framework: str = "jax",
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
"""Create an RLDS data loader for training.
Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md
Args:
data_config: The data configuration.
action_horizon: The action horizon.
batch_size: The batch size.
sharding: The sharding to use for the data loader. If None, the data loader will
use a single device sharding.
skip_norm_stats: Whether to skip data normalization.
shuffle: Whether to shuffle the data.
num_batches: Determines the number of batches to return. If the number exceeds the
number of batches in the dataset, the data loader will loop over the dataset.
If not provided, will iterate over the dataset indefinitely.
"""
if framework == "pytorch":
raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
data_loader = RLDSDataLoader(
dataset,
sharding=sharding,
num_batches=num_batches,
)
return DataLoaderImpl(data_config, data_loader)
class JaxProcessDistributedSampler(torch.utils.data.Sampler[int]):
"""Simple sampler to split dataset indices across JAX processes.
Each process sees a disjoint slice of indices using striding by num_replicas.
Shuffling (if enabled) is deterministic via the provided seed.
"""
def __init__(
self,
dataset_size: int,
*,
num_replicas: int,
rank: int,
shuffle: bool,
seed: int,
) -> None:
self._dataset_size = max(0, dataset_size)
self._num_replicas = max(1, num_replicas)
self._rank = max(0, rank)
self._shuffle = shuffle
self._seed = seed
def __iter__(self):
indices = list(range(self._dataset_size))
if self._shuffle and self._dataset_size > 0:
g = torch.Generator()
g.manual_seed(self._seed)
indices = torch.randperm(self._dataset_size, generator=g).tolist()
# Strided split across processes; drop remainder for balance
indices = indices[self._rank :: self._num_replicas]
return iter(indices)
def __len__(self) -> int:
# Match strided selection length
return (self._dataset_size + self._num_replicas - 1) // self._num_replicas
# @profile
class TorchDataLoader:
"""Torch data loader implementation."""
def __init__(
self,
dataset,
local_batch_size: int,
*,
sharding: jax.sharding.Sharding | None = None,
shuffle: bool = False,
sampler: torch.utils.data.Sampler | None = None,
num_batches: int | None = None,
num_workers: int = 0,
seed: int = 0,
framework: str = "jax",
):
"""Create a PyTorch data loader.
Args:
dataset: The dataset to load.
local_batch_size: The local batch size for each process.
sharding: The sharding to use for the data loader.
shuffle: Whether to shuffle the data.
num_batches: If provided, determines the number of returned batches. If the
number is larger than the number of batches in the dataset, the data loader
will loop over the dataset. If not provided, will iterate over the dataset
indefinitely.
num_workers: The number of worker processes to use. If zero, the data loader will
execute in the main process.
seed: The seed to use for shuffling the data.
"""
if len(dataset) < local_batch_size:
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
# Store sharding - None for PyTorch, JAX sharding for JAX
self._sharding = sharding
if sharding is None and framework == "jax":
# Use data parallel sharding by default for JAX only.
self._sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ("B",)),
jax.sharding.PartitionSpec("B"),
)
self._num_batches = num_batches
mp_context = None
if num_workers > 0:
mp_context = multiprocessing.get_context("spawn")
generator = torch.Generator()
generator.manual_seed(seed)
self._data_loader = torch.utils.data.DataLoader(
typing.cast(torch.utils.data.Dataset, dataset),
batch_size=local_batch_size,
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
sampler=sampler,
num_workers=num_workers,
multiprocessing_context=mp_context,
persistent_workers=num_workers > 0,
collate_fn=_collate_fn,
worker_init_fn=_worker_init_fn,
drop_last=True,
generator=generator,
pin_memory=False,
)
@property
def torch_loader(self) -> torch.utils.data.DataLoader:
return self._data_loader
@profile
def __iter__(self):
num_items = 0
while True:
data_iter = iter(self._data_loader)
while True:
if self._num_batches is not None and num_items >= self._num_batches:
return
try:
batch = next(data_iter)
except StopIteration:
break # We've exhausted the dataset. Create a new iterator and start over.
num_items += 1
# For JAX, convert to sharded arrays; for PyTorch, return torch tensors
if self._sharding is not None:
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
else:
yield jax.tree.map(torch.as_tensor, batch)
def _collate_fn(items):
"""Collate the batch elements into batched numpy arrays."""
# Make sure to convert to numpy arrays before stacking since some of the incoming elements
# may be JAX arrays.
return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)
def _worker_init_fn(worker_id: int) -> None:
"""Tell JAX inside the worker process not to preallocate the GPU memory."""
# NOTE: This is called after jax is imported inside the worker process. This
# means that this approach will not work for selecting the backend.
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
class RLDSDataLoader:
"""Shallow wrapper around the DROID data loader to make it compatible with openpi.
All batching already happens in the DROID dataset, so we don't need to do anything here.
"""
def __init__(
self,
dataset: DroidRldsDataset,
*,
sharding: jax.sharding.Sharding | None = None,
num_batches: int | None = None,
):
self._dataset = dataset
self._num_batches = num_batches
if jax.process_count() > 1:
raise NotImplementedError("Data loading with multiple processes is not supported.")
if sharding is None:
# Use data parallel sharding by default.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ("B",)),
jax.sharding.PartitionSpec("B"),
)
self._sharding = sharding
self._num_batches = num_batches
def __iter__(self):
num_items = 0
while True:
data_iter = iter(self._dataset)
while True:
if self._num_batches is not None and num_items >= self._num_batches:
return
try:
batch = next(data_iter)
except StopIteration:
break # We've exhausted the dataset. Create a new iterator and start over.
num_items += 1
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
class DataLoaderImpl(DataLoader):
def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):
self._data_config = data_config
self._data_loader = data_loader
def data_config(self) -> _config.DataConfig:
return self._data_config
def __iter__(self):
for batch in self._data_loader:
yield _model.Observation.from_dict(batch), batch["actions"]

View File

@@ -0,0 +1,221 @@
"""
RLDS-based data loader for DROID.
While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID.
Thus, we provide a data loader example here that uses the RLDS data format.
The data loader also applies a few DROID-specific data filters / transformations.
"""
from enum import Enum
from enum import auto
import json
import logging
from pathlib import Path
import tqdm
import openpi.shared.download as download
class DroidActionSpace(Enum):
"""Action space for DROID dataset."""
JOINT_POSITION = auto()
JOINT_VELOCITY = auto()
class DroidRldsDataset:
def __init__(
self,
data_dir: str,
batch_size: int,
*, # Force keyword-only arguments
shuffle: bool = True,
action_chunk_size: int = 16,
# We default to joint position actions, since they allow policy evaluation in simulation.
action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,
max_loaded_steps_per_episode: int = 100,
# Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.
shuffle_buffer_size: int = 250_000,
num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
filter_dict_path=None, # Path to json file with indices to sample during training
):
# Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
import dlimp as dl
import tensorflow as tf
import tensorflow_datasets as tfds
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)
tf.config.set_visible_devices([], "GPU")
builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1")
dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads)
# Filter out any unsuccessful trajectories -- we use the file name to check this
dataset = dataset.filter(
lambda traj: tf.strings.regex_full_match(
traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*"
)
)
# # Repeat dataset so we never run out of data.
dataset = dataset.repeat()
# Load the filter dictionary if provided.
# The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample
# (e.g.,
# {
# "<episode key>": [[0, 100], [200, 300]]
# }
# means keep frames 0-99 and 200-299).
if filter_dict_path is not None:
cached_filter_dict_path = download.maybe_download(filter_dict_path)
with Path(cached_filter_dict_path).open("r") as f:
filter_dict = json.load(f)
logging.info(f"Using filter dictionary with {len(filter_dict)} episodes")
keys_tensor = []
values_tensor = []
for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."):
for start, end in ranges:
for t in range(start, end):
frame_key = f"{episode_key}--{t}"
keys_tensor.append(frame_key)
values_tensor.append(True)
self.filter_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False
)
logging.info("Filter hash table initialized")
else:
self.filter_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True
)
def restructure(traj):
"""Reformat observation and action keys, sample language instruction."""
# Important: we use joint *position* action space -- easier to simulate!
actions = tf.concat(
(
(
traj["action_dict"]["joint_position"]
if action_space == DroidActionSpace.JOINT_POSITION
else traj["action_dict"]["joint_velocity"]
),
traj["action_dict"]["gripper_position"],
),
axis=-1,
)
# Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).
# Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera.
exterior_img = tf.cond(
tf.random.uniform(shape=[]) > 0.5,
lambda: traj["observation"]["exterior_image_1_left"],
lambda: traj["observation"]["exterior_image_2_left"],
)
wrist_img = traj["observation"]["wrist_image_left"]
# Randomly sample one of the three language instructions
instruction = tf.random.shuffle(
[traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]]
)[0]
traj_len = tf.shape(traj["action"])[0]
indices = tf.as_string(tf.range(traj_len))
# Data filtering:
# Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path,
# and each step's time step index. This will index into the filter hash table, and if it returns true,
# then the frame passes the filter.
step_id = (
traj["traj_metadata"]["episode_metadata"]["recording_folderpath"]
+ "--"
+ traj["traj_metadata"]["episode_metadata"]["file_path"]
+ "--"
+ indices
)
passes_filter = self.filter_table.lookup(step_id)
return {
"actions": actions,
"observation": {
"image": exterior_img,
"wrist_image": wrist_img,
"joint_position": traj["observation"]["joint_position"],
"gripper_position": traj["observation"]["gripper_position"],
},
"prompt": instruction,
"step_id": step_id,
"passes_filter": passes_filter,
}
dataset = dataset.traj_map(restructure, num_parallel_calls)
def chunk_actions(traj):
"""Splits episode into action chunks."""
traj_len = tf.shape(traj["actions"])[0]
# For each step in the trajectory, construct indices for the next n actions
action_chunk_indices = tf.broadcast_to(
tf.range(action_chunk_size)[None],
[traj_len, action_chunk_size],
) + tf.broadcast_to(
tf.range(traj_len)[:, None],
[traj_len, action_chunk_size],
)
# Cap to length of the sequence --> final chunks will repeat the last action
# This makes sense, since we are using absolute joint + gripper position actions
action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)
# Gather the actions for each chunk
traj["actions"] = tf.gather(traj["actions"], action_chunk_indices)
return traj
dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
# Flatten: map from trajectory dataset to dataset of individual action chunks
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
# Filter data that doesn't pass the filter
def filter_from_dict(frame):
return frame["passes_filter"]
dataset = dataset.filter(filter_from_dict)
# Remove "passes_filter" key from output
def remove_passes_filter(frame):
frame.pop("passes_filter")
return frame
dataset = dataset.map(remove_passes_filter)
# Decode images: RLDS saves encoded images, only decode now for efficiency
def decode_images(traj):
traj["observation"]["image"] = tf.io.decode_image(
traj["observation"]["image"], expand_animations=False, dtype=tf.uint8
)
traj["observation"]["wrist_image"] = tf.io.decode_image(
traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8
)
return traj
dataset = dataset.frame_map(decode_images, num_parallel_calls)
# Shuffle, batch
dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.batch(batch_size)
# Note =>> Seems to reduce memory usage without affecting speed?
dataset = dataset.with_ram_budget(1)
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
yield from self.dataset.as_numpy_iterator()
def __len__(self):
# This is the approximate number of samples in DROID after filtering.
# Easier to hardcode than to iterate through the dataset and compute it.
return 20_000_000

View File

@@ -0,0 +1,116 @@
"""RoboArena baseline policy configs."""
from typing import TypeAlias
import openpi.models.model as _model
import openpi.models.pi0_config as pi0_config
import openpi.models.pi0_fast as pi0_fast
import openpi.models.tokenizer as _tokenizer
import openpi.policies.droid_policy as droid_policy
import openpi.transforms as _transforms
ModelType: TypeAlias = _model.ModelType
def get_roboarena_configs():
# Import here to avoid circular imports.
from openpi.training.config import AssetsConfig
from openpi.training.config import DataConfig
from openpi.training.config import SimpleDataConfig
from openpi.training.config import TrainConfig
return [
#
# RoboArena DROID baseline inference configs.
#
TrainConfig(
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
name="paligemma_binning_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
max_token_len=400,
fast_model_tokenizer=_tokenizer.BinningTokenizer,
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
name="paligemma_fast_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
name="paligemma_fast_specialist_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
fast_model_tokenizer=_tokenizer.FASTTokenizer,
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# Trained from PaliGemma, using FSQ tokenizer.
name="paligemma_vq_droid",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=15,
fast_model_tokenizer=_tokenizer.FSQTokenizer,
fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"},
),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
name="paligemma_diffusion_droid",
model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
]

View File

@@ -0,0 +1,703 @@
import numpy as np
from dataclasses import dataclass
from typing import SupportsIndex, Sequence, List, Dict, Any, Tuple, Optional, Union, TypeVar, Protocol
import torch
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
import openpi.transforms as _transforms
from pdb import set_trace
import logging
T_co = TypeVar("T_co", covariant=True)
import openpi.training.config as _config
import openpi.shared.normalize as normalize
def detect_gripper_change_step(
dataset,
select_actions: list[str] = ["action"],
gripper_dim: int = -1,
threshold_method: str = "std_multiplier",
threshold_multiplier: float = 2.0,
min_threshold: float = 0.001,
max_threshold: float = 1.0,
plot_gripper_changes: bool = False,
):
"""
Detect the step of gripper change. Only work for the self-collected dataset.
Modifies the dataset in place by adding 'gripper_change_step_idx' attribute.
This version uses a sliding window of size 4 centered around non_zero_idx,
including the indices and removing duplicates.
Args:
dataset: LeRobotDataset instance
select_actions: List of action keys to process
gripper_dim: Dimension index for gripper in the action vector
threshold_method: Method to calculate threshold ('std_multiplier', 'percentile', 'absolute')
threshold_multiplier: Multiplier for std-based threshold
min_threshold: Minimum threshold value to avoid too sensitive detection
max_threshold: Maximum threshold value to avoid missing large changes
plot_gripper_changes: Whether to plot gripper changes visualization
"""
episode_lengths = [ep_dict["length"] for ep_dict in dataset.meta.episodes.values()]
cumulative_lengths = np.cumsum(episode_lengths)
all_window_indices = set() # Use a set for automatic deduplication
for action_key in select_actions:
action_values = dataset.hf_dataset[action_key]
delta_action = np.diff(action_values, axis=0)
# Handle episode boundaries
for end_idx in cumulative_lengths[:-1]:
if end_idx - 1 < len(delta_action) and end_idx - 2 >= 0:
delta_action[end_idx - 1] = delta_action[end_idx - 2]
elif end_idx - 1 < len(delta_action):
delta_action[end_idx - 1] = 0
if delta_action.ndim == 1:
delta_action = delta_action[:, np.newaxis]
assert delta_action.ndim == 2
# Extract gripper delta values
gripper_delta = delta_action[:, gripper_dim]
# Calculate threshold based on statistical properties
if threshold_method == "std_multiplier":
# Use standard deviation to filter out small tremors
std_val = np.std(gripper_delta)
threshold = threshold_multiplier * std_val
elif threshold_method == "percentile":
# Use percentile-based threshold (e.g., 90th percentile)
threshold = np.percentile(np.abs(gripper_delta), 85)
elif threshold_method == "absolute":
# Use absolute threshold
threshold = threshold_multiplier
else:
raise ValueError(f"Unknown threshold_method: {threshold_method}")
# Clamp threshold to reasonable bounds
threshold = np.clip(threshold, min_threshold, max_threshold)
# Find indices where gripper change exceeds threshold
significant_change_idx = np.where(np.abs(gripper_delta) > threshold)[0]
cur_window_indices = set()
for idx in significant_change_idx:
# Create a sliding window of size 4 centered around idx.
# The window should include [idx-2, idx-1, idx, idx+1].
# This means starting 2 before and ending 1 after.
window_start = idx - 2
window_end = idx + 1
# Generate indices for the current window and ensure they are non-negative
# and within the bounds of the original action_values length.
# The maximum index possible is len(action_values) - 1.
# Since delta_action is len(action_values) - 1, the index refers to
# the step *before* the change. So the max index we want is effectively
# len(action_values) - 1, which corresponds to the last valid step index.
# If the original index is `i`, delta_action[i] corresponds to the change
# from step `i` to `i+1`. We want to include step `i` and its neighbors.
# The maximum index for steps is `len(action_values) - 1`.
# So, the window indices should not exceed `len(action_values) - 1`.
max_possible_idx = len(action_values) - 1
# Ensure indices are within valid range [0, max_possible_idx]
current_window_indices = np.arange(
max(0, window_start), min(max_possible_idx + 1, window_end + 1)
)
for w_idx in current_window_indices:
cur_window_indices.add(w_idx)
all_window_indices.add(w_idx)
if plot_gripper_changes:
num_episodes_to_plot = 5
end_index_for_plot = cumulative_lengths[num_episodes_to_plot - 1] - 1
delta_action_to_plot = delta_action[:end_index_for_plot]
# Filter gripper_change_step_idx
gripper_change_step_idx = np.array(sorted(list(cur_window_indices))).astype(np.int32)
gripper_change_step_idx_to_plot = gripper_change_step_idx[gripper_change_step_idx < end_index_for_plot]
plot_gripper_changes_in_subplots(
delta_action_to_plot,
gripper_change_step_idx_to_plot,
episode_lengths,
num_episodes_to_plot,
gripper_dim,
f"{action_key}_gripper_change"
)
# Convert the set to a numpy array and sort it
gripper_change_step_idx = np.array(sorted(list(all_window_indices))).astype(np.int32)
print(f"Total unique gripper change steps: {len(gripper_change_step_idx)}, Total steps: {len(action_values)}")
dataset.gripper_change_step_idx = gripper_change_step_idx
# set_trace()
return dataset
class Dataset(Protocol[T_co]):
"""Interface for a dataset with random access."""
def __getitem__(self, index: SupportsIndex) -> T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
def __len__(self) -> int:
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
class TransformedDataset(Dataset[T_co]):
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
self._dataset = dataset
self._transform = _transforms.compose(transforms)
def __getitem__(self, index: SupportsIndex) -> T_co:
return self._transform(self._dataset[index])
def __len__(self) -> int:
return len(self._dataset)
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig) -> Dataset:
"""Transform the dataset by applying the data transforms."""
norm_stats = {}
norm_stats = data_config.norm_stats
return TransformedDataset(
dataset,
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
*data_config.model_transforms.inputs,
],
)
class MixtureDataset(Dataset):
"""
A composite dataset that combines multiple datasets, allowing for weighted sampling
and specific handling based on training stage (e.g., pretrain, finetune) and
gripper change detection for augmentation.
This dataset flattens all eligible samples from its constituent datasets and assigns
sampling weights based on configuration and heuristics (e.g., `gripper_aug_ratio`).
"""
def __init__(
self,
datasets: Sequence[Dataset],
datasets_name: Sequence[str],
datasets_meta: Sequence[LeRobotDatasetMetadata],
datasets_weights: Dict[str, float] = None,
gripper_aug_ratio: float = 1.0,
shuffle: bool = True,
):
"""
Initializes the MixtureDataset.
Args:
datasets (Sequence[Dataset]): A list of `Dataset` objects to be combined.
datasets_name (Sequence[str]): A list of names corresponding to each dataset in `datasets`.
datasets_meta (Sequence[LeRobotDatasetMetadata]): Metadata for each dataset,
typically containing `num_episodes`, `num_frames`, `fps`, and `num_indices`.
datasets_weights (Dict[str, float], optional): A dictionary mapping dataset names
to their base sampling weights. If None, equal weights are assumed.
is_eval (bool): If True, the dataset is configured for evaluation, potentially
limiting the number of episodes and disabling shuffling for reproducibility.
num_eval_episodes (int, optional): The number of episodes to select for evaluation.
Only used if `is_eval` is True.
stage (str): The current training stage (e.g., "stage1_pretrain_wm").
This affects how indices are sampled from the underlying datasets.
gripper_aug_ratio (float): A multiplier applied to the weights of samples
that contain a detected gripper change. Useful for augmenting rare events.
shuffle (bool): If True, the flat sample map and sampling weights are shuffled
after initial creation. Ignored if `is_eval` is True.
"""
self.datasets = datasets
self.datasets_name = datasets_name
self.meta = datasets_meta
# Extract total number of episodes and frames for each dataset from metadata.
self.num_episodes = [meta.info['total_episodes'] for meta in datasets_meta]
self.num_frames = [meta.info['total_frames'] for meta in datasets_meta]
# Compute the flattened list of (dataset_idx, sample_idx) pairs.
# This involves sampling indices based on the stage and dataset type.
self._compute_len(False)
# Assign normalized sampling weights to each sample in the flattened map.
self._get_weights(datasets_weights, gripper_aug_ratio)
# For training, ensure the sample map and weights are consistent.
if len(self.flat_sample_map) != len(self.sample_weights):
raise ValueError(
f"Mismatch in flat sample map length ({len(self.flat_sample_map)}) "
f"and sample weights length ({len(self.sample_weights)})."
)
if shuffle:
# Shuffle both the sample map and weights in the same order for training.
# This ensures random access to samples while maintaining their assigned probabilities.
indices = np.random.permutation(len(self.flat_sample_map))
self.flat_sample_map = [self.flat_sample_map[i] for i in indices]
self.sample_weights = self.sample_weights[indices]
def __len__(self) -> int:
"""
Returns the total number of samples in the mixture dataset (after flattening and selection).
This length represents the effective size of the dataset for iteration.
"""
return len(self.flat_sample_map)
def __getitem__(self, index: SupportsIndex):
"""
Retrieves a specific sample from one of the underlying datasets based on the
flattened sample map.
Args:
index (SupportsIndex): The index in the flattened `flat_sample_map` (0 to `len(self) - 1`).
Returns:
Tuple[int, Any]: A tuple containing the original dataset index and the
sample data (dictionary) from that dataset.
Raises:
IndexError: If the provided index is out of bounds for the dataset.
"""
if not (0 <= index < len(self.flat_sample_map)):
raise IndexError(f"Index {index} is out of bounds for the dataset (size: {len(self.flat_sample_map)}).")
# Retrieve the original dataset index and sample index from the flattened map.
dataset_idx, sample_idx = self.flat_sample_map[index]
return self.datasets[dataset_idx][sample_idx]
def _compute_len(self, is_eval: bool = False):
"""
Pre-computes and stores `all_sample_indices`, a list of episode indices sampled
from each constituent dataset. This method prepares the data for `_create_flat_sample_map`.
Args:
is_eval (bool): Flag indicating if indices are being computed for an evaluation dataset.
"""
self.all_sample_indices: List[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = []
for i, (ds, meta) in enumerate(zip(self.datasets, self.meta)):
# Access the underlying LeRobotDataset or MultiLeRobotDataset, bypassing TransformedDataset wrapper.
actual_ds = ds._dataset if isinstance(ds, TransformedDataset) else ds
# Determine the number of indices to sample for this dataset based on the current stage.
# "stage1" typically uses a limited number of indices (`num_indices`), while other stages
# might use all available data or a different strategy.
num_indices = None
if isinstance(actual_ds, MultiLeRobotDataset):
# For MultiLeRobotDataset, iterate through its sub-datasets to get indices.
indices_list_for_multi_ds = []
for sub_ds in actual_ds._datasets:
_from = sub_ds.episode_data_index["from"]
_to = sub_ds.episode_data_index["to"]
indices = self._sample_indices(
_from, _to, num_indices, is_eval=is_eval, dataset_name=self.datasets_name[i]
)
indices_list_for_multi_ds.append(indices)
self.all_sample_indices.append(indices_list_for_multi_ds)
elif isinstance(actual_ds, LeRobotDataset):
# For a single LeRobotDataset.
_from = actual_ds.episode_data_index["from"]
_to = actual_ds.episode_data_index["to"]
indices = self._sample_indices(
_from, _to, num_indices, is_eval=is_eval, dataset_name=self.datasets_name[i]
)
self.all_sample_indices.append(indices)
else:
raise TypeError(f"Unsupported dataset type: {type(actual_ds)}. "
"Expected `LeRobotDataset` or `MultiLeRobotDataset`.")
# After collecting all sampled episode indices, flatten them into `flat_sample_map`.
self.flat_sample_map = self._create_flat_sample_map()
def _create_flat_sample_map(self) -> List[Tuple[int, int]]:
"""
Converts the potentially nested structure of `self.all_sample_indices` (which can be
lists of lists of tensors, or lists of tensors) into a flat list of
`(original_dataset_index, sample_index_within_original_dataset)` tuples.
This flattened map is then used by `__getitem__` to efficiently retrieve samples.
"""
flat_map = []
for dataset_idx, sample_group in enumerate(self.all_sample_indices):
# Case 1: `MultiLeRobotDataset` where `sample_group` is `List[List[torch.Tensor]]`
if isinstance(sample_group, list) and len(sample_group) > 0 and isinstance(sample_group[0], list):
for sub_group in sample_group: # Iterate through sub-datasets' index lists
for tensor_of_indices in sub_group: # Iterate through tensors of indices for episodes
for i in range(tensor_of_indices.numel()):
flat_map.append((dataset_idx, tensor_of_indices[i].item()))
# Case 2: `LeRobotDataset` where `sample_group` is `List[torch.Tensor]`
elif isinstance(sample_group, list) and len(sample_group) > 0 and isinstance(sample_group[0], torch.Tensor):
for tensor_of_indices in sample_group:
for i in range(tensor_of_indices.numel()):
flat_map.append((dataset_idx, tensor_of_indices[i].item()))
# Case 3: A rare case where `sample_group` might be a single `torch.Tensor` directly
elif isinstance(sample_group, torch.Tensor):
for i in range(sample_group.numel()):
flat_map.append((dataset_idx, sample_group[i].item()))
return flat_map
def _sample_indices(
self,
start: List[int],
end: List[int],
num_frames: Optional[int],
random_pad: bool = False,
is_eval: bool = False,
dataset_name: str = None, # Added for potential future stage-specific logic
) -> List[torch.Tensor]:
"""
Samples indices for episodes based on the current stage and dataset-specific rules.
This function is called per episode to determine which frames to include.
Args:
start (List[int]): List of starting frame indices for each episode.
end (List[int]): List of ending frame indices for each episode.
num_frames (Optional[int]): The target number of frames to sample per episode.
This is primarily used for "stage1" where sampling
a fixed number of frames per episode might be desired.
random_pad (bool): If True, and `frame_count < target_frames`, shorter episodes
will be padded with randomly selected indices from themselves.
is_eval (bool): If True, adjusts indices for evaluation (e.g., shifting by 1 for stage1
to ensure predicted frames are not identical to observed frames).
dataset_name (str): The name of the dataset (for debugging or future dataset-specific sampling rules).
Returns:
List[torch.Tensor]: A list of PyTorch tensors, where each tensor contains the
sampled frame indices for a single episode.
"""
all_indices_for_episodes = []
for _start, _end in zip(start, end):
frame_count = _end - _start # Total frames available in this episode.
target_frames = frame_count
if frame_count >= target_frames:
# If enough frames are available, linearly space the indices to sample `target_frames`.
indices = torch.linspace(_start, _end - 1, steps=target_frames).long()
else:
# If fewer frames than `target_frames` are available.
if random_pad:
# Pad the existing frames with randomly chosen duplicates from the episode.
pad_size = target_frames - frame_count
indices = torch.arange(_start, _end) # All available original indices
# Randomly sample `pad_size` indices from the existing ones.
pad_indices = indices[torch.randint(0, frame_count, (pad_size,))]
indices = torch.cat([indices, pad_indices]) # Combine original and padded indices
indices = indices[torch.randperm(target_frames)] # Randomly permute to mix original and padded.
else:
# If not padding, simply use all available frames.
indices = torch.arange(_start, _end)
all_indices_for_episodes.append(indices)
return all_indices_for_episodes
def _get_weights(self, datasets_weights: Dict[str, float], aug_ratio: float = 1.0):
"""
Assigns normalized sampling weights to each individual sample in the flattened map.
Weights are adjusted based on base dataset weights and `gripper_aug_ratio` for
samples that have a detected gripper change.
Args:
datasets_weights (Dict[str, float]): A dictionary mapping dataset names to their
base sampling weights. If a dataset name is
not found, a default weight of 1.0 is used.
aug_ratio (float): The augmentation ratio (multiplier) to apply to the base weight
for samples where a gripper change is detected.
"""
self.sample_weights: List[float] = []
self.datasets_weight_map: Dict[str, float] = {}
if datasets_weights is None:
num_datasets = len(self.datasets_name)
datasets_weights = {name: 1.0 / num_datasets for name in self.datasets_name}
for idx, ds_name in enumerate(self.datasets_name):
# Access the underlying dataset to get gripper change information.
# It might be wrapped in a TransformedDataset, so we unwrap it.
current_base_dataset = self.datasets[idx]._dataset if isinstance(self.datasets[idx], TransformedDataset) else self.datasets[idx]
base_weight = datasets_weights.get(ds_name, 1.0) # Get base weight for this dataset
individual_weights_for_ds: List[float] = []
# Logic to retrieve `gripper_change_step_idx` and assign weights.
if isinstance(current_base_dataset, MultiLeRobotDataset):
# For MultiLeRobotDataset, iterate through its sub-datasets.
for idj, sub_ds in enumerate(current_base_dataset._datasets):
gripper_change_step_idx = getattr(sub_ds, 'gripper_change_step_idx', None)
if gripper_change_step_idx is not None:
sampled_indices_sub_ds = self.all_sample_indices[idx][idj]
for tensor_of_indices in sampled_indices_sub_ds:
for step_idx in tensor_of_indices.tolist():
if step_idx in gripper_change_step_idx:
individual_weights_for_ds.append(base_weight * aug_ratio)
else:
individual_weights_for_ds.append(base_weight)
elif isinstance(current_base_dataset, LeRobotDataset):
# For a single LeRobotDataset.
gripper_change_step_idx = getattr(current_base_dataset, 'gripper_change_step_idx', None)
if gripper_change_step_idx is not None:
sampled_indices_ds = self.all_sample_indices[idx]
for tensor_of_indices in sampled_indices_ds:
for step_idx in tensor_of_indices.tolist():
if step_idx in gripper_change_step_idx:
individual_weights_for_ds.append(base_weight * aug_ratio)
else:
individual_weights_for_ds.append(base_weight)
if gripper_change_step_idx is None:
print(f"Warning: Gripper change detection not fully supported for dataset type {type(current_base_dataset)}. "
"Assigning uniform weights based on `base_weight` for this dataset.")
num_samples_for_ds_in_flat_map = sum(1 for map_ds_idx, _ in self.flat_sample_map if map_ds_idx == idx)
individual_weights_for_ds.extend([base_weight] * num_samples_for_ds_in_flat_map)
# Accumulate individual weights for all samples and for the dataset's total.
self.sample_weights.extend(individual_weights_for_ds)
self.datasets_weight_map[ds_name] = self.datasets_weight_map.get(ds_name, 0.0) + sum(individual_weights_for_ds)
# Final normalization of all individual sample weights across the entire mixture dataset.
total_sum_of_all_individual_weights = sum(self.sample_weights)
if total_sum_of_all_individual_weights > 0:
self.sample_weights = np.array(self.sample_weights, dtype=np.float32)
self.sample_weights = self.sample_weights / total_sum_of_all_individual_weights
else:
self.sample_weights = np.array([], dtype=np.float32)
# Normalize the `datasets_weight_map` to reflect the effective proportion of each dataset
# in the final sampling distribution.
if total_sum_of_all_individual_weights > 0:
for k in self.datasets_weight_map:
self.datasets_weight_map[k] /= total_sum_of_all_individual_weights
else:
self.datasets_weight_map = {k: 0.0 for k in self.datasets_weight_map} # All weights become zero.
def __str__(self) -> str:
"""
Returns a formatted string representation of the MixtureDataset,
showing the effective sampling weights and dataset lengths.
"""
# Define ANSI escape codes for colored and bold text.
RESET = "\033[0m"
BOLD = "\033[1m"
CYAN = "\033[96m"
YELLOW = "\033[93m"
GREEN = "\033[92m"
MAGENTA = "\033[95m"
# Determine the maximum key length for consistent formatting.
max_key_len = max(len(k) for k in self.datasets_weight_map.keys()) + 2 if self.datasets_weight_map else 20
# Build the lines of the string representation.
lines = [
f"{BOLD}{MAGENTA}######################################### 👈 Dataset Weight Map: ########################################{RESET}"
]
# Add individual dataset information: name, number of samples, and effective weight.
for idx, (name, weight) in enumerate(self.datasets_weight_map.items()):
# Use `len(self.datasets[idx])` to get the number of samples in each transformed dataset.
# Formatting to 2 decimal places for weight and 0 for sample count.
lines.append(f"{CYAN}{name:<{max_key_len}} : {len(self.datasets[idx]):>18.0f} ({weight*100:>.2f}%){RESET}")
# Add a separator line.
separator_length = len(lines[0]) - len(BOLD) - len(MAGENTA) - len(RESET) + 1
lines.append("-" * separator_length)
# Add total episodes summary.
lines.append(f"{CYAN}{'Total Episodes':<{max_key_len}}{RESET} : {YELLOW}{sum(self.num_episodes):>18.0f}{RESET}")
# Add the closing border, matching the length of the separator.
lines.append(f"{BOLD}{MAGENTA}{'#' * separator_length}{RESET}")
return "\n".join(lines)
def create_mixture_dataset(
data_configs_list,
action_horizon,
model_config,
):
all_datasets = []
all_datasets_name = []
all_datasets_meta = []
all_datasets_weight = {}
for ds_configs in data_configs_list:
for ds_config in ds_configs:
repo_dir = ds_config.repo_dir
task_id = ds_config.task_id
subtask_id = ds_config.subtask_id
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
episodes = list(dataset_meta.episodes_stats.keys())
if ds_config.data_ratio < 1.0:
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
logging.info(f"sub_length: {sub_length}")
indices = np.random.choice(len(episodes), sub_length, replace=False)
episodes = [episodes[i] for i in indices]
print(f"downsample ratio: {ds_config.downsample_ratio}")
dataset = LeRobotDataset(
episodes=episodes,
repo_id=root_path,
root=root_path,
delta_timestamps={
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
},
)
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
gripper_aug_config = ds_config.gripper_aug_config
dataset = detect_gripper_change_step(
dataset,
select_actions=gripper_aug_config["gripper_action_keys"],
gripper_dim=gripper_aug_config["gripper_dim"],
threshold_method=gripper_aug_config["gripper_threshold_method"],
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
min_threshold=gripper_aug_config["gripper_min_threshold"],
max_threshold=gripper_aug_config["gripper_max_threshold"],
)
dataset = transform_dataset(dataset, ds_config)
dataset_name = root_path
dataset_weight = ds_config.weight
all_datasets.append(dataset)
all_datasets_name.append(dataset_name)
all_datasets_meta.append(dataset_meta)
all_datasets_weight[dataset_name] = dataset_weight
mixture_dataset = MixtureDataset(
all_datasets,
all_datasets_name,
all_datasets_meta,
all_datasets_weight,
gripper_aug_ratio=10.0,
)
return mixture_dataset
def create_mixture_dataset_no_transform(
data_configs_list,
action_horizon,
model_config
):
all_datasets = []
all_datasets_name = []
all_datasets_meta = []
all_datasets_weight = {}
for ds_configs in data_configs_list:
for ds_config in ds_configs:
repo_dir = ds_config.repo_dir
task_id = ds_config.task_id
subtask_id = ds_config.subtask_id
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
episodes = list(dataset_meta.episodes_stats.keys())
if ds_config.data_ratio < 1.0:
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
episodes = episodes[:sub_length]
dataset = LeRobotDataset(
episodes=episodes,
repo_id=root_path,
root=root_path,
delta_timestamps={
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
},
)
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
gripper_aug_config = ds_config.gripper_aug_config
dataset = detect_gripper_change_step(
dataset,
select_actions=gripper_aug_config["gripper_action_keys"],
gripper_dim=gripper_aug_config["gripper_dim"],
threshold_method=gripper_aug_config["gripper_threshold_method"],
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
min_threshold=gripper_aug_config["gripper_min_threshold"],
max_threshold=gripper_aug_config["gripper_max_threshold"],
)
dataset_name = root_path
dataset_weight = ds_config.weight
all_datasets.append(dataset)
all_datasets_name.append(dataset_name)
all_datasets_meta.append(dataset_meta)
all_datasets_weight[dataset_name] = dataset_weight
mixture_dataset = MixtureDataset(
all_datasets,
all_datasets_name,
all_datasets_meta,
all_datasets_weight,
gripper_aug_ratio=10.0,
)
return mixture_dataset
def create_mixture_dataset_calculate_norm_stats(
data_configs_list,
action_horizon,
model_config
):
all_datasets = []
all_datasets_name = []
all_datasets_meta = []
all_datasets_weight = {}
for ds_config in data_configs_list:
repo_dir = ds_config.repo_dir
task_id = ds_config.task_id
subtask_id = ds_config.subtask_id
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
episodes = list(dataset_meta.episodes_stats.keys())
if ds_config.data_ratio < 1.0:
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
episodes = episodes[:sub_length]
dataset = LeRobotDataset(
episodes=episodes,
repo_id=root_path,
root=root_path,
delta_timestamps={
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
},
load_video=False,
)
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
gripper_aug_config = ds_config.gripper_aug_config
dataset = detect_gripper_change_step(
dataset,
select_actions=gripper_aug_config["gripper_action_keys"],
gripper_dim=gripper_aug_config["gripper_dim"],
threshold_method=gripper_aug_config["gripper_threshold_method"],
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
min_threshold=gripper_aug_config["gripper_min_threshold"],
max_threshold=gripper_aug_config["gripper_max_threshold"],
)
dataset_name = root_path
dataset_weight = ds_config.weight
all_datasets.append(dataset)
all_datasets_name.append(dataset_name)
all_datasets_meta.append(dataset_meta)
all_datasets_weight[dataset_name] = dataset_weight
mixture_dataset = MixtureDataset(
all_datasets,
all_datasets_name,
all_datasets_meta,
all_datasets_weight,
gripper_aug_ratio=10.0,
)
return mixture_dataset

View File

@@ -0,0 +1,123 @@
import dataclasses
from typing import Protocol, runtime_checkable
import jax.numpy as jnp
import optax
import openpi.shared.array_typing as at
@runtime_checkable
class LRScheduleConfig(Protocol):
def create(self) -> optax.Schedule: ...
@dataclasses.dataclass(frozen=True)
class CosineDecaySchedule(LRScheduleConfig):
"""Cosine decay schedule with warmup."""
warmup_steps: int = 1_000
peak_lr: float = 2.5e-5
decay_steps: int = 30_000
decay_lr: float = 2.5e-6
def create(self) -> optax.Schedule:
return optax.warmup_cosine_decay_schedule(
init_value=self.peak_lr / (self.warmup_steps + 1),
peak_value=self.peak_lr,
warmup_steps=self.warmup_steps,
decay_steps=self.decay_steps,
end_value=self.decay_lr,
)
@dataclasses.dataclass(frozen=True)
class RsqrtDecaySchedule(LRScheduleConfig):
"""Inverse square root decay schedule with warmup."""
warmup_steps: int = 1_000
peak_lr: float = 5e-5
timescale: float = 10_000
def create(self) -> optax.Schedule:
return optax.join_schedules(
[
optax.linear_schedule(
init_value=self.peak_lr / (self.warmup_steps + 1),
end_value=self.peak_lr,
transition_steps=self.warmup_steps,
),
lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale),
],
[self.warmup_steps],
)
@dataclasses.dataclass(frozen=True)
class WarmupConstantSchedule(LRScheduleConfig):
"""Warmup constant schedule with warmup."""
warmup_steps: int = 2_000
peak_lr: float = 5e-5
def create(self) -> optax.Schedule:
return optax.warmup_constant_schedule(
init_value=self.peak_lr / (self.warmup_steps + 1),
peak_value=self.peak_lr,
warmup_steps=self.warmup_steps,
)
@runtime_checkable
class OptimizerConfig(Protocol):
def create(
self,
lr: optax.ScalarOrSchedule,
weight_decay_mask: at.PyTree | None = None,
) -> optax.GradientTransformation: ...
@dataclasses.dataclass(frozen=True)
class AdamW(OptimizerConfig):
"""AdamW optimizer."""
b1: float = 0.9
b2: float = 0.95
eps: float = 1e-8
# Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value.
weight_decay: float = 1e-10
clip_gradient_norm: float = 1.0
def create(
self,
lr: optax.ScalarOrSchedule,
weight_decay_mask: at.PyTree | None = None,
) -> optax.GradientTransformation:
tx = optax.adamw(
lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask
)
return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx)
@dataclasses.dataclass(frozen=True)
class SGD(OptimizerConfig):
"""SGD optimizer."""
lr: float = 5e-5
momentum: float = 0.9
nesterov: bool = False
def create(
self,
lr: optax.ScalarOrSchedule,
weight_decay_mask: at.PyTree | None = None,
) -> optax.GradientTransformation:
assert weight_decay_mask is None, "Weight decay is not supported for SGD"
return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov)
def create_optimizer(
optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None
) -> optax.GradientTransformation:
lr = lr_schedule.create()
return optimizer.create(lr, weight_decay_mask=weight_decay_mask)

View File

@@ -0,0 +1,102 @@
import contextlib
import logging
import jax
import numpy as np
BATCH_AXIS = "batch"
FSDP_AXIS = "fsdp"
# In FSDP, we shard the data across both the batch and FSDP axes.
DATA_AXIS = (BATCH_AXIS, FSDP_AXIS)
class _MeshState:
active_mesh: jax.sharding.Mesh | None = None
def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:
if jax.device_count() % num_fsdp_devices != 0:
raise ValueError(
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}."
)
mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices)
return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS))
@contextlib.contextmanager
def set_mesh(mesh: jax.sharding.Mesh):
"""Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a
custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used
in `activation_sharding_constraint` below."""
if _MeshState.active_mesh is not None:
raise ValueError("Cannot nest set_mesh context managers.")
_MeshState.active_mesh = mesh
try:
yield
finally:
_MeshState.active_mesh = None
def activation_sharding_constraint(pytree):
if _MeshState.active_mesh is None:
return pytree
return jax.lax.with_sharding_constraint(
pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS))
)
def fsdp_sharding(
pytree,
mesh: jax.sharding.Mesh,
*,
min_size_mbytes: int = 4, # 4 MiB
log: bool = False,
):
"""Apply FSDP sharding to a pytree of arrays based on the mesh shape.
Args:
pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr)
will be considered for sharding.
mesh: The mesh being used for applying sharding on to pytree.
min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this
will be replicated.
log: If true, will log the sharding decisions for arrays that are being considered for sharding.
Returns:
The sharded pytree.
"""
min_size_bytes = min_size_mbytes * 2**20
def _shard_arr(kp, array: jax.ShapeDtypeStruct):
# if fsdp is not actually going to be used, replicate everything to avoid extraneous logging
if mesh.shape[FSDP_AXIS] == 1:
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# replicate scalar and vector arrays
if not hasattr(array, "shape"):
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
if len(array.shape) < 2:
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# replicate small arrays
if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes:
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension
axes = np.argsort(array.shape)[::-1]
spec = [None] * len(axes)
for i in axes:
if array.shape[i] % mesh.shape[FSDP_AXIS] == 0:
if log:
logging.info(
f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}"
)
spec[i] = FSDP_AXIS
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
# replicate if no valid sharding was found
if log:
logging.warning(
f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}"
)
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
return jax.tree_util.tree_map_with_path(_shard_arr, pytree)

View File

@@ -0,0 +1,38 @@
from collections.abc import Callable
from typing import Any
from flax import nnx
from flax import struct
import jax
import optax
from openpi.models import model as _model
from openpi.shared import array_typing as at
@at.typecheck
@struct.dataclass
class TrainState:
step: at.Int[at.ArrayLike, ""]
params: nnx.State
model_def: nnx.GraphDef[_model.BaseModel]
opt_state: optax.OptState
tx: optax.GradientTransformation = struct.field(pytree_node=False)
ema_decay: float | None = struct.field(pytree_node=False)
ema_params: nnx.State | None = None
@at.typecheck
def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str:
"""Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert
the leaf values to more meaningful strings.
"""
tree, _ = jax.tree_util.tree_flatten_with_path(tree)
return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree)
@at.typecheck
def array_tree_to_info(tree: at.PyTree) -> str:
"""Converts a PyTree of arrays into a human-readable string for logging."""
return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}")

View File

@@ -0,0 +1,103 @@
import dataclasses
import logging
import re
from typing import Protocol, runtime_checkable
import flax.traverse_util
import numpy as np
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.shared.download as download
from pathlib import Path
logger = logging.getLogger(__name__)
@runtime_checkable
class WeightLoader(Protocol):
def load(self, params: at.Params) -> at.Params:
"""Loads the model weights.
Args:
params: Parameters of the model. This is a nested structure of array-like objects that
represent the model's parameters.
Returns:
Loaded parameters. The structure must be identical to `params`. If returning a subset of
the parameters the loader must merge the loaded parameters with `params`.
"""
@dataclasses.dataclass(frozen=True)
class NoOpWeightLoader(WeightLoader):
def load(self, params: at.Params) -> at.Params:
return params
@dataclasses.dataclass(frozen=True)
class CheckpointWeightLoader(WeightLoader):
"""Loads an entire set of weights from a checkpoint.
Compatible with:
trained checkpoints:
example: "./checkpoints/<config>/<exp>/<step>/params"
released checkpoints:
example: "gs://openpi-assets/checkpoints/<model>/params"
"""
params_path: str
def load(self, params: at.Params) -> at.Params:
# We are loading np.ndarray and relying on the training code to properly convert and shard the params.
loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray)
# Add all missing LoRA weights.
return _merge_params(loaded_params, params, missing_regex=".*lora.*")
@dataclasses.dataclass(frozen=True)
class PaliGemmaWeightLoader(WeightLoader):
"""Loads weights from the official PaliGemma checkpoint.
This will overwrite existing weights with similar names while keeping all extra weights intact.
This allows us to support the action expert which is used by the Pi0 model.
"""
params_path: str
def load(self, params: at.Params) -> at.Params:
path = Path(self.params_path)
with path.open("rb") as f:
flat_params = dict(np.load(f, allow_pickle=False))
loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]}
# Add all missing weights.
return _merge_params(loaded_params, params, missing_regex=".*")
def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params:
"""Merges the loaded parameters with the reference parameters.
Args:
loaded_params: The parameters to merge.
params: The reference parameters.
missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters.
Returns:
A new dictionary with the merged parameters.
"""
flat_ref = flax.traverse_util.flatten_dict(params, sep="/")
flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/")
# First, take all weights that are a subset of the reference weights.
result = {}
for k, v in flat_loaded.items():
if k in flat_ref:
result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v
flat_loaded.clear()
# Then, merge any missing weights as defined by the missing regex.
pattern = re.compile(missing_regex)
for k in {k for k in flat_ref if pattern.fullmatch(k)}:
if k not in result:
result[k] = flat_ref[k]
return flax.traverse_util.unflatten_dict(result, sep="/")