Minor fixes for #47

This commit is contained in:
Simon Alibert
2024-03-25 18:50:47 +01:00
parent f00252552a
commit c5635b7d94
37 changed files with 13 additions and 18 deletions

View File

@@ -21,7 +21,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__( def __init__(
self, self,
dataset_id: str, dataset_id: str,
version: str | None, version: str | None = None,
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,
@@ -32,7 +32,6 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
collate_fn: Callable = None, collate_fn: Callable = None,
writer: Writer = None, writer: Writer = None,
transform: "torchrl.envs.Transform" = None, transform: "torchrl.envs.Transform" = None,
# storage = None,
): ):
self.dataset_id = dataset_id self.dataset_id = dataset_id
self.version = version self.version = version

View File

@@ -63,7 +63,7 @@ class AlohaEnv(AbstractEnv):
def _make_env(self): def _make_env(self):
if not _has_gym: if not _has_gym:
raise ImportError("Cannot import gym.") raise ImportError("Cannot import gymnasium.")
if not self.from_pixels: if not self.from_pixels:
raise NotImplementedError() raise NotImplementedError()

View File

@@ -50,7 +50,7 @@ class PushtEnv(AbstractEnv):
def _make_env(self): def _make_env(self):
if not _has_gym: if not _has_gym:
raise ImportError("Cannot import gym.") raise ImportError("Cannot import gymnasium.")
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on) # TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
# from lerobot.common.envs.pusht.pusht_env import PushTEnv # from lerobot.common.envs.pusht.pusht_env import PushTEnv

View File

@@ -49,10 +49,8 @@ class SimxarmEnv(AbstractEnv):
) )
def _make_env(self): def _make_env(self):
# if not _has_simxarm:
# raise ImportError("Cannot import simxarm.")
if not _has_gym: if not _has_gym:
raise ImportError("Cannot import gym.") raise ImportError("Cannot import gymnasium.")
import gymnasium import gymnasium
@@ -231,8 +229,6 @@ class SimxarmEnv(AbstractEnv):
def _set_seed(self, seed: Optional[int]): def _set_seed(self, seed: Optional[int]):
set_global_seed(seed) set_global_seed(seed)
# self._env.seed(seed)
# self._env.action_space.seed(seed)
# self.set_seed(seed)
logging.warning("simxarm env is not seeded")
self._seed = seed self._seed = seed
# TODO(aliberts): change self._reset so that it takes in a seed value
logging.warning("simxarm env is not properly seeded")

View File

@@ -4,11 +4,11 @@ import gymnasium as gym
import numpy as np import numpy as np
from gymnasium.wrappers import TimeLimit from gymnasium.wrappers import TimeLimit
from lerobot.common.envs.simxarm.simxarm.task.base import Base as Base from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base
from lerobot.common.envs.simxarm.simxarm.task.lift import Lift from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift
from lerobot.common.envs.simxarm.simxarm.task.peg_in_box import PegInBox from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox
from lerobot.common.envs.simxarm.simxarm.task.push import Push from lerobot.common.envs.simxarm.simxarm.tasks.push import Push
from lerobot.common.envs.simxarm.simxarm.task.reach import Reach from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach
TASKS = OrderedDict( TASKS = OrderedDict(
( (

View File

@@ -4,7 +4,7 @@ import mujoco
import numpy as np import numpy as np
from gymnasium_robotics.envs import robot_env from gymnasium_robotics.envs import robot_env
from lerobot.common.envs.simxarm.simxarm.task import mocap from lerobot.common.envs.simxarm.simxarm.tasks import mocap
class Base(robot_env.MujocoRobotEnv): class Base(robot_env.MujocoRobotEnv):

View File

@@ -231,7 +231,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_parallel_env = rollout.batch_size[0] num_parallel_env = rollout.batch_size[0]
if num_parallel_env != 1: if num_parallel_env != 1:
# TODO(rcadene): when num_parallel_env > 1, episode needs to be incremented and we need to add tests # TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
raise NotImplementedError() raise NotImplementedError()
num_max_steps = rollout.batch_size[1] num_max_steps = rollout.batch_size[1]