From 4b88842d20c3872674a77a1cc06ca023b443bb9f Mon Sep 17 00:00:00 2001 From: Kleist Bond <61907235+KleistvonLiu@users.noreply.github.com> Date: Mon, 28 Jul 2025 21:17:30 +0800 Subject: [PATCH] fix bug about sampling time from beta distribution (#1605) * fix bug about sampling t from beta distribution * fix: address review comments --------- --- src/lerobot/policies/pi0/modeling_pi0.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 11feca964..e9e6014f8 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -515,9 +515,10 @@ class PI0FlowMatching(nn.Module): return noise def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) + beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) + return time def embed_prefix( self, images, img_masks, lang_tokens, lang_masks