test_datasets.py are passing!

This commit is contained in:
Cadene
2024-04-08 14:02:03 +00:00
parent e1ac5dc62f
commit 70aaf1c4cb
109 changed files with 90 additions and 228 deletions

View File

@@ -429,7 +429,7 @@ class TDMPCPolicy(nn.Module):
batch[key] = batch[key].transpose(1, 0)
action = batch["action"]
reward = batch["next.reward"][:, :, None] # add extra channel dimension
reward = batch["next.reward"]
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)