fix environment seeding

This commit is contained in:
Alexander Soare
2024-03-22 13:25:23 +00:00
parent b633748987
commit b9047fbdd2
3 changed files with 16 additions and 11 deletions

View File

@@ -4,6 +4,8 @@ from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
from lerobot.common.utils import set_seed
class AbstractEnv(EnvBase):
def __init__(
@@ -34,7 +36,13 @@ class AbstractEnv(EnvBase):
self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)
# self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
# you store the return value in self._next_seed (it will be a new randomly generated seed).
self._next_seed = seed
# Don't store the result of this in self._next_seed, as we want to make sure that the first time
# self._reset is called, we use seed.
self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
@@ -59,4 +67,4 @@ class AbstractEnv(EnvBase):
raise NotImplementedError("Abstract method")
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError("Abstract method")
set_seed(seed)