Add obs queue to pusht, Set n_obs_steps=2 for diffusion (Not fully tested)

This commit is contained in:
Remi Cadene
2024-03-03 13:21:31 +00:00
parent cbbed590a9
commit 0f2fa4d9ef
3 changed files with 85 additions and 14 deletions

View File

@@ -1,7 +1,6 @@
import copy
import time
import einops
import hydra
import torch
import torch.nn as nn
@@ -101,15 +100,13 @@ class DiffusionPolicy(nn.Module):
# TODO(rcadene): remove unused step_count
del step_count
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs_dict = {
# TODO(rcadene): hack to add temporal dim
"image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"),
"agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"),
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)