Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
69 lines
1.9 KiB
Python
69 lines
1.9 KiB
Python
import torch
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
def main():
|
|
device = "cuda"
|
|
dataset_repo_id = "danaaubakirova/koch_test"
|
|
# model_name = "pi0_base"
|
|
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
|
|
ckpt_torch_dir = "lerobot/pi0"
|
|
|
|
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=0,
|
|
batch_size=1,
|
|
)
|
|
|
|
batch = next(iter(dataloader))
|
|
|
|
# To device
|
|
for k in batch:
|
|
if isinstance(batch[k], torch.Tensor):
|
|
batch[k] = batch[k].to(device=device, dtype=torch.float32)
|
|
|
|
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
|
|
cfg.pretrained_path = ckpt_torch_dir
|
|
policy = make_policy(cfg, device, ds_meta=dataset.meta)
|
|
|
|
# policy = torch.compile(policy, mode="reduce-overhead")
|
|
|
|
warmup_iters = 10
|
|
benchmark_iters = 30
|
|
|
|
# Warmup
|
|
for _ in range(warmup_iters):
|
|
torch.cuda.synchronize()
|
|
policy.select_action(batch)
|
|
policy.reset()
|
|
torch.cuda.synchronize()
|
|
|
|
# Benchmark
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
start_event.record()
|
|
for _ in range(benchmark_iters):
|
|
policy.select_action(batch)
|
|
policy.reset()
|
|
end_event.record()
|
|
|
|
# Synchronize and measure time
|
|
torch.cuda.synchronize()
|
|
elapsed_time_ms = start_event.elapsed_time(end_event)
|
|
|
|
avg_time_per_iter = elapsed_time_ms / benchmark_iters
|
|
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
with torch.inference_mode():
|
|
main()
|