Merge pull request #6 from Cadene/user/rcadene/2024_03_04_diffusion
Make diffusion work
This commit is contained in:
@@ -1,20 +1,15 @@
|
||||
import copy
|
||||
import time
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
|
||||
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from .multi_image_obs_encoder import MultiImageObsEncoder
|
||||
|
||||
FIRST_ACTION = 0
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module):
|
||||
def __init__(
|
||||
@@ -42,8 +37,8 @@ class DiffusionPolicy(nn.Module):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler)
|
||||
rgb_model = get_resnet(**cfg_rgb_model)
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg_obs_encoder,
|
||||
@@ -101,20 +96,17 @@ class DiffusionPolicy(nn.Module):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
if observation["image"].ndim == 3:
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
# TODO(rcadene): remove unsqueeze hack to add bsize=1
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs_dict = {
|
||||
# TODO(rcadene): hack to add temporal dim
|
||||
"image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"),
|
||||
"agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"),
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
|
||||
# TODO(rcadene): add possibility to return >1 timestemps
|
||||
action = out["action"].squeeze(0)[FIRST_ACTION]
|
||||
action = out["action"].squeeze(0)
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
@@ -133,16 +125,36 @@ class DiffusionPolicy(nn.Module):
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
||||
|
||||
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
|
||||
# |o|o| observations: 2
|
||||
# | |a|a|a|a|a|a|a|a| actions executed: 8
|
||||
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
|
||||
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
|
||||
|
||||
image = batch["observation", "image"]
|
||||
state = batch["observation", "state"]
|
||||
action = batch["action"]
|
||||
assert image.shape[1] == horizon
|
||||
assert state.shape[1] == horizon
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
|
||||
raise NotImplementedError()
|
||||
|
||||
# keep first 2 observations of the slice corresponding to t=[-1,0]
|
||||
image = image[:, : self.cfg.n_obs_steps]
|
||||
state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": batch["observation", "image"].to(self.device, non_blocking=True),
|
||||
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": batch["action"].to(self.device, non_blocking=True),
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
Reference in New Issue
Block a user