diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index c06e620ba..5659e8727 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -14,15 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import deque + import torch from torch import nn -def populate_queues(queues, batch): +def populate_queues( + queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None +): + if exclude_keys is None: + exclude_keys = [] for key in batch: # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the # queues have the keys they want). - if key not in queues: + if key not in queues or key in exclude_keys: continue if len(queues[key]) != queues[key].maxlen: # initialize by copying the first observation several times until the queue is full diff --git a/pyproject.toml b/pyproject.toml index 68122794b..127295b68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ intelrealsense = [ "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] pi0 = ["transformers>=4.48.0"] -smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"] +smolvla = ["transformers>=4.50.3"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",