Add obs queue to pusht, Set n_obs_steps=2 for diffusion (Not fully tested)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import copy
|
||||
import time
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -101,15 +100,13 @@ class DiffusionPolicy(nn.Module):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
if observation["image"].ndim == 3:
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs_dict = {
|
||||
# TODO(rcadene): hack to add temporal dim
|
||||
"image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"),
|
||||
"agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"),
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user