Minor fixes for #47

This commit is contained in:
Simon Alibert
2024-03-25 18:50:47 +01:00
parent f00252552a
commit c5635b7d94
37 changed files with 13 additions and 18 deletions

View File

@@ -21,7 +21,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
dataset_id: str,
version: str | None,
version: str | None = None,
batch_size: int = None,
*,
shuffle: bool = True,
@@ -32,7 +32,6 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
# storage = None,
):
self.dataset_id = dataset_id
self.version = version

View File

@@ -63,7 +63,7 @@ class AlohaEnv(AbstractEnv):
def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
raise ImportError("Cannot import gymnasium.")
if not self.from_pixels:
raise NotImplementedError()

View File

@@ -50,7 +50,7 @@ class PushtEnv(AbstractEnv):
def _make_env(self):
if not _has_gym:
raise ImportError("Cannot import gym.")
raise ImportError("Cannot import gymnasium.")
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from lerobot.common.envs.pusht.pusht_env import PushTEnv

View File

@@ -49,10 +49,8 @@ class SimxarmEnv(AbstractEnv):
)
def _make_env(self):
# if not _has_simxarm:
# raise ImportError("Cannot import simxarm.")
if not _has_gym:
raise ImportError("Cannot import gym.")
raise ImportError("Cannot import gymnasium.")
import gymnasium
@@ -231,8 +229,6 @@ class SimxarmEnv(AbstractEnv):
def _set_seed(self, seed: Optional[int]):
set_global_seed(seed)
# self._env.seed(seed)
# self._env.action_space.seed(seed)
# self.set_seed(seed)
logging.warning("simxarm env is not seeded")
self._seed = seed
# TODO(aliberts): change self._reset so that it takes in a seed value
logging.warning("simxarm env is not properly seeded")

View File

@@ -4,11 +4,11 @@ import gymnasium as gym
import numpy as np
from gymnasium.wrappers import TimeLimit
from lerobot.common.envs.simxarm.simxarm.task.base import Base as Base
from lerobot.common.envs.simxarm.simxarm.task.lift import Lift
from lerobot.common.envs.simxarm.simxarm.task.peg_in_box import PegInBox
from lerobot.common.envs.simxarm.simxarm.task.push import Push
from lerobot.common.envs.simxarm.simxarm.task.reach import Reach
from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base
from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift
from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox
from lerobot.common.envs.simxarm.simxarm.tasks.push import Push
from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach
TASKS = OrderedDict(
(

View File

@@ -4,7 +4,7 @@ import mujoco
import numpy as np
from gymnasium_robotics.envs import robot_env
from lerobot.common.envs.simxarm.simxarm.task import mocap
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
class Base(robot_env.MujocoRobotEnv):

View File

@@ -231,7 +231,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_parallel_env = rollout.batch_size[0]
if num_parallel_env != 1:
# TODO(rcadene): when num_parallel_env > 1, episode needs to be incremented and we need to add tests
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
raise NotImplementedError()
num_max_steps = rollout.batch_size[1]