Fix sample beta for smolvla as done for pi0, remove sample_beta func (#1611)

This commit is contained in:
Michel Aractingi
2025-07-28 17:28:55 +02:00
committed by GitHub
parent b267cd40f7
commit c7c3b477d6
2 changed files with 3 additions and 14 deletions

View File

@@ -90,12 +90,6 @@ def create_sinusoidal_pos_embedding(
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.

View File

@@ -194,12 +194,6 @@ def create_sinusoidal_pos_embedding(
return pos_emb
def sample_beta(alpha, beta, bsize, device):
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
return gamma1 / (gamma1 + gamma2)
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
@@ -690,9 +684,10 @@ class VLAFlowMatching(nn.Module):
return noise
def sample_time(self, bsize, device):
time_beta = sample_beta(1.5, 1.0, bsize, device)
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
time = time_beta * 0.999 + 0.001
return time.to(dtype=torch.float32, device=device)
return time
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None