forked from tangger/lerobot
Compare commits
1 Commits
my-fix-bas
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfd26eef5a |
@@ -168,7 +168,7 @@ available_datasets = sorted(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# lists all available policies from `lerobot/common/policies`
|
# lists all available policies from `lerobot/common/policies`
|
||||||
available_policies = ["act", "diffusion", "tdmpc", "vqbet", "smolvla"]
|
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||||
|
|
||||||
# lists all available robots from `lerobot/common/robot_devices/robots`
|
# lists all available robots from `lerobot/common/robot_devices/robots`
|
||||||
available_robots = [
|
available_robots = [
|
||||||
|
|||||||
@@ -662,7 +662,6 @@ class VLAFlowMatching(nn.Module):
|
|||||||
self.config.max_period,
|
self.config.max_period,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
time_emb = time_emb.type(dtype=dtype)
|
time_emb = time_emb.type(dtype=dtype)
|
||||||
|
|
||||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||||
|
|||||||
@@ -272,6 +272,7 @@ def control_loop(
|
|||||||
action = {"action": action}
|
action = {"action": action}
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
|
observation = {k: v for k, v in observation.items() if k not in ["task", "robot_type"]}
|
||||||
frame = {**observation, **action, "task": single_task}
|
frame = {**observation, **action, "task": single_task}
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import pytest
|
|||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
|
||||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||||
from tests.utils import require_env
|
from tests.utils import require_env
|
||||||
@@ -46,7 +45,7 @@ def test_available_policies():
|
|||||||
This test verifies that the class attribute `name` for all policies is
|
This test verifies that the class attribute `name` for all policies is
|
||||||
consistent with those listed in `lerobot/__init__.py`.
|
consistent with those listed in `lerobot/__init__.py`.
|
||||||
"""
|
"""
|
||||||
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy, SmolVLAPolicy]
|
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
|
||||||
policies = [pol_cls.name for pol_cls in policy_classes]
|
policies = [pol_cls.name for pol_cls in policy_classes]
|
||||||
assert set(policies) == set(lerobot.available_policies), policies
|
assert set(policies) == set(lerobot.available_policies), policies
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user