Minor fixes for #47
This commit is contained in:
@@ -21,7 +21,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
version: str | None,
|
version: str | None = None,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
*,
|
*,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
@@ -32,7 +32,6 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||||||
collate_fn: Callable = None,
|
collate_fn: Callable = None,
|
||||||
writer: Writer = None,
|
writer: Writer = None,
|
||||||
transform: "torchrl.envs.Transform" = None,
|
transform: "torchrl.envs.Transform" = None,
|
||||||
# storage = None,
|
|
||||||
):
|
):
|
||||||
self.dataset_id = dataset_id
|
self.dataset_id = dataset_id
|
||||||
self.version = version
|
self.version = version
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class AlohaEnv(AbstractEnv):
|
|||||||
|
|
||||||
def _make_env(self):
|
def _make_env(self):
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
raise ImportError("Cannot import gym.")
|
raise ImportError("Cannot import gymnasium.")
|
||||||
|
|
||||||
if not self.from_pixels:
|
if not self.from_pixels:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class PushtEnv(AbstractEnv):
|
|||||||
|
|
||||||
def _make_env(self):
|
def _make_env(self):
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
raise ImportError("Cannot import gym.")
|
raise ImportError("Cannot import gymnasium.")
|
||||||
|
|
||||||
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
# TODO(rcadene) (PushTEnv is similar to PushTImageEnv, but without the image rendering, it's faster to iterate on)
|
||||||
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
# from lerobot.common.envs.pusht.pusht_env import PushTEnv
|
||||||
|
|||||||
@@ -49,10 +49,8 @@ class SimxarmEnv(AbstractEnv):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _make_env(self):
|
def _make_env(self):
|
||||||
# if not _has_simxarm:
|
|
||||||
# raise ImportError("Cannot import simxarm.")
|
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
raise ImportError("Cannot import gym.")
|
raise ImportError("Cannot import gymnasium.")
|
||||||
|
|
||||||
import gymnasium
|
import gymnasium
|
||||||
|
|
||||||
@@ -231,8 +229,6 @@ class SimxarmEnv(AbstractEnv):
|
|||||||
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
set_global_seed(seed)
|
set_global_seed(seed)
|
||||||
# self._env.seed(seed)
|
|
||||||
# self._env.action_space.seed(seed)
|
|
||||||
# self.set_seed(seed)
|
|
||||||
logging.warning("simxarm env is not seeded")
|
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
|
# TODO(aliberts): change self._reset so that it takes in a seed value
|
||||||
|
logging.warning("simxarm env is not properly seeded")
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.wrappers import TimeLimit
|
from gymnasium.wrappers import TimeLimit
|
||||||
|
|
||||||
from lerobot.common.envs.simxarm.simxarm.task.base import Base as Base
|
from lerobot.common.envs.simxarm.simxarm.tasks.base import Base as Base
|
||||||
from lerobot.common.envs.simxarm.simxarm.task.lift import Lift
|
from lerobot.common.envs.simxarm.simxarm.tasks.lift import Lift
|
||||||
from lerobot.common.envs.simxarm.simxarm.task.peg_in_box import PegInBox
|
from lerobot.common.envs.simxarm.simxarm.tasks.peg_in_box import PegInBox
|
||||||
from lerobot.common.envs.simxarm.simxarm.task.push import Push
|
from lerobot.common.envs.simxarm.simxarm.tasks.push import Push
|
||||||
from lerobot.common.envs.simxarm.simxarm.task.reach import Reach
|
from lerobot.common.envs.simxarm.simxarm.tasks.reach import Reach
|
||||||
|
|
||||||
TASKS = OrderedDict(
|
TASKS = OrderedDict(
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import mujoco
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium_robotics.envs import robot_env
|
from gymnasium_robotics.envs import robot_env
|
||||||
|
|
||||||
from lerobot.common.envs.simxarm.simxarm.task import mocap
|
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
|
||||||
|
|
||||||
|
|
||||||
class Base(robot_env.MujocoRobotEnv):
|
class Base(robot_env.MujocoRobotEnv):
|
||||||
@@ -231,7 +231,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
|
|
||||||
num_parallel_env = rollout.batch_size[0]
|
num_parallel_env = rollout.batch_size[0]
|
||||||
if num_parallel_env != 1:
|
if num_parallel_env != 1:
|
||||||
# TODO(rcadene): when num_parallel_env > 1, episode needs to be incremented and we need to add tests
|
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
num_max_steps = rollout.batch_size[1]
|
num_max_steps = rollout.batch_size[1]
|
||||||
|
|||||||
Reference in New Issue
Block a user