Fix diffusion (rm transpose), Add prefetch
This commit is contained in:
@@ -359,7 +359,11 @@ class TDMPC(nn.Module):
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
if self.cfg.balanced_sampling:
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
else:
|
||||
batch = replay_buffer.sample()
|
||||
|
||||
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||
batch, self.cfg.horizon, num_slices
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user