"""Compute normalization statistics for a config. This script is used to compute the normalization statistics for a given config. It will compute the mean and standard deviation of the data in the dataset and save it to the config metadata directory. """ import numpy as np import tqdm import tyro import openpi.shared.normalize as normalize import openpi.training.config as _config import openpi.training.data_loader as _data_loader def create_dataset(config: _config.TrainConfig) -> tuple[str, _data_loader.Dataset]: model = config.create_model() data_config = config.data.create(config.metadata_dir, model) if data_config.repo_id is None: raise ValueError("Data config must have a repo_id") dataset = _data_loader.TransformedDataset( _data_loader.create_dataset(data_config, model), [ *data_config.repack_transforms.inputs, *data_config.data_transforms.inputs, ], ) return data_config.repo_id, dataset def main(config_name: str, max_frames: int | None = None): config = _config.get_config(config_name) repo_id, dataset = create_dataset(config) num_frames = len(dataset) shuffle = False if max_frames is not None and max_frames < num_frames: num_frames = max_frames shuffle = True data_loader = _data_loader.TorchDataLoader( dataset, local_batch_size=1, num_workers=8, shuffle=shuffle, num_batches=num_frames, ) keys = ["state", "actions"] stats = {key: normalize.RunningStats() for key in keys} for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"): for key in keys: values = np.asarray(batch[key][0]) stats[key].update(values.reshape(-1, values.shape[-1])) norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} output_path = config.metadata_dir / repo_id print(f"Writing stats to: {output_path}") normalize.save(output_path, norm_stats) if __name__ == "__main__": tyro.cli(main)