diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 11feca96..e9e6014f 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -515,9 +515,10 @@ class PI0FlowMatching(nn.Module): return noise def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) + beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) + return time def embed_prefix( self, images, img_masks, lang_tokens, lang_masks