fix bug about sampling time from beta distribution (#1605)
* fix bug about sampling t from beta distribution * fix: address review comments ---------
This commit is contained in:
@@ -515,9 +515,10 @@ class PI0FlowMatching(nn.Module):
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
return time
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
|
||||
Reference in New Issue
Block a user