Minor fixes for #47
This commit is contained in:
@@ -21,7 +21,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None,
|
||||
version: str | None = None,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
@@ -32,7 +32,6 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
# storage = None,
|
||||
):
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
|
||||
@@ -63,7 +63,7 @@ class AlohaEnv(AbstractEnv):
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
raise ImportError("Cannot import gymnasium.")
|
||||
|
||||
if not self.from_pixels:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -50,7 +50,7 @@ class PushtEnv(AbstractEnv):
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
raise ImportError("Cannot import gymnasium.")
|
||||
|
||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||
|
||||
@@ -49,10 +49,8 @@ class SimxarmEnv(AbstractEnv):
|
||||
)
|
||||
|
||||
def _make_env(self):
|
||||
# if not _has_simxarm:
|
||||
# raise ImportError("Cannot import simxarm.")
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gym.")
|
||||
raise ImportError("Cannot import gymnasium.")
|
||||
|
||||
import gymnasium
|
||||
|
||||
@@ -231,8 +229,6 @@ class SimxarmEnv(AbstractEnv):
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
set_global_seed(seed)
|
||||
# self._env.seed(seed)
|
||||
# self._env.action_space.seed(seed)
|
||||
# self.set_seed(seed)
|
||||
logging.warning("simxarm env is not seeded")
|
||||
self._seed = seed
|
||||
# TODO(aliberts): change self._reset so that it takes in a seed value
|
||||
logging.warning("simxarm env is not properly seeded")
|
||||
|
||||
@@ -4,11 +4,11 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium.wrappers import TimeLimit
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm.task.base import Base as Base
|
||||
from lerobot.common.envs.simxarm.simxarm.task.lift import Lift
|
||||
from lerobot.common.envs.simxarm.simxarm.task.peg_in_box import PegInBox
|
||||
from lerobot.common.envs.simxarm.simxarm.task.push import Push
|
||||
from lerobot.common.envs.simxarm.simxarm.task.reach import Reach
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.push import Push
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach
|
||||
|
||||
TASKS = OrderedDict(
|
||||
(
|
||||
|
||||
@@ -4,7 +4,7 @@ import mujoco
|
||||
import numpy as np
|
||||
from gymnasium_robotics.envs import robot_env
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm.task import mocap
|
||||
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
|
||||
|
||||
|
||||
class Base(robot_env.MujocoRobotEnv):
|
||||
@@ -231,7 +231,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
num_parallel_env = rollout.batch_size[0]
|
||||
if num_parallel_env != 1:
|
||||
# TODO(rcadene): when num_parallel_env > 1, episode needs to be incremented and we need to add tests
|
||||
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
|
||||
raise NotImplementedError()
|
||||
|
||||
num_max_steps = rollout.batch_size[1]
|
||||
|
||||
Reference in New Issue
Block a user