forked from tangger/lerobot
fix(smolvla): update record.py, fix populate_queues and remove unused dependencies (#1208)
This commit is contained in:
@@ -14,15 +14,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def populate_queues(queues, batch):
|
||||
def populate_queues(
|
||||
queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None
|
||||
):
|
||||
if exclude_keys is None:
|
||||
exclude_keys = []
|
||||
for key in batch:
|
||||
# Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the
|
||||
# queues have the keys they want).
|
||||
if key not in queues:
|
||||
if key not in queues or key in exclude_keys:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
# initialize by copying the first observation several times until the queue is full
|
||||
|
||||
@@ -89,7 +89,7 @@ intelrealsense = [
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
|
||||
smolvla = ["transformers>=4.50.3"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
|
||||
Reference in New Issue
Block a user