Wandb works, One output dir
This commit is contained in:
@@ -11,6 +11,7 @@ def make_env(cfg):
|
||||
"from_pixels": cfg.from_pixels,
|
||||
"pixels_only": cfg.pixels_only,
|
||||
"image_size": cfg.image_size,
|
||||
"max_episode_length": cfg.episode_length,
|
||||
}
|
||||
|
||||
if cfg.env == "simxarm":
|
||||
|
||||
@@ -29,7 +29,7 @@ class PushtEnv(EnvBase):
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
max_episode_length=25, # TODO: verify
|
||||
max_episode_length=300,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.frame_skip = frame_skip
|
||||
@@ -53,13 +53,11 @@ class PushtEnv(EnvBase):
|
||||
if not from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from gym.wrappers import TimeLimit
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
self._env = TimeLimit(self._env, self.max_episode_length)
|
||||
|
||||
self._make_spec()
|
||||
self.set_seed(seed)
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
if width != height:
|
||||
@@ -90,7 +88,11 @@ class PushtEnv(EnvBase):
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
|
||||
@@ -49,7 +49,6 @@ class SimxarmEnv(EnvBase):
|
||||
raise ImportError("Cannot import gym.")
|
||||
|
||||
import gym
|
||||
from gym.wrappers import TimeLimit
|
||||
from simxarm import TASKS
|
||||
|
||||
if self.task not in TASKS:
|
||||
@@ -58,7 +57,6 @@ class SimxarmEnv(EnvBase):
|
||||
)
|
||||
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
self._env = TimeLimit(self._env, TASKS[self.task]["episode_length"])
|
||||
|
||||
MAX_NUM_ACTIONS = 4
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
|
||||
Reference in New Issue
Block a user