Refactor env queue, Training diffusion works (Still not converging)

This commit is contained in:
Remi Cadene
2024-03-04 10:59:43 +00:00
parent fddd9f0311
commit cfc304e870
11 changed files with 96 additions and 111 deletions

View File

@@ -101,14 +101,18 @@ class PushtEnv(EnvBase):
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
# remove all previous observations
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.clear()
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.clear()
# copy the current observation n times
obs = self._stack_prev_obs(obs)
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
@@ -121,40 +125,6 @@ class PushtEnv(EnvBase):
raise NotImplementedError()
return td
def _stack_prev_obs(self, obs):
"""When the queue is empty, copy the current observation n times."""
assert self.num_prev_obs > 0
def stack_update_queue(prev_obs_queue, obs, num_prev_obs):
# get n most recent observations
prev_obs = list(prev_obs_queue)[-num_prev_obs:]
# if not enough observations, copy the oldest observation until we obtain n observations
if len(prev_obs) == 0:
prev_obs = [obs] * num_prev_obs # queue is empty when env reset
elif len(prev_obs) < num_prev_obs:
prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs
# stack n most recent observations with the current observation
stacked_obs = torch.stack(prev_obs + [obs], dim=0)
# add current observation to the queue
# automatically remove oldest observation when queue is full
prev_obs_queue.appendleft(obs)
return stacked_obs
stacked_obs = {}
if "image" in obs:
stacked_obs["image"] = stack_update_queue(
self._prev_obs_image_queue, obs["image"], self.num_prev_obs
)
if "state" in obs:
stacked_obs["state"] = stack_update_queue(
self._prev_obs_state_queue, obs["state"], self.num_prev_obs
)
return stacked_obs
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
@@ -176,7 +146,14 @@ class PushtEnv(EnvBase):
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
obs = self._stack_prev_obs(obs)
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{