WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -1,18 +1,20 @@
import copy
import logging
import time
from collections import deque
import hydra
import torch
from torch import nn
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
from lerobot.common.policies.utils import populate_queues
from lerobot.common.utils import get_safe_torch_device
class DiffusionPolicy(AbstractPolicy):
class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(
@@ -38,8 +40,12 @@ class DiffusionPolicy(AbstractPolicy):
# parameters passed to step
**kwargs,
):
super().__init__(n_action_steps)
super().__init__()
self.cfg = cfg
self.n_obs_steps = n_obs_steps
self.n_action_steps = n_action_steps
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
@@ -100,76 +106,58 @@ class DiffusionPolicy(AbstractPolicy):
last_epoch=self.global_step - 1,
)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.n_obs_steps),
"observation.state": deque(maxlen=self.n_obs_steps),
"action": deque(maxlen=self.n_action_steps),
}
@torch.no_grad()
def select_actions(self, observation, step_count):
def select_action(self, batch, step):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
"""
# TODO(rcadene): remove unused step_count
del step_count
# TODO(rcadene): remove unused step
del step
assert "observation.image" in batch
assert "observation.state" in batch
assert len(batch) == 2
obs_dict = {
"image": observation["image"],
"agent_pos": observation["state"],
}
if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
action = out["action"]
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
obs_dict = {
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
self._queues["action"].extend(out["action"].transpose(0, 1))
action = self._queues["action"].popleft()
return action
def update(self, replay_buffer, step):
def forward(self, batch, step):
start_time = time.time()
self.diffusion.train()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
assert batch_size % self.cfg.horizon == 0
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
# |o|o| observations: 2
# | |a|a|a|a|a|a|a|a| actions executed: 8
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
image = batch["observation", "image"]
state = batch["observation", "state"]
action = batch["action"]
assert image.shape[1] == horizon
assert state.shape[1] == horizon
assert action.shape[1] == horizon
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
raise NotImplementedError()
# keep first 2 observations of the slice corresponding to t=[-1,0]
image = image[:, : self.cfg.n_obs_steps]
state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
loss = self.diffusion.compute_loss(batch)
obs_dict = {
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
loss = self.diffusion.compute_loss(obs_dict)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(