76 lines
2.4 KiB
Python
Executable File
76 lines
2.4 KiB
Python
Executable File
"""Compute normalization statistics for a config.
|
|
|
|
This script is used to compute the normalization statistics for a given config. It
|
|
will compute the mean and standard deviation of the data in the dataset and save it
|
|
to the config assets directory.
|
|
"""
|
|
|
|
import numpy as np
|
|
import tqdm
|
|
import tyro
|
|
|
|
import openpi.shared.normalize as normalize
|
|
import openpi.training.config as _config
|
|
import openpi.training.data_loader as _data_loader
|
|
import openpi.transforms as transforms
|
|
|
|
|
|
class RemoveStrings(transforms.DataTransformFn):
|
|
def __call__(self, x: dict) -> dict:
|
|
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
|
|
|
|
|
def create_dataset(config: _config.TrainConfig) -> tuple[_config.DataConfig, _data_loader.Dataset]:
|
|
data_config = config.data.create(config.assets_dirs, config.model)
|
|
if data_config.repo_id is None:
|
|
raise ValueError("Data config must have a repo_id")
|
|
dataset = _data_loader.create_dataset(data_config, config.model)
|
|
dataset = _data_loader.TransformedDataset(
|
|
dataset,
|
|
[
|
|
*data_config.repack_transforms.inputs,
|
|
*data_config.data_transforms.inputs,
|
|
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
|
|
RemoveStrings(),
|
|
],
|
|
)
|
|
return data_config, dataset
|
|
|
|
|
|
def main(config_name: str, max_frames: int | None = None):
|
|
config = _config.get_config(config_name)
|
|
data_config, dataset = create_dataset(config)
|
|
|
|
num_frames = len(dataset)
|
|
shuffle = False
|
|
|
|
if max_frames is not None and max_frames < num_frames:
|
|
num_frames = max_frames
|
|
shuffle = True
|
|
|
|
data_loader = _data_loader.TorchDataLoader(
|
|
dataset,
|
|
local_batch_size=1,
|
|
num_workers=8,
|
|
shuffle=shuffle,
|
|
num_batches=num_frames,
|
|
)
|
|
|
|
keys = ["state", "actions"]
|
|
stats = {key: normalize.RunningStats() for key in keys}
|
|
|
|
for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
|
|
for key in keys:
|
|
values = np.asarray(batch[key][0])
|
|
stats[key].update(values.reshape(-1, values.shape[-1]))
|
|
|
|
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
|
|
|
output_path = config.assets_dirs / data_config.repo_id
|
|
print(f"Writing stats to: {output_path}")
|
|
normalize.save(output_path, norm_stats)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tyro.cli(main)
|