Refactor env queue, Training diffusion works (Still not converging)
This commit is contained in:
@@ -101,14 +101,18 @@ class PushtEnv(EnvBase):
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
# remove all previous observations
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.clear()
|
||||
self._prev_obs_image_queue = deque(
|
||||
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.clear()
|
||||
|
||||
# copy the current observation n times
|
||||
obs = self._stack_prev_obs(obs)
|
||||
self._prev_obs_state_queue = deque(
|
||||
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
|
||||
)
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
@@ -121,40 +125,6 @@ class PushtEnv(EnvBase):
|
||||
raise NotImplementedError()
|
||||
return td
|
||||
|
||||
def _stack_prev_obs(self, obs):
|
||||
"""When the queue is empty, copy the current observation n times."""
|
||||
assert self.num_prev_obs > 0
|
||||
|
||||
def stack_update_queue(prev_obs_queue, obs, num_prev_obs):
|
||||
# get n most recent observations
|
||||
prev_obs = list(prev_obs_queue)[-num_prev_obs:]
|
||||
|
||||
# if not enough observations, copy the oldest observation until we obtain n observations
|
||||
if len(prev_obs) == 0:
|
||||
prev_obs = [obs] * num_prev_obs # queue is empty when env reset
|
||||
elif len(prev_obs) < num_prev_obs:
|
||||
prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs
|
||||
|
||||
# stack n most recent observations with the current observation
|
||||
stacked_obs = torch.stack(prev_obs + [obs], dim=0)
|
||||
|
||||
# add current observation to the queue
|
||||
# automatically remove oldest observation when queue is full
|
||||
prev_obs_queue.appendleft(obs)
|
||||
|
||||
return stacked_obs
|
||||
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
stacked_obs["image"] = stack_update_queue(
|
||||
self._prev_obs_image_queue, obs["image"], self.num_prev_obs
|
||||
)
|
||||
if "state" in obs:
|
||||
stacked_obs["state"] = stack_update_queue(
|
||||
self._prev_obs_state_queue, obs["state"], self.num_prev_obs
|
||||
)
|
||||
return stacked_obs
|
||||
|
||||
def _step(self, tensordict: TensorDict):
|
||||
td = tensordict
|
||||
action = td["action"].numpy()
|
||||
@@ -176,7 +146,14 @@ class PushtEnv(EnvBase):
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
obs = self._stack_prev_obs(obs)
|
||||
stacked_obs = {}
|
||||
if "image" in obs:
|
||||
self._prev_obs_image_queue.append(obs["image"])
|
||||
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
|
||||
if "state" in obs:
|
||||
self._prev_obs_state_queue.append(obs["state"])
|
||||
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
|
||||
obs = stacked_obs
|
||||
|
||||
td = TensorDict(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user