Renamed set_seed -> set_global_seed

This commit is contained in:
Simon Alibert
2024-03-25 17:19:28 +01:00
parent 058ac991eb
commit 7cdd6d2450
9 changed files with 16 additions and 16 deletions

View File

@@ -4,7 +4,7 @@ from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
from lerobot.common.utils import set_seed
from lerobot.common.utils import set_global_seed
class AbstractEnv(EnvBase):
@@ -67,4 +67,4 @@ class AbstractEnv(EnvBase):
raise NotImplementedError("Abstract method")
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
set_global_seed(seed)

View File

@@ -29,7 +29,7 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
TransferCubeEndEffectorTask,
)
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
from lerobot.common.utils import set_seed
from lerobot.common.utils import set_global_seed
_has_gym = importlib.util.find_spec("gymnasium") is not None
@@ -290,7 +290,7 @@ class AlohaEnv(AbstractEnv):
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
set_global_seed(seed)
# TODO(rcadene): seed the env
# self._env.seed(seed)
logging.warning("Aloha env is not seeded")

View File

@@ -16,7 +16,7 @@ from torchrl.data.tensor_specs import (
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
from lerobot.common.utils import set_global_seed
_has_gym = importlib.util.find_spec("gymnasium") is not None
@@ -238,6 +238,6 @@ class PushtEnv(AbstractEnv):
def _set_seed(self, seed: Optional[int]):
# Set global seed.
set_seed(seed)
set_global_seed(seed)
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
self._env.seed(seed)

View File

@@ -1,4 +1,5 @@
import importlib
import logging
from collections import deque
from typing import Optional
@@ -15,7 +16,7 @@ from torchrl.data.tensor_specs import (
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
from lerobot.common.envs.abstract import AbstractEnv
from lerobot.common.utils import set_seed
from lerobot.common.utils import set_global_seed
MAX_NUM_ACTIONS = 4
@@ -229,8 +230,9 @@ class SimxarmEnv(AbstractEnv):
)
def _set_seed(self, seed: Optional[int]):
set_seed(seed)
set_global_seed(seed)
# self._env.seed(seed)
# self._env.action_space.seed(seed)
# self.set_seed(seed)
logging.warning("simxarm env is not seeded")
self._seed = seed

View File

@@ -26,7 +26,7 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
return device
def set_seed(seed):
def set_global_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)