forked from tangger/lerobot
chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -23,6 +23,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from tests.utils import require_package
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -44,7 +45,7 @@ class MockPolicy:
|
||||
|
||||
def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return a chunk of 20 dummy actions."""
|
||||
batch_size = len(observation["observation.state"])
|
||||
batch_size = len(observation[OBS_STATE])
|
||||
return torch.zeros(batch_size, 20, 6)
|
||||
|
||||
def __init__(self):
|
||||
@@ -77,7 +78,7 @@ def policy_server():
|
||||
|
||||
# Add mock lerobot_features that the observation similarity functions need
|
||||
server.lerobot_features = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"dtype": "float32",
|
||||
"shape": [6],
|
||||
"names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"],
|
||||
|
||||
Reference in New Issue
Block a user