multi-node openpi commit
This commit is contained in:
159
policy/openpi-InternData-A1/src/openpi/training/checkpoints.py
Normal file
159
policy/openpi-InternData-A1/src/openpi/training/checkpoints.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures as futures
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import Protocol
|
||||
|
||||
from etils import epath
|
||||
import jax
|
||||
import orbax.checkpoint as ocp
|
||||
import orbax.checkpoint.future as future
|
||||
|
||||
from openpi.shared import array_typing as at
|
||||
import openpi.shared.normalize as _normalize
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.utils as training_utils
|
||||
|
||||
|
||||
def initialize_checkpoint_dir(
|
||||
checkpoint_dir: epath.Path | str, *, keep_period: int | None, overwrite: bool, resume: bool
|
||||
) -> tuple[ocp.CheckpointManager, bool]:
|
||||
checkpoint_dir = epath.Path(checkpoint_dir).resolve()
|
||||
resuming = False
|
||||
if checkpoint_dir.exists():
|
||||
if overwrite:
|
||||
checkpoint_dir.rmtree()
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"Wiped checkpoint directory {checkpoint_dir}")
|
||||
elif resume:
|
||||
resuming = True
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Checkpoint directory {checkpoint_dir} already exists. Use --overwrite or --resume "
|
||||
"to indicate how to handle it."
|
||||
)
|
||||
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mngr = ocp.CheckpointManager(
|
||||
checkpoint_dir,
|
||||
item_handlers={
|
||||
"assets": CallbackHandler(),
|
||||
"train_state": ocp.PyTreeCheckpointHandler(),
|
||||
"params": ocp.PyTreeCheckpointHandler(),
|
||||
},
|
||||
options=ocp.CheckpointManagerOptions(
|
||||
max_to_keep=1,
|
||||
keep_period=keep_period,
|
||||
create=False,
|
||||
async_options=ocp.AsyncOptions(timeout_secs=7200),
|
||||
),
|
||||
)
|
||||
|
||||
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
|
||||
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
|
||||
# checkpoint, since it will fail.
|
||||
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
|
||||
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
|
||||
resuming = False
|
||||
|
||||
return mngr, resuming
|
||||
|
||||
|
||||
def save_state(
|
||||
checkpoint_manager: ocp.CheckpointManager,
|
||||
state: training_utils.TrainState,
|
||||
data_loader: _data_loader.DataLoader,
|
||||
step: int,
|
||||
):
|
||||
def save_assets(directory: epath.Path):
|
||||
# Save the normalization stats.
|
||||
data_config = data_loader.data_config()
|
||||
norm_stats = data_config.norm_stats
|
||||
if norm_stats is not None and data_config.asset_id is not None:
|
||||
_normalize.save(directory / data_config.asset_id, norm_stats)
|
||||
|
||||
# Split params that can be used for inference into a separate item.
|
||||
with at.disable_typechecking():
|
||||
train_state, params = _split_params(state)
|
||||
items = {
|
||||
"assets": save_assets,
|
||||
"train_state": train_state,
|
||||
"params": {"params": params},
|
||||
}
|
||||
checkpoint_manager.save(step, items)
|
||||
|
||||
|
||||
def restore_state(
|
||||
checkpoint_manager: ocp.CheckpointManager,
|
||||
state: training_utils.TrainState,
|
||||
data_loader: _data_loader.DataLoader,
|
||||
step: int | None = None,
|
||||
) -> training_utils.TrainState:
|
||||
del data_loader
|
||||
|
||||
with at.disable_typechecking():
|
||||
# Split params that can be used for inference into a separate item.
|
||||
train_state, params = _split_params(state)
|
||||
restored = checkpoint_manager.restore(
|
||||
step,
|
||||
items={
|
||||
"train_state": train_state,
|
||||
"params": {"params": params},
|
||||
},
|
||||
)
|
||||
return _merge_params(restored["train_state"], restored["params"])
|
||||
|
||||
|
||||
def load_norm_stats(assets_dir: epath.Path | str, asset_id: str) -> dict[str, _normalize.NormStats] | None:
|
||||
norm_stats_dir = epath.Path(assets_dir) / asset_id
|
||||
norm_stats = _normalize.load(norm_stats_dir)
|
||||
logging.info(f"Loaded norm stats from {norm_stats_dir}")
|
||||
return norm_stats
|
||||
|
||||
|
||||
class Callback(Protocol):
|
||||
def __call__(self, directory: epath.Path) -> None: ...
|
||||
|
||||
|
||||
class CallbackHandler(ocp.AsyncCheckpointHandler):
|
||||
"""A CheckpointHandler for calling an arbitrary function asynchronously. Only for saving, not for restoring."""
|
||||
|
||||
def save(self, directory: epath.Path, args: CallbackSave):
|
||||
if jax.process_index() == 0:
|
||||
args.callback(directory)
|
||||
|
||||
async def async_save(self, directory: epath.Path, args: CallbackSave) -> list[futures.Future]:
|
||||
return [future.CommitFutureAwaitingContractedSignals(asyncio.to_thread(self.save, directory, args))]
|
||||
|
||||
def restore(self, *args, **kwargs):
|
||||
raise NotImplementedError("CallbackHandler does not support restore")
|
||||
|
||||
|
||||
@ocp.args.register_with_handler(CallbackHandler, for_save=True)
|
||||
@dataclasses.dataclass
|
||||
class CallbackSave(ocp.args.CheckpointArgs):
|
||||
callback: Callback
|
||||
|
||||
|
||||
@ocp.args.register_with_handler(CallbackHandler, for_restore=True)
|
||||
class CallbackRestore(ocp.args.CheckpointArgs): ...
|
||||
|
||||
|
||||
def _split_params(state: training_utils.TrainState) -> tuple[training_utils.TrainState, at.Params]:
|
||||
if state.ema_params is not None:
|
||||
params = state.ema_params
|
||||
train_state = dataclasses.replace(state, ema_params=None)
|
||||
else:
|
||||
params = state.params
|
||||
train_state = dataclasses.replace(state, params={})
|
||||
return train_state, params
|
||||
|
||||
|
||||
def _merge_params(train_state: training_utils.TrainState, params: dict[str, at.Params]) -> training_utils.TrainState:
|
||||
# Revert the logic inside `_split_params`. Assumes that existence of `params` means that EMA params were used during the split.
|
||||
if train_state.params:
|
||||
return dataclasses.replace(train_state, ema_params=params["params"])
|
||||
return dataclasses.replace(train_state, params=params["params"])
|
||||
1904
policy/openpi-InternData-A1/src/openpi/training/config.py
Normal file
1904
policy/openpi-InternData-A1/src/openpi/training/config.py
Normal file
File diff suppressed because it is too large
Load Diff
721
policy/openpi-InternData-A1/src/openpi/training/data_loader.py
Normal file
721
policy/openpi-InternData-A1/src/openpi/training/data_loader.py
Normal file
@@ -0,0 +1,721 @@
|
||||
from collections.abc import Iterator, Sequence
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import typing
|
||||
from typing import Literal, Protocol, SupportsIndex, TypeVar, Dict
|
||||
import sys
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import lerobot.common.datasets.lerobot_dataset as lerobot_dataset
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import dataloader
|
||||
from torch.multiprocessing import reductions
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
import openpi.shared.normalize as normalize
|
||||
default_collate_func = dataloader.default_collate
|
||||
import psutil
|
||||
|
||||
def default_collate_override(batch):
|
||||
dataloader._use_shared_memory = False
|
||||
return default_collate_func(batch)
|
||||
|
||||
setattr(dataloader, 'default_collate', default_collate_override)
|
||||
|
||||
for t in torch._storage_classes:
|
||||
if sys.version_info[0] == 2:
|
||||
if t in ForkingPickler.dispatch:
|
||||
del ForkingPickler.dispatch[t]
|
||||
else:
|
||||
if t in ForkingPickler._extra_reducers:
|
||||
del ForkingPickler._extra_reducers[t]
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.training.config as _config
|
||||
from openpi.training.droid_rlds_dataset import DroidRldsDataset
|
||||
import openpi.transforms as _transforms
|
||||
from openpi.training.mixture_dataset import create_mixture_dataset
|
||||
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
import copy
|
||||
from memory_profiler import profile
|
||||
from pdb import set_trace
|
||||
|
||||
class Dataset(Protocol[T_co]):
|
||||
"""Interface for a dataset with random access."""
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> T_co:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
|
||||
|
||||
|
||||
class IterableDataset(Protocol[T_co]):
|
||||
"""Interface for an iterable dataset."""
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
raise NotImplementedError("Subclasses of IterableDataset should implement __iter__.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
|
||||
|
||||
|
||||
class DataLoader(Protocol[T_co]):
|
||||
"""Interface for a data loader."""
|
||||
|
||||
def data_config(self) -> _config.DataConfig:
|
||||
"""Get the data config for this data loader."""
|
||||
raise NotImplementedError("Subclasses of DataLoader should implement data_config.")
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
raise NotImplementedError("Subclasses of DataLoader should implement __iter__.")
|
||||
|
||||
|
||||
class TransformedDataset(Dataset[T_co]):
|
||||
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
|
||||
self._dataset = dataset
|
||||
self._transform = _transforms.compose(transforms)
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> T_co:
|
||||
return self._transform(self._dataset[index])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dataset)
|
||||
|
||||
|
||||
class IterableTransformedDataset(IterableDataset[T_co]):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: IterableDataset,
|
||||
transforms: Sequence[_transforms.DataTransformFn],
|
||||
*,
|
||||
is_batched: bool = False,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._transform = _transforms.compose(transforms)
|
||||
self._is_batched = is_batched
|
||||
|
||||
def __iter__(self):
|
||||
for sample in self._dataset:
|
||||
if self._is_batched:
|
||||
# Transforms are designed to be applied to individual samples. So we need to split the batch into
|
||||
# individual samples and apply the transform to each sample individually.
|
||||
batch_size = next(v.shape[0] for v in sample.values())
|
||||
|
||||
# Split batch into individual samples using tree_map
|
||||
individual_samples = [jax.tree.map(lambda x: x[i], sample) for i in range(batch_size)] # noqa: B023
|
||||
|
||||
# Transform each sample
|
||||
transformed = [self._transform(s) for s in individual_samples]
|
||||
|
||||
# Recombine batch with tree_map
|
||||
yield jax.tree.map(lambda *x: np.stack(x, axis=0), *transformed)
|
||||
else:
|
||||
yield self._transform(sample)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dataset)
|
||||
|
||||
|
||||
class FakeDataset(Dataset):
|
||||
def __init__(self, model_config: _model.BaseModelConfig, num_samples: int):
|
||||
self._num_samples = num_samples
|
||||
self._observation_spec, self._action_spec = model_config.inputs_spec()
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> dict:
|
||||
rng = jax.random.key(index.__index__())
|
||||
|
||||
def make_from_spec(spec: jax.ShapeDtypeStruct):
|
||||
nonlocal rng
|
||||
rng, data_rng = jax.random.split(rng)
|
||||
# Remove the batch dimension.
|
||||
shape = spec.shape[1:]
|
||||
if spec.dtype == jnp.float32:
|
||||
return jax.random.uniform(data_rng, shape=shape, minval=-1.0, maxval=1.0)
|
||||
if spec.dtype == jnp.int32:
|
||||
return jax.random.randint(data_rng, shape=shape, minval=0, maxval=2048)
|
||||
return jnp.zeros(shape=shape, dtype=spec.dtype)
|
||||
|
||||
observation = jax.tree.map(make_from_spec, self._observation_spec)
|
||||
action = jax.tree.map(make_from_spec, self._action_spec)
|
||||
|
||||
return {
|
||||
**observation.to_dict(),
|
||||
"actions": action,
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._num_samples
|
||||
|
||||
|
||||
def create_torch_dataset(
|
||||
data_config: _config.DataConfig, action_horizon: int, model_config: _model.BaseModelConfig
|
||||
) -> Dataset:
|
||||
"""Create a dataset for training."""
|
||||
repo_id = data_config.repo_id
|
||||
if repo_id is None:
|
||||
raise ValueError("Repo ID is not set. Cannot create dataset.")
|
||||
if repo_id == "fake":
|
||||
return FakeDataset(model_config, num_samples=1024)
|
||||
|
||||
dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(repo_id)
|
||||
dataset = lerobot_dataset.LeRobotDataset(
|
||||
data_config.repo_id,
|
||||
delta_timestamps={
|
||||
key: [t / dataset_meta.fps for t in range(action_horizon)] for key in data_config.action_sequence_keys
|
||||
},
|
||||
)
|
||||
|
||||
if data_config.prompt_from_task:
|
||||
dataset = TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)])
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def create_rlds_dataset(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
*,
|
||||
shuffle: bool = False,
|
||||
) -> Dataset:
|
||||
# At the moment, we only support DROID for RLDS datasets.
|
||||
return DroidRldsDataset(
|
||||
data_dir=data_config.rlds_data_dir,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
action_chunk_size=action_horizon,
|
||||
action_space=data_config.action_space,
|
||||
filter_dict_path=data_config.filter_dict_path,
|
||||
)
|
||||
|
||||
|
||||
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig, *, skip_norm_stats: bool = False) -> Dataset:
|
||||
"""Transform the dataset by applying the data transforms."""
|
||||
norm_stats = {}
|
||||
if data_config.repo_id != "fake" and not skip_norm_stats:
|
||||
if data_config.norm_stats is None:
|
||||
raise ValueError(
|
||||
"Normalization stats not found. "
|
||||
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
|
||||
)
|
||||
norm_stats = data_config.norm_stats
|
||||
|
||||
return TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def transform_iterable_dataset(
|
||||
dataset: IterableDataset,
|
||||
data_config: _config.DataConfig,
|
||||
*,
|
||||
skip_norm_stats: bool = False,
|
||||
is_batched: bool = False,
|
||||
) -> IterableDataset:
|
||||
"""Transform the dataset by applying the data transforms."""
|
||||
norm_stats = {}
|
||||
if data_config.repo_id != "fake" and not skip_norm_stats:
|
||||
if data_config.norm_stats is None:
|
||||
raise ValueError(
|
||||
"Normalization stats not found. "
|
||||
"Make sure to run `scripts/compute_norm_stats.py --config-name=<your-config>`."
|
||||
)
|
||||
norm_stats = data_config.norm_stats
|
||||
|
||||
return IterableTransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
is_batched=is_batched,
|
||||
)
|
||||
|
||||
|
||||
def create_data_loader(
|
||||
config: _config.TrainConfig,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
framework: Literal["jax", "pytorch"] = "jax",
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create a data loader for training.
|
||||
|
||||
Args:
|
||||
config: The training configuration.
|
||||
sharding: The sharding to use for the data loader (JAX only).
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
framework: The framework to use ("jax" or "pytorch").
|
||||
"""
|
||||
data_config = config.data.create(config.assets_dirs, config.model)
|
||||
logging.info(f"data_config: {data_config}")
|
||||
|
||||
if data_config.rlds_data_dir is not None:
|
||||
return create_rlds_data_loader(
|
||||
data_config,
|
||||
action_horizon=config.model.action_horizon,
|
||||
batch_size=config.batch_size,
|
||||
sharding=sharding,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
skip_norm_stats=skip_norm_stats,
|
||||
framework=framework,
|
||||
)
|
||||
return create_torch_data_loader(
|
||||
data_config,
|
||||
model_config=config.model,
|
||||
action_horizon=config.model.action_horizon,
|
||||
batch_size=config.batch_size,
|
||||
sharding=sharding,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
num_workers=config.num_workers,
|
||||
seed=config.seed,
|
||||
skip_norm_stats=skip_norm_stats,
|
||||
framework=framework,
|
||||
)
|
||||
|
||||
def create_data_loader_multi(
|
||||
config: _config.TrainConfig,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
framework: Literal["jax", "pytorch"] = "jax",
|
||||
global_norm_stats: Dict[str, normalize.NormStats] | None = None,
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create a data loader for training.
|
||||
|
||||
Args:
|
||||
config: The training configuration.
|
||||
sharding: The sharding to use for the data loader (JAX only).
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
framework: The framework to use ("jax" or "pytorch").
|
||||
"""
|
||||
data_configs_list = []
|
||||
for data_config_factory in config.data:
|
||||
data_configs = data_config_factory.create(config.model, global_norm_stats)
|
||||
logging.info(f"data_config: {data_configs}")
|
||||
data_configs_list.append(data_configs)
|
||||
|
||||
return create_torch_data_loader_multi(
|
||||
data_configs_list,
|
||||
model_config=config.model,
|
||||
action_horizon=config.model.action_horizon,
|
||||
batch_size=config.batch_size,
|
||||
sharding=sharding,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
num_workers=config.num_workers,
|
||||
seed=config.seed,
|
||||
skip_norm_stats=skip_norm_stats,
|
||||
framework=framework,
|
||||
global_norm_stats=global_norm_stats,
|
||||
)
|
||||
|
||||
|
||||
def create_torch_data_loader(
|
||||
data_config: _config.DataConfig,
|
||||
model_config: _model.BaseModelConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
num_workers: int = 0,
|
||||
seed: int = 0,
|
||||
framework: str = "jax",
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create a data loader for training.
|
||||
|
||||
Args:
|
||||
data_config: The data configuration.
|
||||
action_horizon: The action horizon.
|
||||
batch_size: The batch size.
|
||||
sharding: The sharding to use for the data loader. If None, the data loader will
|
||||
use a single device sharding.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return. If the number exceeds the
|
||||
number of batches in the dataset, the data loader will loop over the dataset.
|
||||
If not provided, will iterate over the dataset indefinitely.
|
||||
num_workers: The number of worker processes to use. If zero, the data loader will
|
||||
execute in the main process.
|
||||
seed: The seed to use for shuffling the data.
|
||||
"""
|
||||
dataset = create_torch_dataset(data_config, action_horizon, model_config)
|
||||
dataset = transform_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats)
|
||||
|
||||
# Use TorchDataLoader for both frameworks
|
||||
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
|
||||
# For JAX, divide by process count
|
||||
sampler = None
|
||||
if framework == "pytorch":
|
||||
if torch.distributed.is_initialized():
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=torch.distributed.get_world_size(),
|
||||
rank=torch.distributed.get_rank(),
|
||||
shuffle=shuffle,
|
||||
drop_last=True,
|
||||
)
|
||||
local_batch_size = batch_size // torch.distributed.get_world_size()
|
||||
else:
|
||||
local_batch_size = batch_size
|
||||
else:
|
||||
local_batch_size = batch_size // jax.process_count()
|
||||
if jax.process_count() > 1:
|
||||
sampler = JaxProcessDistributedSampler(
|
||||
dataset_size=len(dataset),
|
||||
num_replicas=jax.process_count(),
|
||||
rank=jax.process_index(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
)
|
||||
logging.info(f"local_batch_size: {local_batch_size}")
|
||||
data_loader = TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=local_batch_size,
|
||||
sharding=None if framework == "pytorch" else sharding,
|
||||
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
|
||||
sampler=sampler,
|
||||
num_batches=num_batches,
|
||||
num_workers=num_workers,
|
||||
seed=seed,
|
||||
framework=framework,
|
||||
)
|
||||
|
||||
return DataLoaderImpl(data_config, data_loader)
|
||||
|
||||
def create_torch_data_loader_multi(
|
||||
data_configs_list: list[_config.DataConfig],
|
||||
model_config: _model.BaseModelConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
num_workers: int = 0,
|
||||
seed: int = 0,
|
||||
framework: str = "jax",
|
||||
global_norm_stats: Dict[str, normalize.NormStats] | None = None,
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create a data loader for training.
|
||||
|
||||
Args:
|
||||
data_config: The data configuration.
|
||||
action_horizon: The action horizon.
|
||||
batch_size: The batch size.
|
||||
sharding: The sharding to use for the data loader. If None, the data loader will
|
||||
use a single device sharding.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return. If the number exceeds the
|
||||
number of batches in the dataset, the data loader will loop over the dataset.
|
||||
If not provided, will iterate over the dataset indefinitely.
|
||||
num_workers: The number of worker processes to use. If zero, the data loader will
|
||||
execute in the main process.
|
||||
seed: The seed to use for shuffling the data.
|
||||
"""
|
||||
dataset = create_mixture_dataset(data_configs_list, action_horizon, model_config)
|
||||
# Use TorchDataLoader for both frameworks
|
||||
# For PyTorch DDP, create DistributedSampler and divide batch size by world size
|
||||
# For JAX, divide by process count
|
||||
sampler = None
|
||||
if framework == "pytorch":
|
||||
if torch.distributed.is_initialized():
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=torch.distributed.get_world_size(),
|
||||
rank=torch.distributed.get_rank(),
|
||||
shuffle=shuffle,
|
||||
drop_last=True,
|
||||
)
|
||||
local_batch_size = batch_size // torch.distributed.get_world_size()
|
||||
else:
|
||||
local_batch_size = batch_size
|
||||
else:
|
||||
local_batch_size = batch_size // jax.process_count()
|
||||
if jax.process_count() > 1:
|
||||
sampler = JaxProcessDistributedSampler(
|
||||
dataset_size=len(dataset),
|
||||
num_replicas=jax.process_count(),
|
||||
rank=jax.process_index(),
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
)
|
||||
logging.info(f"local_batch_size: {local_batch_size}")
|
||||
data_loader = TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=local_batch_size,
|
||||
sharding=None if framework == "pytorch" else sharding,
|
||||
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
|
||||
sampler=sampler,
|
||||
num_batches=num_batches,
|
||||
num_workers=num_workers,
|
||||
seed=seed,
|
||||
framework=framework,
|
||||
)
|
||||
|
||||
return DataLoaderImpl(data_configs_list[0][0], data_loader)
|
||||
|
||||
|
||||
def create_rlds_data_loader(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
skip_norm_stats: bool = False,
|
||||
shuffle: bool = False,
|
||||
num_batches: int | None = None,
|
||||
framework: str = "jax",
|
||||
) -> DataLoader[tuple[_model.Observation, _model.Actions]]:
|
||||
"""Create an RLDS data loader for training.
|
||||
|
||||
Note: This data loader requires some extra dependencies -- see examples/droid/README_train.md
|
||||
|
||||
Args:
|
||||
data_config: The data configuration.
|
||||
action_horizon: The action horizon.
|
||||
batch_size: The batch size.
|
||||
sharding: The sharding to use for the data loader. If None, the data loader will
|
||||
use a single device sharding.
|
||||
skip_norm_stats: Whether to skip data normalization.
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: Determines the number of batches to return. If the number exceeds the
|
||||
number of batches in the dataset, the data loader will loop over the dataset.
|
||||
If not provided, will iterate over the dataset indefinitely.
|
||||
"""
|
||||
if framework == "pytorch":
|
||||
raise NotImplementedError("PyTorch RLDS data loader is not supported yet")
|
||||
dataset = create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=shuffle)
|
||||
dataset = transform_iterable_dataset(dataset, data_config, skip_norm_stats=skip_norm_stats, is_batched=True)
|
||||
|
||||
data_loader = RLDSDataLoader(
|
||||
dataset,
|
||||
sharding=sharding,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
|
||||
return DataLoaderImpl(data_config, data_loader)
|
||||
|
||||
|
||||
class JaxProcessDistributedSampler(torch.utils.data.Sampler[int]):
|
||||
"""Simple sampler to split dataset indices across JAX processes.
|
||||
|
||||
Each process sees a disjoint slice of indices using striding by num_replicas.
|
||||
Shuffling (if enabled) is deterministic via the provided seed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_size: int,
|
||||
*,
|
||||
num_replicas: int,
|
||||
rank: int,
|
||||
shuffle: bool,
|
||||
seed: int,
|
||||
) -> None:
|
||||
self._dataset_size = max(0, dataset_size)
|
||||
self._num_replicas = max(1, num_replicas)
|
||||
self._rank = max(0, rank)
|
||||
self._shuffle = shuffle
|
||||
self._seed = seed
|
||||
|
||||
def __iter__(self):
|
||||
indices = list(range(self._dataset_size))
|
||||
if self._shuffle and self._dataset_size > 0:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed)
|
||||
indices = torch.randperm(self._dataset_size, generator=g).tolist()
|
||||
# Strided split across processes; drop remainder for balance
|
||||
indices = indices[self._rank :: self._num_replicas]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
# Match strided selection length
|
||||
return (self._dataset_size + self._num_replicas - 1) // self._num_replicas
|
||||
|
||||
# @profile
|
||||
class TorchDataLoader:
|
||||
"""Torch data loader implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
local_batch_size: int,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
shuffle: bool = False,
|
||||
sampler: torch.utils.data.Sampler | None = None,
|
||||
num_batches: int | None = None,
|
||||
num_workers: int = 0,
|
||||
seed: int = 0,
|
||||
framework: str = "jax",
|
||||
):
|
||||
"""Create a PyTorch data loader.
|
||||
|
||||
Args:
|
||||
dataset: The dataset to load.
|
||||
local_batch_size: The local batch size for each process.
|
||||
sharding: The sharding to use for the data loader.
|
||||
shuffle: Whether to shuffle the data.
|
||||
num_batches: If provided, determines the number of returned batches. If the
|
||||
number is larger than the number of batches in the dataset, the data loader
|
||||
will loop over the dataset. If not provided, will iterate over the dataset
|
||||
indefinitely.
|
||||
num_workers: The number of worker processes to use. If zero, the data loader will
|
||||
execute in the main process.
|
||||
seed: The seed to use for shuffling the data.
|
||||
"""
|
||||
if len(dataset) < local_batch_size:
|
||||
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
|
||||
|
||||
# Store sharding - None for PyTorch, JAX sharding for JAX
|
||||
self._sharding = sharding
|
||||
if sharding is None and framework == "jax":
|
||||
# Use data parallel sharding by default for JAX only.
|
||||
self._sharding = jax.sharding.NamedSharding(
|
||||
jax.sharding.Mesh(jax.devices(), ("B",)),
|
||||
jax.sharding.PartitionSpec("B"),
|
||||
)
|
||||
self._num_batches = num_batches
|
||||
|
||||
mp_context = None
|
||||
if num_workers > 0:
|
||||
mp_context = multiprocessing.get_context("spawn")
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
self._data_loader = torch.utils.data.DataLoader(
|
||||
typing.cast(torch.utils.data.Dataset, dataset),
|
||||
batch_size=local_batch_size,
|
||||
shuffle=(sampler is None and shuffle), # Don't shuffle if using sampler
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
multiprocessing_context=mp_context,
|
||||
persistent_workers=num_workers > 0,
|
||||
collate_fn=_collate_fn,
|
||||
worker_init_fn=_worker_init_fn,
|
||||
drop_last=True,
|
||||
generator=generator,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def torch_loader(self) -> torch.utils.data.DataLoader:
|
||||
return self._data_loader
|
||||
|
||||
@profile
|
||||
def __iter__(self):
|
||||
num_items = 0
|
||||
while True:
|
||||
data_iter = iter(self._data_loader)
|
||||
while True:
|
||||
if self._num_batches is not None and num_items >= self._num_batches:
|
||||
return
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
break # We've exhausted the dataset. Create a new iterator and start over.
|
||||
num_items += 1
|
||||
# For JAX, convert to sharded arrays; for PyTorch, return torch tensors
|
||||
if self._sharding is not None:
|
||||
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
|
||||
else:
|
||||
yield jax.tree.map(torch.as_tensor, batch)
|
||||
|
||||
|
||||
def _collate_fn(items):
|
||||
"""Collate the batch elements into batched numpy arrays."""
|
||||
# Make sure to convert to numpy arrays before stacking since some of the incoming elements
|
||||
# may be JAX arrays.
|
||||
return jax.tree.map(lambda *xs: np.stack([np.asarray(x) for x in xs], axis=0), *items)
|
||||
|
||||
|
||||
def _worker_init_fn(worker_id: int) -> None:
|
||||
"""Tell JAX inside the worker process not to preallocate the GPU memory."""
|
||||
# NOTE: This is called after jax is imported inside the worker process. This
|
||||
# means that this approach will not work for selecting the backend.
|
||||
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
||||
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
||||
|
||||
|
||||
class RLDSDataLoader:
|
||||
"""Shallow wrapper around the DROID data loader to make it compatible with openpi.
|
||||
|
||||
All batching already happens in the DROID dataset, so we don't need to do anything here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: DroidRldsDataset,
|
||||
*,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
num_batches: int | None = None,
|
||||
):
|
||||
self._dataset = dataset
|
||||
self._num_batches = num_batches
|
||||
|
||||
if jax.process_count() > 1:
|
||||
raise NotImplementedError("Data loading with multiple processes is not supported.")
|
||||
|
||||
if sharding is None:
|
||||
# Use data parallel sharding by default.
|
||||
sharding = jax.sharding.NamedSharding(
|
||||
jax.sharding.Mesh(jax.devices(), ("B",)),
|
||||
jax.sharding.PartitionSpec("B"),
|
||||
)
|
||||
|
||||
self._sharding = sharding
|
||||
self._num_batches = num_batches
|
||||
|
||||
def __iter__(self):
|
||||
num_items = 0
|
||||
while True:
|
||||
data_iter = iter(self._dataset)
|
||||
while True:
|
||||
if self._num_batches is not None and num_items >= self._num_batches:
|
||||
return
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
break # We've exhausted the dataset. Create a new iterator and start over.
|
||||
num_items += 1
|
||||
yield jax.tree.map(lambda x: jax.make_array_from_process_local_data(self._sharding, x), batch)
|
||||
|
||||
|
||||
class DataLoaderImpl(DataLoader):
|
||||
def __init__(self, data_config: _config.DataConfig, data_loader: TorchDataLoader | RLDSDataLoader):
|
||||
self._data_config = data_config
|
||||
self._data_loader = data_loader
|
||||
|
||||
def data_config(self) -> _config.DataConfig:
|
||||
return self._data_config
|
||||
|
||||
def __iter__(self):
|
||||
for batch in self._data_loader:
|
||||
yield _model.Observation.from_dict(batch), batch["actions"]
|
||||
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
RLDS-based data loader for DROID.
|
||||
While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID.
|
||||
Thus, we provide a data loader example here that uses the RLDS data format.
|
||||
The data loader also applies a few DROID-specific data filters / transformations.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from enum import auto
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
|
||||
import openpi.shared.download as download
|
||||
|
||||
|
||||
class DroidActionSpace(Enum):
|
||||
"""Action space for DROID dataset."""
|
||||
|
||||
JOINT_POSITION = auto()
|
||||
JOINT_VELOCITY = auto()
|
||||
|
||||
|
||||
class DroidRldsDataset:
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
batch_size: int,
|
||||
*, # Force keyword-only arguments
|
||||
shuffle: bool = True,
|
||||
action_chunk_size: int = 16,
|
||||
# We default to joint position actions, since they allow policy evaluation in simulation.
|
||||
action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,
|
||||
max_loaded_steps_per_episode: int = 100,
|
||||
# Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.
|
||||
shuffle_buffer_size: int = 250_000,
|
||||
num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
|
||||
num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
|
||||
filter_dict_path=None, # Path to json file with indices to sample during training
|
||||
):
|
||||
# Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
|
||||
import dlimp as dl
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)
|
||||
tf.config.set_visible_devices([], "GPU")
|
||||
|
||||
builder = tfds.builder("droid", data_dir=data_dir, version="1.0.1")
|
||||
dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads)
|
||||
|
||||
# Filter out any unsuccessful trajectories -- we use the file name to check this
|
||||
dataset = dataset.filter(
|
||||
lambda traj: tf.strings.regex_full_match(
|
||||
traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*"
|
||||
)
|
||||
)
|
||||
|
||||
# # Repeat dataset so we never run out of data.
|
||||
dataset = dataset.repeat()
|
||||
|
||||
# Load the filter dictionary if provided.
|
||||
# The filter dictionary is a JSON file that maps episode keys to ranges of frames to sample
|
||||
# (e.g.,
|
||||
# {
|
||||
# "<episode key>": [[0, 100], [200, 300]]
|
||||
# }
|
||||
# means keep frames 0-99 and 200-299).
|
||||
if filter_dict_path is not None:
|
||||
cached_filter_dict_path = download.maybe_download(filter_dict_path)
|
||||
with Path(cached_filter_dict_path).open("r") as f:
|
||||
filter_dict = json.load(f)
|
||||
|
||||
logging.info(f"Using filter dictionary with {len(filter_dict)} episodes")
|
||||
|
||||
keys_tensor = []
|
||||
values_tensor = []
|
||||
|
||||
for episode_key, ranges in tqdm.tqdm(filter_dict.items(), desc="Creating idle filter hash table..."):
|
||||
for start, end in ranges:
|
||||
for t in range(start, end):
|
||||
frame_key = f"{episode_key}--{t}"
|
||||
keys_tensor.append(frame_key)
|
||||
values_tensor.append(True)
|
||||
self.filter_table = tf.lookup.StaticHashTable(
|
||||
tf.lookup.KeyValueTensorInitializer(keys_tensor, values_tensor), default_value=False
|
||||
)
|
||||
logging.info("Filter hash table initialized")
|
||||
else:
|
||||
self.filter_table = tf.lookup.StaticHashTable(
|
||||
tf.lookup.KeyValueTensorInitializer([""], [True]), default_value=True
|
||||
)
|
||||
|
||||
def restructure(traj):
|
||||
"""Reformat observation and action keys, sample language instruction."""
|
||||
# Important: we use joint *position* action space -- easier to simulate!
|
||||
actions = tf.concat(
|
||||
(
|
||||
(
|
||||
traj["action_dict"]["joint_position"]
|
||||
if action_space == DroidActionSpace.JOINT_POSITION
|
||||
else traj["action_dict"]["joint_velocity"]
|
||||
),
|
||||
traj["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).
|
||||
# Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera.
|
||||
exterior_img = tf.cond(
|
||||
tf.random.uniform(shape=[]) > 0.5,
|
||||
lambda: traj["observation"]["exterior_image_1_left"],
|
||||
lambda: traj["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
wrist_img = traj["observation"]["wrist_image_left"]
|
||||
# Randomly sample one of the three language instructions
|
||||
instruction = tf.random.shuffle(
|
||||
[traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]]
|
||||
)[0]
|
||||
|
||||
traj_len = tf.shape(traj["action"])[0]
|
||||
indices = tf.as_string(tf.range(traj_len))
|
||||
|
||||
# Data filtering:
|
||||
# Compute a uniquely-identifying step ID by concatenating the recording folderpath, file path,
|
||||
# and each step's time step index. This will index into the filter hash table, and if it returns true,
|
||||
# then the frame passes the filter.
|
||||
step_id = (
|
||||
traj["traj_metadata"]["episode_metadata"]["recording_folderpath"]
|
||||
+ "--"
|
||||
+ traj["traj_metadata"]["episode_metadata"]["file_path"]
|
||||
+ "--"
|
||||
+ indices
|
||||
)
|
||||
passes_filter = self.filter_table.lookup(step_id)
|
||||
|
||||
return {
|
||||
"actions": actions,
|
||||
"observation": {
|
||||
"image": exterior_img,
|
||||
"wrist_image": wrist_img,
|
||||
"joint_position": traj["observation"]["joint_position"],
|
||||
"gripper_position": traj["observation"]["gripper_position"],
|
||||
},
|
||||
"prompt": instruction,
|
||||
"step_id": step_id,
|
||||
"passes_filter": passes_filter,
|
||||
}
|
||||
|
||||
dataset = dataset.traj_map(restructure, num_parallel_calls)
|
||||
|
||||
def chunk_actions(traj):
|
||||
"""Splits episode into action chunks."""
|
||||
traj_len = tf.shape(traj["actions"])[0]
|
||||
|
||||
# For each step in the trajectory, construct indices for the next n actions
|
||||
action_chunk_indices = tf.broadcast_to(
|
||||
tf.range(action_chunk_size)[None],
|
||||
[traj_len, action_chunk_size],
|
||||
) + tf.broadcast_to(
|
||||
tf.range(traj_len)[:, None],
|
||||
[traj_len, action_chunk_size],
|
||||
)
|
||||
|
||||
# Cap to length of the sequence --> final chunks will repeat the last action
|
||||
# This makes sense, since we are using absolute joint + gripper position actions
|
||||
action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)
|
||||
|
||||
# Gather the actions for each chunk
|
||||
traj["actions"] = tf.gather(traj["actions"], action_chunk_indices)
|
||||
return traj
|
||||
|
||||
dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
|
||||
|
||||
# Flatten: map from trajectory dataset to dataset of individual action chunks
|
||||
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
|
||||
|
||||
# Filter data that doesn't pass the filter
|
||||
def filter_from_dict(frame):
|
||||
return frame["passes_filter"]
|
||||
|
||||
dataset = dataset.filter(filter_from_dict)
|
||||
|
||||
# Remove "passes_filter" key from output
|
||||
def remove_passes_filter(frame):
|
||||
frame.pop("passes_filter")
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(remove_passes_filter)
|
||||
|
||||
# Decode images: RLDS saves encoded images, only decode now for efficiency
|
||||
def decode_images(traj):
|
||||
traj["observation"]["image"] = tf.io.decode_image(
|
||||
traj["observation"]["image"], expand_animations=False, dtype=tf.uint8
|
||||
)
|
||||
traj["observation"]["wrist_image"] = tf.io.decode_image(
|
||||
traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8
|
||||
)
|
||||
return traj
|
||||
|
||||
dataset = dataset.frame_map(decode_images, num_parallel_calls)
|
||||
|
||||
# Shuffle, batch
|
||||
dataset = dataset.shuffle(shuffle_buffer_size)
|
||||
dataset = dataset.batch(batch_size)
|
||||
# Note =>> Seems to reduce memory usage without affecting speed?
|
||||
dataset = dataset.with_ram_budget(1)
|
||||
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.dataset.as_numpy_iterator()
|
||||
|
||||
def __len__(self):
|
||||
# This is the approximate number of samples in DROID after filtering.
|
||||
# Easier to hardcode than to iterate through the dataset and compute it.
|
||||
return 20_000_000
|
||||
@@ -0,0 +1,116 @@
|
||||
"""RoboArena baseline policy configs."""
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
import openpi.models.pi0_fast as pi0_fast
|
||||
import openpi.models.tokenizer as _tokenizer
|
||||
import openpi.policies.droid_policy as droid_policy
|
||||
import openpi.transforms as _transforms
|
||||
|
||||
ModelType: TypeAlias = _model.ModelType
|
||||
|
||||
|
||||
def get_roboarena_configs():
|
||||
# Import here to avoid circular imports.
|
||||
from openpi.training.config import AssetsConfig
|
||||
from openpi.training.config import DataConfig
|
||||
from openpi.training.config import SimpleDataConfig
|
||||
from openpi.training.config import TrainConfig
|
||||
|
||||
return [
|
||||
#
|
||||
# RoboArena DROID baseline inference configs.
|
||||
#
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
|
||||
name="paligemma_binning_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
max_token_len=400,
|
||||
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
|
||||
name="paligemma_fast_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
|
||||
name="paligemma_fast_specialist_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
fast_model_tokenizer=_tokenizer.FASTTokenizer,
|
||||
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# Trained from PaliGemma, using FSQ tokenizer.
|
||||
name="paligemma_vq_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=8,
|
||||
action_horizon=15,
|
||||
fast_model_tokenizer=_tokenizer.FSQTokenizer,
|
||||
fast_model_tokenizer_kwargs={"fsq_tokenizer_path": "gs://openpi-assets/tokenizers/droid_fsq_tokenizer"},
|
||||
),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
|
||||
name="paligemma_diffusion_droid",
|
||||
model=pi0_config.Pi0Config(action_horizon=10, action_dim=8),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,703 @@
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import SupportsIndex, Sequence, List, Dict, Any, Tuple, Optional, Union, TypeVar, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
|
||||
import openpi.transforms as _transforms
|
||||
from pdb import set_trace
|
||||
import logging
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
import openpi.training.config as _config
|
||||
import openpi.shared.normalize as normalize
|
||||
|
||||
def detect_gripper_change_step(
|
||||
dataset,
|
||||
select_actions: list[str] = ["action"],
|
||||
gripper_dim: int = -1,
|
||||
threshold_method: str = "std_multiplier",
|
||||
threshold_multiplier: float = 2.0,
|
||||
min_threshold: float = 0.001,
|
||||
max_threshold: float = 1.0,
|
||||
plot_gripper_changes: bool = False,
|
||||
):
|
||||
"""
|
||||
Detect the step of gripper change. Only work for the self-collected dataset.
|
||||
Modifies the dataset in place by adding 'gripper_change_step_idx' attribute.
|
||||
This version uses a sliding window of size 4 centered around non_zero_idx,
|
||||
including the indices and removing duplicates.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
select_actions: List of action keys to process
|
||||
gripper_dim: Dimension index for gripper in the action vector
|
||||
threshold_method: Method to calculate threshold ('std_multiplier', 'percentile', 'absolute')
|
||||
threshold_multiplier: Multiplier for std-based threshold
|
||||
min_threshold: Minimum threshold value to avoid too sensitive detection
|
||||
max_threshold: Maximum threshold value to avoid missing large changes
|
||||
plot_gripper_changes: Whether to plot gripper changes visualization
|
||||
"""
|
||||
episode_lengths = [ep_dict["length"] for ep_dict in dataset.meta.episodes.values()]
|
||||
cumulative_lengths = np.cumsum(episode_lengths)
|
||||
|
||||
all_window_indices = set() # Use a set for automatic deduplication
|
||||
|
||||
for action_key in select_actions:
|
||||
action_values = dataset.hf_dataset[action_key]
|
||||
|
||||
delta_action = np.diff(action_values, axis=0)
|
||||
|
||||
# Handle episode boundaries
|
||||
for end_idx in cumulative_lengths[:-1]:
|
||||
if end_idx - 1 < len(delta_action) and end_idx - 2 >= 0:
|
||||
delta_action[end_idx - 1] = delta_action[end_idx - 2]
|
||||
elif end_idx - 1 < len(delta_action):
|
||||
delta_action[end_idx - 1] = 0
|
||||
|
||||
if delta_action.ndim == 1:
|
||||
delta_action = delta_action[:, np.newaxis]
|
||||
|
||||
assert delta_action.ndim == 2
|
||||
|
||||
# Extract gripper delta values
|
||||
gripper_delta = delta_action[:, gripper_dim]
|
||||
|
||||
# Calculate threshold based on statistical properties
|
||||
if threshold_method == "std_multiplier":
|
||||
# Use standard deviation to filter out small tremors
|
||||
std_val = np.std(gripper_delta)
|
||||
threshold = threshold_multiplier * std_val
|
||||
elif threshold_method == "percentile":
|
||||
# Use percentile-based threshold (e.g., 90th percentile)
|
||||
threshold = np.percentile(np.abs(gripper_delta), 85)
|
||||
elif threshold_method == "absolute":
|
||||
# Use absolute threshold
|
||||
threshold = threshold_multiplier
|
||||
else:
|
||||
raise ValueError(f"Unknown threshold_method: {threshold_method}")
|
||||
|
||||
# Clamp threshold to reasonable bounds
|
||||
threshold = np.clip(threshold, min_threshold, max_threshold)
|
||||
|
||||
# Find indices where gripper change exceeds threshold
|
||||
significant_change_idx = np.where(np.abs(gripper_delta) > threshold)[0]
|
||||
|
||||
cur_window_indices = set()
|
||||
for idx in significant_change_idx:
|
||||
# Create a sliding window of size 4 centered around idx.
|
||||
# The window should include [idx-2, idx-1, idx, idx+1].
|
||||
# This means starting 2 before and ending 1 after.
|
||||
window_start = idx - 2
|
||||
window_end = idx + 1
|
||||
|
||||
# Generate indices for the current window and ensure they are non-negative
|
||||
# and within the bounds of the original action_values length.
|
||||
# The maximum index possible is len(action_values) - 1.
|
||||
# Since delta_action is len(action_values) - 1, the index refers to
|
||||
# the step *before* the change. So the max index we want is effectively
|
||||
# len(action_values) - 1, which corresponds to the last valid step index.
|
||||
# If the original index is `i`, delta_action[i] corresponds to the change
|
||||
# from step `i` to `i+1`. We want to include step `i` and its neighbors.
|
||||
# The maximum index for steps is `len(action_values) - 1`.
|
||||
# So, the window indices should not exceed `len(action_values) - 1`.
|
||||
max_possible_idx = len(action_values) - 1
|
||||
|
||||
# Ensure indices are within valid range [0, max_possible_idx]
|
||||
current_window_indices = np.arange(
|
||||
max(0, window_start), min(max_possible_idx + 1, window_end + 1)
|
||||
)
|
||||
for w_idx in current_window_indices:
|
||||
cur_window_indices.add(w_idx)
|
||||
all_window_indices.add(w_idx)
|
||||
|
||||
if plot_gripper_changes:
|
||||
num_episodes_to_plot = 5
|
||||
end_index_for_plot = cumulative_lengths[num_episodes_to_plot - 1] - 1
|
||||
delta_action_to_plot = delta_action[:end_index_for_plot]
|
||||
|
||||
# Filter gripper_change_step_idx
|
||||
gripper_change_step_idx = np.array(sorted(list(cur_window_indices))).astype(np.int32)
|
||||
gripper_change_step_idx_to_plot = gripper_change_step_idx[gripper_change_step_idx < end_index_for_plot]
|
||||
|
||||
plot_gripper_changes_in_subplots(
|
||||
delta_action_to_plot,
|
||||
gripper_change_step_idx_to_plot,
|
||||
episode_lengths,
|
||||
num_episodes_to_plot,
|
||||
gripper_dim,
|
||||
f"{action_key}_gripper_change"
|
||||
)
|
||||
|
||||
# Convert the set to a numpy array and sort it
|
||||
gripper_change_step_idx = np.array(sorted(list(all_window_indices))).astype(np.int32)
|
||||
|
||||
print(f"Total unique gripper change steps: {len(gripper_change_step_idx)}, Total steps: {len(action_values)}")
|
||||
|
||||
dataset.gripper_change_step_idx = gripper_change_step_idx
|
||||
# set_trace()
|
||||
|
||||
return dataset
|
||||
|
||||
class Dataset(Protocol[T_co]):
|
||||
"""Interface for a dataset with random access."""
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> T_co:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError("Subclasses of Dataset should implement __len__.")
|
||||
|
||||
class TransformedDataset(Dataset[T_co]):
|
||||
def __init__(self, dataset: Dataset, transforms: Sequence[_transforms.DataTransformFn]):
|
||||
self._dataset = dataset
|
||||
self._transform = _transforms.compose(transforms)
|
||||
|
||||
def __getitem__(self, index: SupportsIndex) -> T_co:
|
||||
return self._transform(self._dataset[index])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dataset)
|
||||
|
||||
def transform_dataset(dataset: Dataset, data_config: _config.DataConfig) -> Dataset:
|
||||
"""Transform the dataset by applying the data transforms."""
|
||||
norm_stats = {}
|
||||
norm_stats = data_config.norm_stats
|
||||
|
||||
return TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
_transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
)
|
||||
|
||||
class MixtureDataset(Dataset):
|
||||
"""
|
||||
A composite dataset that combines multiple datasets, allowing for weighted sampling
|
||||
and specific handling based on training stage (e.g., pretrain, finetune) and
|
||||
gripper change detection for augmentation.
|
||||
|
||||
This dataset flattens all eligible samples from its constituent datasets and assigns
|
||||
sampling weights based on configuration and heuristics (e.g., `gripper_aug_ratio`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datasets: Sequence[Dataset],
|
||||
datasets_name: Sequence[str],
|
||||
datasets_meta: Sequence[LeRobotDatasetMetadata],
|
||||
datasets_weights: Dict[str, float] = None,
|
||||
gripper_aug_ratio: float = 1.0,
|
||||
shuffle: bool = True,
|
||||
):
|
||||
"""
|
||||
Initializes the MixtureDataset.
|
||||
|
||||
Args:
|
||||
datasets (Sequence[Dataset]): A list of `Dataset` objects to be combined.
|
||||
datasets_name (Sequence[str]): A list of names corresponding to each dataset in `datasets`.
|
||||
datasets_meta (Sequence[LeRobotDatasetMetadata]): Metadata for each dataset,
|
||||
typically containing `num_episodes`, `num_frames`, `fps`, and `num_indices`.
|
||||
datasets_weights (Dict[str, float], optional): A dictionary mapping dataset names
|
||||
to their base sampling weights. If None, equal weights are assumed.
|
||||
is_eval (bool): If True, the dataset is configured for evaluation, potentially
|
||||
limiting the number of episodes and disabling shuffling for reproducibility.
|
||||
num_eval_episodes (int, optional): The number of episodes to select for evaluation.
|
||||
Only used if `is_eval` is True.
|
||||
stage (str): The current training stage (e.g., "stage1_pretrain_wm").
|
||||
This affects how indices are sampled from the underlying datasets.
|
||||
gripper_aug_ratio (float): A multiplier applied to the weights of samples
|
||||
that contain a detected gripper change. Useful for augmenting rare events.
|
||||
shuffle (bool): If True, the flat sample map and sampling weights are shuffled
|
||||
after initial creation. Ignored if `is_eval` is True.
|
||||
"""
|
||||
self.datasets = datasets
|
||||
self.datasets_name = datasets_name
|
||||
self.meta = datasets_meta
|
||||
# Extract total number of episodes and frames for each dataset from metadata.
|
||||
self.num_episodes = [meta.info['total_episodes'] for meta in datasets_meta]
|
||||
self.num_frames = [meta.info['total_frames'] for meta in datasets_meta]
|
||||
|
||||
|
||||
# Compute the flattened list of (dataset_idx, sample_idx) pairs.
|
||||
# This involves sampling indices based on the stage and dataset type.
|
||||
self._compute_len(False)
|
||||
# Assign normalized sampling weights to each sample in the flattened map.
|
||||
self._get_weights(datasets_weights, gripper_aug_ratio)
|
||||
|
||||
# For training, ensure the sample map and weights are consistent.
|
||||
if len(self.flat_sample_map) != len(self.sample_weights):
|
||||
raise ValueError(
|
||||
f"Mismatch in flat sample map length ({len(self.flat_sample_map)}) "
|
||||
f"and sample weights length ({len(self.sample_weights)})."
|
||||
)
|
||||
if shuffle:
|
||||
# Shuffle both the sample map and weights in the same order for training.
|
||||
# This ensures random access to samples while maintaining their assigned probabilities.
|
||||
indices = np.random.permutation(len(self.flat_sample_map))
|
||||
self.flat_sample_map = [self.flat_sample_map[i] for i in indices]
|
||||
self.sample_weights = self.sample_weights[indices]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the total number of samples in the mixture dataset (after flattening and selection).
|
||||
This length represents the effective size of the dataset for iteration.
|
||||
"""
|
||||
return len(self.flat_sample_map)
|
||||
|
||||
def __getitem__(self, index: SupportsIndex):
|
||||
"""
|
||||
Retrieves a specific sample from one of the underlying datasets based on the
|
||||
flattened sample map.
|
||||
|
||||
Args:
|
||||
index (SupportsIndex): The index in the flattened `flat_sample_map` (0 to `len(self) - 1`).
|
||||
|
||||
Returns:
|
||||
Tuple[int, Any]: A tuple containing the original dataset index and the
|
||||
sample data (dictionary) from that dataset.
|
||||
|
||||
Raises:
|
||||
IndexError: If the provided index is out of bounds for the dataset.
|
||||
"""
|
||||
if not (0 <= index < len(self.flat_sample_map)):
|
||||
raise IndexError(f"Index {index} is out of bounds for the dataset (size: {len(self.flat_sample_map)}).")
|
||||
|
||||
# Retrieve the original dataset index and sample index from the flattened map.
|
||||
dataset_idx, sample_idx = self.flat_sample_map[index]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
||||
def _compute_len(self, is_eval: bool = False):
|
||||
"""
|
||||
Pre-computes and stores `all_sample_indices`, a list of episode indices sampled
|
||||
from each constituent dataset. This method prepares the data for `_create_flat_sample_map`.
|
||||
|
||||
Args:
|
||||
is_eval (bool): Flag indicating if indices are being computed for an evaluation dataset.
|
||||
"""
|
||||
self.all_sample_indices: List[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = []
|
||||
|
||||
for i, (ds, meta) in enumerate(zip(self.datasets, self.meta)):
|
||||
# Access the underlying LeRobotDataset or MultiLeRobotDataset, bypassing TransformedDataset wrapper.
|
||||
actual_ds = ds._dataset if isinstance(ds, TransformedDataset) else ds
|
||||
|
||||
# Determine the number of indices to sample for this dataset based on the current stage.
|
||||
# "stage1" typically uses a limited number of indices (`num_indices`), while other stages
|
||||
# might use all available data or a different strategy.
|
||||
num_indices = None
|
||||
|
||||
if isinstance(actual_ds, MultiLeRobotDataset):
|
||||
# For MultiLeRobotDataset, iterate through its sub-datasets to get indices.
|
||||
indices_list_for_multi_ds = []
|
||||
for sub_ds in actual_ds._datasets:
|
||||
_from = sub_ds.episode_data_index["from"]
|
||||
_to = sub_ds.episode_data_index["to"]
|
||||
indices = self._sample_indices(
|
||||
_from, _to, num_indices, is_eval=is_eval, dataset_name=self.datasets_name[i]
|
||||
)
|
||||
indices_list_for_multi_ds.append(indices)
|
||||
self.all_sample_indices.append(indices_list_for_multi_ds)
|
||||
elif isinstance(actual_ds, LeRobotDataset):
|
||||
# For a single LeRobotDataset.
|
||||
_from = actual_ds.episode_data_index["from"]
|
||||
_to = actual_ds.episode_data_index["to"]
|
||||
indices = self._sample_indices(
|
||||
_from, _to, num_indices, is_eval=is_eval, dataset_name=self.datasets_name[i]
|
||||
)
|
||||
self.all_sample_indices.append(indices)
|
||||
else:
|
||||
raise TypeError(f"Unsupported dataset type: {type(actual_ds)}. "
|
||||
"Expected `LeRobotDataset` or `MultiLeRobotDataset`.")
|
||||
|
||||
# After collecting all sampled episode indices, flatten them into `flat_sample_map`.
|
||||
self.flat_sample_map = self._create_flat_sample_map()
|
||||
|
||||
def _create_flat_sample_map(self) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Converts the potentially nested structure of `self.all_sample_indices` (which can be
|
||||
lists of lists of tensors, or lists of tensors) into a flat list of
|
||||
`(original_dataset_index, sample_index_within_original_dataset)` tuples.
|
||||
|
||||
This flattened map is then used by `__getitem__` to efficiently retrieve samples.
|
||||
"""
|
||||
flat_map = []
|
||||
for dataset_idx, sample_group in enumerate(self.all_sample_indices):
|
||||
# Case 1: `MultiLeRobotDataset` where `sample_group` is `List[List[torch.Tensor]]`
|
||||
if isinstance(sample_group, list) and len(sample_group) > 0 and isinstance(sample_group[0], list):
|
||||
for sub_group in sample_group: # Iterate through sub-datasets' index lists
|
||||
for tensor_of_indices in sub_group: # Iterate through tensors of indices for episodes
|
||||
for i in range(tensor_of_indices.numel()):
|
||||
flat_map.append((dataset_idx, tensor_of_indices[i].item()))
|
||||
# Case 2: `LeRobotDataset` where `sample_group` is `List[torch.Tensor]`
|
||||
elif isinstance(sample_group, list) and len(sample_group) > 0 and isinstance(sample_group[0], torch.Tensor):
|
||||
for tensor_of_indices in sample_group:
|
||||
for i in range(tensor_of_indices.numel()):
|
||||
flat_map.append((dataset_idx, tensor_of_indices[i].item()))
|
||||
# Case 3: A rare case where `sample_group` might be a single `torch.Tensor` directly
|
||||
elif isinstance(sample_group, torch.Tensor):
|
||||
for i in range(sample_group.numel()):
|
||||
flat_map.append((dataset_idx, sample_group[i].item()))
|
||||
return flat_map
|
||||
|
||||
def _sample_indices(
|
||||
self,
|
||||
start: List[int],
|
||||
end: List[int],
|
||||
num_frames: Optional[int],
|
||||
random_pad: bool = False,
|
||||
is_eval: bool = False,
|
||||
dataset_name: str = None, # Added for potential future stage-specific logic
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Samples indices for episodes based on the current stage and dataset-specific rules.
|
||||
This function is called per episode to determine which frames to include.
|
||||
|
||||
Args:
|
||||
start (List[int]): List of starting frame indices for each episode.
|
||||
end (List[int]): List of ending frame indices for each episode.
|
||||
num_frames (Optional[int]): The target number of frames to sample per episode.
|
||||
This is primarily used for "stage1" where sampling
|
||||
a fixed number of frames per episode might be desired.
|
||||
random_pad (bool): If True, and `frame_count < target_frames`, shorter episodes
|
||||
will be padded with randomly selected indices from themselves.
|
||||
is_eval (bool): If True, adjusts indices for evaluation (e.g., shifting by 1 for stage1
|
||||
to ensure predicted frames are not identical to observed frames).
|
||||
dataset_name (str): The name of the dataset (for debugging or future dataset-specific sampling rules).
|
||||
|
||||
Returns:
|
||||
List[torch.Tensor]: A list of PyTorch tensors, where each tensor contains the
|
||||
sampled frame indices for a single episode.
|
||||
"""
|
||||
all_indices_for_episodes = []
|
||||
for _start, _end in zip(start, end):
|
||||
frame_count = _end - _start # Total frames available in this episode.
|
||||
target_frames = frame_count
|
||||
if frame_count >= target_frames:
|
||||
# If enough frames are available, linearly space the indices to sample `target_frames`.
|
||||
indices = torch.linspace(_start, _end - 1, steps=target_frames).long()
|
||||
else:
|
||||
# If fewer frames than `target_frames` are available.
|
||||
if random_pad:
|
||||
# Pad the existing frames with randomly chosen duplicates from the episode.
|
||||
pad_size = target_frames - frame_count
|
||||
indices = torch.arange(_start, _end) # All available original indices
|
||||
# Randomly sample `pad_size` indices from the existing ones.
|
||||
pad_indices = indices[torch.randint(0, frame_count, (pad_size,))]
|
||||
indices = torch.cat([indices, pad_indices]) # Combine original and padded indices
|
||||
indices = indices[torch.randperm(target_frames)] # Randomly permute to mix original and padded.
|
||||
else:
|
||||
# If not padding, simply use all available frames.
|
||||
indices = torch.arange(_start, _end)
|
||||
|
||||
all_indices_for_episodes.append(indices)
|
||||
|
||||
return all_indices_for_episodes
|
||||
|
||||
def _get_weights(self, datasets_weights: Dict[str, float], aug_ratio: float = 1.0):
|
||||
"""
|
||||
Assigns normalized sampling weights to each individual sample in the flattened map.
|
||||
Weights are adjusted based on base dataset weights and `gripper_aug_ratio` for
|
||||
samples that have a detected gripper change.
|
||||
|
||||
Args:
|
||||
datasets_weights (Dict[str, float]): A dictionary mapping dataset names to their
|
||||
base sampling weights. If a dataset name is
|
||||
not found, a default weight of 1.0 is used.
|
||||
aug_ratio (float): The augmentation ratio (multiplier) to apply to the base weight
|
||||
for samples where a gripper change is detected.
|
||||
"""
|
||||
self.sample_weights: List[float] = []
|
||||
self.datasets_weight_map: Dict[str, float] = {}
|
||||
|
||||
if datasets_weights is None:
|
||||
num_datasets = len(self.datasets_name)
|
||||
datasets_weights = {name: 1.0 / num_datasets for name in self.datasets_name}
|
||||
|
||||
for idx, ds_name in enumerate(self.datasets_name):
|
||||
# Access the underlying dataset to get gripper change information.
|
||||
# It might be wrapped in a TransformedDataset, so we unwrap it.
|
||||
current_base_dataset = self.datasets[idx]._dataset if isinstance(self.datasets[idx], TransformedDataset) else self.datasets[idx]
|
||||
base_weight = datasets_weights.get(ds_name, 1.0) # Get base weight for this dataset
|
||||
|
||||
individual_weights_for_ds: List[float] = []
|
||||
|
||||
# Logic to retrieve `gripper_change_step_idx` and assign weights.
|
||||
if isinstance(current_base_dataset, MultiLeRobotDataset):
|
||||
# For MultiLeRobotDataset, iterate through its sub-datasets.
|
||||
for idj, sub_ds in enumerate(current_base_dataset._datasets):
|
||||
gripper_change_step_idx = getattr(sub_ds, 'gripper_change_step_idx', None)
|
||||
if gripper_change_step_idx is not None:
|
||||
sampled_indices_sub_ds = self.all_sample_indices[idx][idj]
|
||||
for tensor_of_indices in sampled_indices_sub_ds:
|
||||
for step_idx in tensor_of_indices.tolist():
|
||||
if step_idx in gripper_change_step_idx:
|
||||
individual_weights_for_ds.append(base_weight * aug_ratio)
|
||||
else:
|
||||
individual_weights_for_ds.append(base_weight)
|
||||
elif isinstance(current_base_dataset, LeRobotDataset):
|
||||
# For a single LeRobotDataset.
|
||||
gripper_change_step_idx = getattr(current_base_dataset, 'gripper_change_step_idx', None)
|
||||
if gripper_change_step_idx is not None:
|
||||
sampled_indices_ds = self.all_sample_indices[idx]
|
||||
for tensor_of_indices in sampled_indices_ds:
|
||||
for step_idx in tensor_of_indices.tolist():
|
||||
if step_idx in gripper_change_step_idx:
|
||||
individual_weights_for_ds.append(base_weight * aug_ratio)
|
||||
else:
|
||||
individual_weights_for_ds.append(base_weight)
|
||||
if gripper_change_step_idx is None:
|
||||
print(f"Warning: Gripper change detection not fully supported for dataset type {type(current_base_dataset)}. "
|
||||
"Assigning uniform weights based on `base_weight` for this dataset.")
|
||||
num_samples_for_ds_in_flat_map = sum(1 for map_ds_idx, _ in self.flat_sample_map if map_ds_idx == idx)
|
||||
individual_weights_for_ds.extend([base_weight] * num_samples_for_ds_in_flat_map)
|
||||
|
||||
# Accumulate individual weights for all samples and for the dataset's total.
|
||||
self.sample_weights.extend(individual_weights_for_ds)
|
||||
self.datasets_weight_map[ds_name] = self.datasets_weight_map.get(ds_name, 0.0) + sum(individual_weights_for_ds)
|
||||
|
||||
# Final normalization of all individual sample weights across the entire mixture dataset.
|
||||
total_sum_of_all_individual_weights = sum(self.sample_weights)
|
||||
if total_sum_of_all_individual_weights > 0:
|
||||
self.sample_weights = np.array(self.sample_weights, dtype=np.float32)
|
||||
self.sample_weights = self.sample_weights / total_sum_of_all_individual_weights
|
||||
else:
|
||||
self.sample_weights = np.array([], dtype=np.float32)
|
||||
|
||||
# Normalize the `datasets_weight_map` to reflect the effective proportion of each dataset
|
||||
# in the final sampling distribution.
|
||||
if total_sum_of_all_individual_weights > 0:
|
||||
for k in self.datasets_weight_map:
|
||||
self.datasets_weight_map[k] /= total_sum_of_all_individual_weights
|
||||
else:
|
||||
self.datasets_weight_map = {k: 0.0 for k in self.datasets_weight_map} # All weights become zero.
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Returns a formatted string representation of the MixtureDataset,
|
||||
showing the effective sampling weights and dataset lengths.
|
||||
"""
|
||||
# Define ANSI escape codes for colored and bold text.
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
CYAN = "\033[96m"
|
||||
YELLOW = "\033[93m"
|
||||
GREEN = "\033[92m"
|
||||
MAGENTA = "\033[95m"
|
||||
|
||||
# Determine the maximum key length for consistent formatting.
|
||||
max_key_len = max(len(k) for k in self.datasets_weight_map.keys()) + 2 if self.datasets_weight_map else 20
|
||||
|
||||
# Build the lines of the string representation.
|
||||
lines = [
|
||||
f"{BOLD}{MAGENTA}######################################### 👈 Dataset Weight Map: ########################################{RESET}"
|
||||
]
|
||||
|
||||
# Add individual dataset information: name, number of samples, and effective weight.
|
||||
for idx, (name, weight) in enumerate(self.datasets_weight_map.items()):
|
||||
# Use `len(self.datasets[idx])` to get the number of samples in each transformed dataset.
|
||||
# Formatting to 2 decimal places for weight and 0 for sample count.
|
||||
lines.append(f"{CYAN}{name:<{max_key_len}} : {len(self.datasets[idx]):>18.0f} ({weight*100:>.2f}%){RESET}")
|
||||
|
||||
# Add a separator line.
|
||||
separator_length = len(lines[0]) - len(BOLD) - len(MAGENTA) - len(RESET) + 1
|
||||
lines.append("-" * separator_length)
|
||||
|
||||
# Add total episodes summary.
|
||||
lines.append(f"{CYAN}{'Total Episodes':<{max_key_len}}{RESET} : {YELLOW}{sum(self.num_episodes):>18.0f}{RESET}")
|
||||
|
||||
# Add the closing border, matching the length of the separator.
|
||||
lines.append(f"{BOLD}{MAGENTA}{'#' * separator_length}{RESET}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def create_mixture_dataset(
|
||||
data_configs_list,
|
||||
action_horizon,
|
||||
model_config,
|
||||
):
|
||||
all_datasets = []
|
||||
all_datasets_name = []
|
||||
all_datasets_meta = []
|
||||
all_datasets_weight = {}
|
||||
|
||||
for ds_configs in data_configs_list:
|
||||
for ds_config in ds_configs:
|
||||
repo_dir = ds_config.repo_dir
|
||||
task_id = ds_config.task_id
|
||||
subtask_id = ds_config.subtask_id
|
||||
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
|
||||
|
||||
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
|
||||
episodes = list(dataset_meta.episodes_stats.keys())
|
||||
if ds_config.data_ratio < 1.0:
|
||||
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
|
||||
logging.info(f"sub_length: {sub_length}")
|
||||
indices = np.random.choice(len(episodes), sub_length, replace=False)
|
||||
episodes = [episodes[i] for i in indices]
|
||||
print(f"downsample ratio: {ds_config.downsample_ratio}")
|
||||
dataset = LeRobotDataset(
|
||||
episodes=episodes,
|
||||
repo_id=root_path,
|
||||
root=root_path,
|
||||
delta_timestamps={
|
||||
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
|
||||
},
|
||||
)
|
||||
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
|
||||
gripper_aug_config = ds_config.gripper_aug_config
|
||||
dataset = detect_gripper_change_step(
|
||||
dataset,
|
||||
select_actions=gripper_aug_config["gripper_action_keys"],
|
||||
gripper_dim=gripper_aug_config["gripper_dim"],
|
||||
threshold_method=gripper_aug_config["gripper_threshold_method"],
|
||||
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
|
||||
min_threshold=gripper_aug_config["gripper_min_threshold"],
|
||||
max_threshold=gripper_aug_config["gripper_max_threshold"],
|
||||
)
|
||||
|
||||
dataset = transform_dataset(dataset, ds_config)
|
||||
dataset_name = root_path
|
||||
dataset_weight = ds_config.weight
|
||||
|
||||
all_datasets.append(dataset)
|
||||
all_datasets_name.append(dataset_name)
|
||||
all_datasets_meta.append(dataset_meta)
|
||||
all_datasets_weight[dataset_name] = dataset_weight
|
||||
|
||||
mixture_dataset = MixtureDataset(
|
||||
all_datasets,
|
||||
all_datasets_name,
|
||||
all_datasets_meta,
|
||||
all_datasets_weight,
|
||||
gripper_aug_ratio=10.0,
|
||||
)
|
||||
return mixture_dataset
|
||||
|
||||
def create_mixture_dataset_no_transform(
|
||||
data_configs_list,
|
||||
action_horizon,
|
||||
model_config
|
||||
):
|
||||
all_datasets = []
|
||||
all_datasets_name = []
|
||||
all_datasets_meta = []
|
||||
all_datasets_weight = {}
|
||||
|
||||
for ds_configs in data_configs_list:
|
||||
for ds_config in ds_configs:
|
||||
repo_dir = ds_config.repo_dir
|
||||
task_id = ds_config.task_id
|
||||
subtask_id = ds_config.subtask_id
|
||||
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
|
||||
|
||||
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
|
||||
episodes = list(dataset_meta.episodes_stats.keys())
|
||||
if ds_config.data_ratio < 1.0:
|
||||
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
|
||||
episodes = episodes[:sub_length]
|
||||
dataset = LeRobotDataset(
|
||||
episodes=episodes,
|
||||
repo_id=root_path,
|
||||
root=root_path,
|
||||
delta_timestamps={
|
||||
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
|
||||
},
|
||||
)
|
||||
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
|
||||
gripper_aug_config = ds_config.gripper_aug_config
|
||||
dataset = detect_gripper_change_step(
|
||||
dataset,
|
||||
select_actions=gripper_aug_config["gripper_action_keys"],
|
||||
gripper_dim=gripper_aug_config["gripper_dim"],
|
||||
threshold_method=gripper_aug_config["gripper_threshold_method"],
|
||||
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
|
||||
min_threshold=gripper_aug_config["gripper_min_threshold"],
|
||||
max_threshold=gripper_aug_config["gripper_max_threshold"],
|
||||
)
|
||||
|
||||
dataset_name = root_path
|
||||
dataset_weight = ds_config.weight
|
||||
|
||||
all_datasets.append(dataset)
|
||||
all_datasets_name.append(dataset_name)
|
||||
all_datasets_meta.append(dataset_meta)
|
||||
all_datasets_weight[dataset_name] = dataset_weight
|
||||
|
||||
mixture_dataset = MixtureDataset(
|
||||
all_datasets,
|
||||
all_datasets_name,
|
||||
all_datasets_meta,
|
||||
all_datasets_weight,
|
||||
gripper_aug_ratio=10.0,
|
||||
)
|
||||
return mixture_dataset
|
||||
|
||||
def create_mixture_dataset_calculate_norm_stats(
|
||||
data_configs_list,
|
||||
action_horizon,
|
||||
model_config
|
||||
):
|
||||
all_datasets = []
|
||||
all_datasets_name = []
|
||||
all_datasets_meta = []
|
||||
all_datasets_weight = {}
|
||||
|
||||
for ds_config in data_configs_list:
|
||||
repo_dir = ds_config.repo_dir
|
||||
task_id = ds_config.task_id
|
||||
subtask_id = ds_config.subtask_id
|
||||
root_path = f"{repo_dir}/{task_id}/{subtask_id}"
|
||||
|
||||
dataset_meta = LeRobotDatasetMetadata(repo_id=root_path, root=root_path)
|
||||
episodes = list(dataset_meta.episodes_stats.keys())
|
||||
if ds_config.data_ratio < 1.0:
|
||||
sub_length = int(len(episodes) * ds_config.data_ratio) + 1
|
||||
episodes = episodes[:sub_length]
|
||||
dataset = LeRobotDataset(
|
||||
episodes=episodes,
|
||||
repo_id=root_path,
|
||||
root=root_path,
|
||||
delta_timestamps={
|
||||
key: [t / (dataset_meta.fps // ds_config.downsample_ratio) for t in range(action_horizon)] for key in ds_config.action_sequence_keys
|
||||
},
|
||||
load_video=False,
|
||||
|
||||
)
|
||||
if ds_config.use_gripper_aug and ds_config.gripper_aug_config is not None:
|
||||
gripper_aug_config = ds_config.gripper_aug_config
|
||||
dataset = detect_gripper_change_step(
|
||||
dataset,
|
||||
select_actions=gripper_aug_config["gripper_action_keys"],
|
||||
gripper_dim=gripper_aug_config["gripper_dim"],
|
||||
threshold_method=gripper_aug_config["gripper_threshold_method"],
|
||||
threshold_multiplier=gripper_aug_config["gripper_threshold_multiplier"],
|
||||
min_threshold=gripper_aug_config["gripper_min_threshold"],
|
||||
max_threshold=gripper_aug_config["gripper_max_threshold"],
|
||||
)
|
||||
|
||||
dataset_name = root_path
|
||||
dataset_weight = ds_config.weight
|
||||
|
||||
all_datasets.append(dataset)
|
||||
all_datasets_name.append(dataset_name)
|
||||
all_datasets_meta.append(dataset_meta)
|
||||
all_datasets_weight[dataset_name] = dataset_weight
|
||||
|
||||
mixture_dataset = MixtureDataset(
|
||||
all_datasets,
|
||||
all_datasets_name,
|
||||
all_datasets_meta,
|
||||
all_datasets_weight,
|
||||
gripper_aug_ratio=10.0,
|
||||
)
|
||||
return mixture_dataset
|
||||
|
||||
123
policy/openpi-InternData-A1/src/openpi/training/optimizer.py
Normal file
123
policy/openpi-InternData-A1/src/openpi/training/optimizer.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import dataclasses
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LRScheduleConfig(Protocol):
|
||||
def create(self) -> optax.Schedule: ...
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CosineDecaySchedule(LRScheduleConfig):
|
||||
"""Cosine decay schedule with warmup."""
|
||||
|
||||
warmup_steps: int = 1_000
|
||||
peak_lr: float = 2.5e-5
|
||||
decay_steps: int = 30_000
|
||||
decay_lr: float = 2.5e-6
|
||||
|
||||
def create(self) -> optax.Schedule:
|
||||
return optax.warmup_cosine_decay_schedule(
|
||||
init_value=self.peak_lr / (self.warmup_steps + 1),
|
||||
peak_value=self.peak_lr,
|
||||
warmup_steps=self.warmup_steps,
|
||||
decay_steps=self.decay_steps,
|
||||
end_value=self.decay_lr,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class RsqrtDecaySchedule(LRScheduleConfig):
|
||||
"""Inverse square root decay schedule with warmup."""
|
||||
|
||||
warmup_steps: int = 1_000
|
||||
peak_lr: float = 5e-5
|
||||
timescale: float = 10_000
|
||||
|
||||
def create(self) -> optax.Schedule:
|
||||
return optax.join_schedules(
|
||||
[
|
||||
optax.linear_schedule(
|
||||
init_value=self.peak_lr / (self.warmup_steps + 1),
|
||||
end_value=self.peak_lr,
|
||||
transition_steps=self.warmup_steps,
|
||||
),
|
||||
lambda step: self.peak_lr / jnp.sqrt((self.timescale + step) / self.timescale),
|
||||
],
|
||||
[self.warmup_steps],
|
||||
)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WarmupConstantSchedule(LRScheduleConfig):
|
||||
"""Warmup constant schedule with warmup."""
|
||||
|
||||
warmup_steps: int = 2_000
|
||||
peak_lr: float = 5e-5
|
||||
|
||||
def create(self) -> optax.Schedule:
|
||||
return optax.warmup_constant_schedule(
|
||||
init_value=self.peak_lr / (self.warmup_steps + 1),
|
||||
peak_value=self.peak_lr,
|
||||
warmup_steps=self.warmup_steps,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OptimizerConfig(Protocol):
|
||||
def create(
|
||||
self,
|
||||
lr: optax.ScalarOrSchedule,
|
||||
weight_decay_mask: at.PyTree | None = None,
|
||||
) -> optax.GradientTransformation: ...
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AdamW(OptimizerConfig):
|
||||
"""AdamW optimizer."""
|
||||
|
||||
b1: float = 0.9
|
||||
b2: float = 0.95
|
||||
eps: float = 1e-8
|
||||
# Changing this to 0 can cause out-of-memory errors for some reason, so we set it to a negligible value.
|
||||
weight_decay: float = 1e-10
|
||||
clip_gradient_norm: float = 1.0
|
||||
|
||||
def create(
|
||||
self,
|
||||
lr: optax.ScalarOrSchedule,
|
||||
weight_decay_mask: at.PyTree | None = None,
|
||||
) -> optax.GradientTransformation:
|
||||
tx = optax.adamw(
|
||||
lr, b1=self.b1, b2=self.b2, eps=self.eps, weight_decay=self.weight_decay, mask=weight_decay_mask
|
||||
)
|
||||
|
||||
return optax.chain(optax.clip_by_global_norm(self.clip_gradient_norm), tx)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SGD(OptimizerConfig):
|
||||
"""SGD optimizer."""
|
||||
|
||||
lr: float = 5e-5
|
||||
momentum: float = 0.9
|
||||
nesterov: bool = False
|
||||
|
||||
def create(
|
||||
self,
|
||||
lr: optax.ScalarOrSchedule,
|
||||
weight_decay_mask: at.PyTree | None = None,
|
||||
) -> optax.GradientTransformation:
|
||||
assert weight_decay_mask is None, "Weight decay is not supported for SGD"
|
||||
return optax.sgd(lr, momentum=self.momentum, nesterov=self.nesterov)
|
||||
|
||||
|
||||
def create_optimizer(
|
||||
optimizer: OptimizerConfig, lr_schedule: LRScheduleConfig, weight_decay_mask: at.PyTree | None = None
|
||||
) -> optax.GradientTransformation:
|
||||
lr = lr_schedule.create()
|
||||
return optimizer.create(lr, weight_decay_mask=weight_decay_mask)
|
||||
102
policy/openpi-InternData-A1/src/openpi/training/sharding.py
Normal file
102
policy/openpi-InternData-A1/src/openpi/training/sharding.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
BATCH_AXIS = "batch"
|
||||
FSDP_AXIS = "fsdp"
|
||||
# In FSDP, we shard the data across both the batch and FSDP axes.
|
||||
DATA_AXIS = (BATCH_AXIS, FSDP_AXIS)
|
||||
|
||||
|
||||
class _MeshState:
|
||||
active_mesh: jax.sharding.Mesh | None = None
|
||||
|
||||
|
||||
def make_mesh(num_fsdp_devices: int) -> jax.sharding.Mesh:
|
||||
if jax.device_count() % num_fsdp_devices != 0:
|
||||
raise ValueError(
|
||||
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {num_fsdp_devices}."
|
||||
)
|
||||
mesh_shape = (jax.device_count() // num_fsdp_devices, num_fsdp_devices)
|
||||
return jax.make_mesh(mesh_shape, (BATCH_AXIS, FSDP_AXIS))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_mesh(mesh: jax.sharding.Mesh):
|
||||
"""Plumbing the mesh deep into the module tree is extremeley cumbersome; until the JAX team lands a better API, a
|
||||
custom context manager like this one is the recommended way to maintain a reference to a global mesh. This is only used
|
||||
in `activation_sharding_constraint` below."""
|
||||
if _MeshState.active_mesh is not None:
|
||||
raise ValueError("Cannot nest set_mesh context managers.")
|
||||
_MeshState.active_mesh = mesh
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_MeshState.active_mesh = None
|
||||
|
||||
|
||||
def activation_sharding_constraint(pytree):
|
||||
if _MeshState.active_mesh is None:
|
||||
return pytree
|
||||
return jax.lax.with_sharding_constraint(
|
||||
pytree, jax.sharding.NamedSharding(_MeshState.active_mesh, jax.sharding.PartitionSpec(DATA_AXIS))
|
||||
)
|
||||
|
||||
|
||||
def fsdp_sharding(
|
||||
pytree,
|
||||
mesh: jax.sharding.Mesh,
|
||||
*,
|
||||
min_size_mbytes: int = 4, # 4 MiB
|
||||
log: bool = False,
|
||||
):
|
||||
"""Apply FSDP sharding to a pytree of arrays based on the mesh shape.
|
||||
|
||||
Args:
|
||||
pytree: A pytree to be apply sharding specified by the mesh, note that only array types (eg. contains .shape attr)
|
||||
will be considered for sharding.
|
||||
mesh: The mesh being used for applying sharding on to pytree.
|
||||
min_size_mbytes: The minimum size of the array in MiB to be considered for sharding, any array smaller than this
|
||||
will be replicated.
|
||||
log: If true, will log the sharding decisions for arrays that are being considered for sharding.
|
||||
|
||||
Returns:
|
||||
The sharded pytree.
|
||||
"""
|
||||
min_size_bytes = min_size_mbytes * 2**20
|
||||
|
||||
def _shard_arr(kp, array: jax.ShapeDtypeStruct):
|
||||
# if fsdp is not actually going to be used, replicate everything to avoid extraneous logging
|
||||
if mesh.shape[FSDP_AXIS] == 1:
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
# replicate scalar and vector arrays
|
||||
if not hasattr(array, "shape"):
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
if len(array.shape) < 2:
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
# replicate small arrays
|
||||
if (arr_size := np.prod(array.shape) * np.dtype(array.dtype).itemsize) < min_size_bytes:
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
# shard matrices and larger tensors along the largest axis that is divisible by the fsdp dimension
|
||||
axes = np.argsort(array.shape)[::-1]
|
||||
spec = [None] * len(axes)
|
||||
for i in axes:
|
||||
if array.shape[i] % mesh.shape[FSDP_AXIS] == 0:
|
||||
if log:
|
||||
logging.info(
|
||||
f"Sharding {jax.tree_util.keystr(kp)} of shape {array.shape} ({arr_size / 2**20:.2f} MiB) along axis {i}"
|
||||
)
|
||||
spec[i] = FSDP_AXIS
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(*spec))
|
||||
|
||||
# replicate if no valid sharding was found
|
||||
if log:
|
||||
logging.warning(
|
||||
f"Could not find a valid sharding for {jax.tree_util.keystr(kp)} of shape {array.shape} with mesh of shape {mesh.shape}"
|
||||
)
|
||||
return jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
return jax.tree_util.tree_map_with_path(_shard_arr, pytree)
|
||||
38
policy/openpi-InternData-A1/src/openpi/training/utils.py
Normal file
38
policy/openpi-InternData-A1/src/openpi/training/utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from flax import nnx
|
||||
from flax import struct
|
||||
import jax
|
||||
import optax
|
||||
|
||||
from openpi.models import model as _model
|
||||
from openpi.shared import array_typing as at
|
||||
|
||||
|
||||
@at.typecheck
|
||||
@struct.dataclass
|
||||
class TrainState:
|
||||
step: at.Int[at.ArrayLike, ""]
|
||||
params: nnx.State
|
||||
model_def: nnx.GraphDef[_model.BaseModel]
|
||||
opt_state: optax.OptState
|
||||
tx: optax.GradientTransformation = struct.field(pytree_node=False)
|
||||
|
||||
ema_decay: float | None = struct.field(pytree_node=False)
|
||||
ema_params: nnx.State | None = None
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def tree_to_info(tree: at.PyTree, interp_func: Callable[[Any], str] = str) -> str:
|
||||
"""Converts a PyTree into a human-readable string for logging. Optionally, `interp_func` can be provided to convert
|
||||
the leaf values to more meaningful strings.
|
||||
"""
|
||||
tree, _ = jax.tree_util.tree_flatten_with_path(tree)
|
||||
return "\n".join(f"{jax.tree_util.keystr(path)}: {interp_func(value)}" for path, value in tree)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def array_tree_to_info(tree: at.PyTree) -> str:
|
||||
"""Converts a PyTree of arrays into a human-readable string for logging."""
|
||||
return tree_to_info(tree, lambda x: f"{x.shape}@{x.dtype}")
|
||||
@@ -0,0 +1,103 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import re
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import flax.traverse_util
|
||||
import numpy as np
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.download as download
|
||||
from pathlib import Path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class WeightLoader(Protocol):
|
||||
def load(self, params: at.Params) -> at.Params:
|
||||
"""Loads the model weights.
|
||||
|
||||
Args:
|
||||
params: Parameters of the model. This is a nested structure of array-like objects that
|
||||
represent the model's parameters.
|
||||
|
||||
Returns:
|
||||
Loaded parameters. The structure must be identical to `params`. If returning a subset of
|
||||
the parameters the loader must merge the loaded parameters with `params`.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NoOpWeightLoader(WeightLoader):
|
||||
def load(self, params: at.Params) -> at.Params:
|
||||
return params
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CheckpointWeightLoader(WeightLoader):
|
||||
"""Loads an entire set of weights from a checkpoint.
|
||||
|
||||
Compatible with:
|
||||
trained checkpoints:
|
||||
example: "./checkpoints/<config>/<exp>/<step>/params"
|
||||
released checkpoints:
|
||||
example: "gs://openpi-assets/checkpoints/<model>/params"
|
||||
"""
|
||||
|
||||
params_path: str
|
||||
|
||||
def load(self, params: at.Params) -> at.Params:
|
||||
# We are loading np.ndarray and relying on the training code to properly convert and shard the params.
|
||||
loaded_params = _model.restore_params(download.maybe_download(self.params_path), restore_type=np.ndarray)
|
||||
# Add all missing LoRA weights.
|
||||
return _merge_params(loaded_params, params, missing_regex=".*lora.*")
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PaliGemmaWeightLoader(WeightLoader):
|
||||
"""Loads weights from the official PaliGemma checkpoint.
|
||||
|
||||
This will overwrite existing weights with similar names while keeping all extra weights intact.
|
||||
This allows us to support the action expert which is used by the Pi0 model.
|
||||
"""
|
||||
params_path: str
|
||||
|
||||
def load(self, params: at.Params) -> at.Params:
|
||||
path = Path(self.params_path)
|
||||
with path.open("rb") as f:
|
||||
flat_params = dict(np.load(f, allow_pickle=False))
|
||||
loaded_params = {"PaliGemma": flax.traverse_util.unflatten_dict(flat_params, sep="/")["params"]}
|
||||
# Add all missing weights.
|
||||
return _merge_params(loaded_params, params, missing_regex=".*")
|
||||
|
||||
|
||||
|
||||
def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params:
|
||||
"""Merges the loaded parameters with the reference parameters.
|
||||
|
||||
Args:
|
||||
loaded_params: The parameters to merge.
|
||||
params: The reference parameters.
|
||||
missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the merged parameters.
|
||||
"""
|
||||
flat_ref = flax.traverse_util.flatten_dict(params, sep="/")
|
||||
flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/")
|
||||
|
||||
# First, take all weights that are a subset of the reference weights.
|
||||
result = {}
|
||||
for k, v in flat_loaded.items():
|
||||
if k in flat_ref:
|
||||
result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v
|
||||
|
||||
flat_loaded.clear()
|
||||
|
||||
# Then, merge any missing weights as defined by the missing regex.
|
||||
pattern = re.compile(missing_regex)
|
||||
for k in {k for k in flat_ref if pattern.fullmatch(k)}:
|
||||
if k not in result:
|
||||
result[k] = flat_ref[k]
|
||||
|
||||
return flax.traverse_util.unflatten_dict(result, sep="/")
|
||||
Reference in New Issue
Block a user