Fix diffusion (rm transpose), Add prefetch

This commit is contained in:
Cadene
2024-02-28 17:45:01 +00:00
parent cf5063e50e
commit ac90b9c3ee
6 changed files with 52 additions and 11 deletions

View File

@@ -359,7 +359,11 @@ class TDMPC(nn.Module):
weights = batch["_weight"][FIRST_FRAME, :, None]
return obs, action, next_obses, reward, mask, done, idxs, weights
batch = replay_buffer.sample(batch_size)
if self.cfg.balanced_sampling:
batch = replay_buffer.sample(batch_size)
else:
batch = replay_buffer.sample()
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
batch, self.cfg.horizon, num_slices
)