From c7c3b477d6d39cf7046a7225eecc4e5debe67065 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 28 Jul 2025 17:28:55 +0200 Subject: [PATCH] Fix sample beta for smolvla as done for pi0, remove sample_beta func (#1611) --- src/lerobot/policies/pi0/modeling_pi0.py | 6 ------ src/lerobot/policies/smolvla/modeling_smolvla.py | 11 +++-------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index e9e6014f8..a34aa34f9 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -90,12 +90,6 @@ def create_sinusoidal_pos_embedding( return pos_emb -def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) - - def make_att_2d_masks(pad_masks, att_masks): """Copied from big_vision. diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index d2f78068c..469645e84 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -194,12 +194,6 @@ def create_sinusoidal_pos_embedding( return pos_emb -def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) - - def make_att_2d_masks(pad_masks, att_masks): """Copied from big_vision. @@ -690,9 +684,10 @@ class VLAFlowMatching(nn.Module): return noise def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) + beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) + return time def embed_prefix( self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None