Fix sample beta for smolvla as done for pi0, remove sample_beta func (#1611)
This commit is contained in:
@@ -90,12 +90,6 @@ def create_sinusoidal_pos_embedding(
|
|||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
|
|
||||||
def sample_beta(alpha, beta, bsize, device):
|
|
||||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
|
||||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
|
||||||
return gamma1 / (gamma1 + gamma2)
|
|
||||||
|
|
||||||
|
|
||||||
def make_att_2d_masks(pad_masks, att_masks):
|
def make_att_2d_masks(pad_masks, att_masks):
|
||||||
"""Copied from big_vision.
|
"""Copied from big_vision.
|
||||||
|
|
||||||
|
|||||||
@@ -194,12 +194,6 @@ def create_sinusoidal_pos_embedding(
|
|||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
|
|
||||||
def sample_beta(alpha, beta, bsize, device):
|
|
||||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
|
||||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
|
||||||
return gamma1 / (gamma1 + gamma2)
|
|
||||||
|
|
||||||
|
|
||||||
def make_att_2d_masks(pad_masks, att_masks):
|
def make_att_2d_masks(pad_masks, att_masks):
|
||||||
"""Copied from big_vision.
|
"""Copied from big_vision.
|
||||||
|
|
||||||
@@ -690,9 +684,10 @@ class VLAFlowMatching(nn.Module):
|
|||||||
return noise
|
return noise
|
||||||
|
|
||||||
def sample_time(self, bsize, device):
|
def sample_time(self, bsize, device):
|
||||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||||
|
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
||||||
time = time_beta * 0.999 + 0.001
|
time = time_beta * 0.999 + 0.001
|
||||||
return time.to(dtype=torch.float32, device=device)
|
return time
|
||||||
|
|
||||||
def embed_prefix(
|
def embed_prefix(
|
||||||
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
|
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
|
||||||
|
|||||||
Reference in New Issue
Block a user