fix bug about sampling time from beta distribution (#1605)

* fix bug about sampling t from beta distribution

* fix: address review comments

---------
This commit is contained in:
Kleist Bond
2025-07-28 21:17:30 +08:00
committed by GitHub
parent c3d5e494c0
commit 4b88842d20

View File

@@ -515,9 +515,10 @@ class PI0FlowMatching(nn.Module):
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
return time
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks