chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code
This commit is contained in:
@@ -24,6 +24,7 @@ from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
@@ -33,12 +34,12 @@ def create_test_transitions(count: int = 3) -> list[Transition]:
|
||||
transitions = []
|
||||
for i in range(count):
|
||||
transition = Transition(
|
||||
state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
action=torch.randn(5),
|
||||
reward=torch.tensor(1.0 + i),
|
||||
done=torch.tensor(i == count - 1), # Last transition is done
|
||||
truncated=torch.tensor(False),
|
||||
next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)},
|
||||
complementary_info={"step": torch.tensor(i), "episode_id": i // 2},
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
Reference in New Issue
Block a user