Initial commit
This commit is contained in:
75
scripts/compute_norm_stats.py
Normal file
75
scripts/compute_norm_stats.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Compute normalization statistics for a config.
|
||||
|
||||
This script is used to compute the normalization statistics for a given config. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config assets directory.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_dataset(config: _config.TrainConfig) -> tuple[_config.DataConfig, _data_loader.Dataset]:
|
||||
data_config = config.data.create(config.assets_dirs, config.model)
|
||||
if data_config.repo_id is None:
|
||||
raise ValueError("Data config must have a repo_id")
|
||||
dataset = _data_loader.create_dataset(data_config, config.model)
|
||||
dataset = _data_loader.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
return data_config, dataset
|
||||
|
||||
|
||||
def main(config_name: str, max_frames: int | None = None):
|
||||
config = _config.get_config(config_name)
|
||||
data_config, dataset = create_dataset(config)
|
||||
|
||||
num_frames = len(dataset)
|
||||
shuffle = False
|
||||
|
||||
if max_frames is not None and max_frames < num_frames:
|
||||
num_frames = max_frames
|
||||
shuffle = True
|
||||
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=1,
|
||||
num_workers=8,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_frames,
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
|
||||
for key in keys:
|
||||
values = np.asarray(batch[key][0])
|
||||
stats[key].update(values.reshape(-1, values.shape[-1]))
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
output_path = config.assets_dirs / data_config.repo_id
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
Reference in New Issue
Block a user