chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -22,7 +22,7 @@ import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def create_random_image() -> torch.Tensor:
|
||||
def create_dummy_transition() -> dict:
|
||||
return {
|
||||
OBS_IMAGE: create_random_image(),
|
||||
"action": torch.randn(4),
|
||||
ACTION: torch.randn(4),
|
||||
"reward": torch.tensor(1.0),
|
||||
OBS_STATE: torch.randn(
|
||||
10,
|
||||
@@ -341,7 +341,7 @@ def test_sample_batch(replay_buffer):
|
||||
f"{k} should be equal to one of the dummy states."
|
||||
)
|
||||
|
||||
for got_action_item in got_batch_transition["action"]:
|
||||
for got_action_item in got_batch_transition[ACTION]:
|
||||
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
|
||||
"Actions should be equal to the dummy actions."
|
||||
)
|
||||
@@ -378,7 +378,7 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
|
||||
for i in range(len(ds)):
|
||||
for feature, value in ds[i].items():
|
||||
if feature == "action":
|
||||
if feature == ACTION:
|
||||
assert torch.equal(value, buffer.actions[i])
|
||||
elif feature == "next.reward":
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
@@ -495,7 +495,7 @@ def test_buffer_sample_alignment():
|
||||
|
||||
for i in range(50):
|
||||
state_sig = batch["state"]["state_value"][i].item()
|
||||
action_val = batch["action"][i].item()
|
||||
action_val = batch[ACTION][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
|
||||
Reference in New Issue
Block a user