Refactor env queue, Training diffusion works (Still not converging)

This commit is contained in:
Remi Cadene
2024-03-04 10:59:43 +00:00
parent fddd9f0311
commit cfc304e870
11 changed files with 96 additions and 111 deletions

View File

@@ -4,10 +4,8 @@ import time
import hydra
import torch
import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
from .multi_image_obs_encoder import MultiImageObsEncoder
@@ -39,8 +37,8 @@ class DiffusionPolicy(nn.Module):
super().__init__()
self.cfg = cfg
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler)
rgb_model = get_resnet(**cfg_rgb_model)
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
@@ -127,16 +125,36 @@ class DiffusionPolicy(nn.Module):
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
# |o|o| observations: 2
# | |a|a|a|a|a|a|a|a| actions executed: 8
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
image = batch["observation", "image"]
state = batch["observation", "state"]
action = batch["action"]
assert image.shape[1] == horizon
assert state.shape[1] == horizon
assert action.shape[1] == horizon
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
raise NotImplementedError()
# keep first 2 observations of the slice corresponding to t=[-1,0]
image = image[:, : self.cfg.n_obs_steps]
state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": batch["observation", "image"].to(self.device, non_blocking=True),
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": batch["action"].to(self.device, non_blocking=True),
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time