chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -22,7 +22,7 @@ import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR
from tests.fixtures.constants import DUMMY_REPO_ID
@@ -63,7 +63,7 @@ def create_random_image() -> torch.Tensor:
def create_dummy_transition() -> dict:
return {
OBS_IMAGE: create_random_image(),
"action": torch.randn(4),
ACTION: torch.randn(4),
"reward": torch.tensor(1.0),
OBS_STATE: torch.randn(
10,
@@ -341,7 +341,7 @@ def test_sample_batch(replay_buffer):
f"{k} should be equal to one of the dummy states."
)
for got_action_item in got_batch_transition["action"]:
for got_action_item in got_batch_transition[ACTION]:
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
"Actions should be equal to the dummy actions."
)
@@ -378,7 +378,7 @@ def test_to_lerobot_dataset(tmp_path):
for i in range(len(ds)):
for feature, value in ds[i].items():
if feature == "action":
if feature == ACTION:
assert torch.equal(value, buffer.actions[i])
elif feature == "next.reward":
assert torch.equal(value, buffer.rewards[i])
@@ -495,7 +495,7 @@ def test_buffer_sample_alignment():
for i in range(50):
state_sig = batch["state"]["state_value"][i].item()
action_val = batch["action"][i].item()
action_val = batch[ACTION][i].item()
reward_val = batch["reward"][i].item()
next_state_sig = batch["next_state"]["state_value"][i].item()
is_done = batch["done"][i].item() > 0.5