Fix diffusion (rm transpose), Add prefetch

This commit is contained in:
Cadene
2024-02-28 17:45:01 +00:00
parent cf5063e50e
commit ac90b9c3ee
6 changed files with 52 additions and 11 deletions

View File

@@ -119,9 +119,9 @@ class DiffusionPolicy(nn.Module):
assert batch_size % num_slices == 0
def process_batch(batch, horizon, num_slices):
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
# trajectory t = 64, horizon h = 16
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
out = {
"obs": {
@@ -132,7 +132,10 @@ class DiffusionPolicy(nn.Module):
}
return out
batch = replay_buffer.sample(batch_size)
if self.cfg.balanced_sampling:
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
batch = process_batch(batch, self.cfg.horizon, num_slices)
loss = self.diffusion.compute_loss(batch)
@@ -149,4 +152,17 @@ class DiffusionPolicy(nn.Module):
"total_loss": loss.item(),
"lr": self.lr_scheduler.get_last_lr()[0],
}
# TODO(rcadene): remove hardcoding
# in diffusion_policy, len(dataloader) is 168 for a batch_size of 64
if step % 168 == 0:
self.global_step += 1
return metrics
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)