Wandb works, One output dir
This commit is contained in:
@@ -29,7 +29,7 @@ class PushtEnv(EnvBase):
|
||||
image_size=None,
|
||||
seed=1337,
|
||||
device="cpu",
|
||||
max_episode_length=25, # TODO: verify
|
||||
max_episode_length=300,
|
||||
):
|
||||
super().__init__(device=device, batch_size=[])
|
||||
self.frame_skip = frame_skip
|
||||
@@ -53,13 +53,11 @@ class PushtEnv(EnvBase):
|
||||
if not from_pixels:
|
||||
raise NotImplementedError("Use PushTEnv, instead of PushTImageEnv")
|
||||
from diffusion_policy.env.pusht.pusht_image_env import PushTImageEnv
|
||||
from gym.wrappers import TimeLimit
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
self._env = TimeLimit(self._env, self.max_episode_length)
|
||||
|
||||
self._make_spec()
|
||||
self.set_seed(seed)
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
if width != height:
|
||||
@@ -90,7 +88,11 @@ class PushtEnv(EnvBase):
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
td = tensordict
|
||||
if td is None or td.is_empty():
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user