test_datasets.py are passing!
This commit is contained in:
@@ -429,7 +429,7 @@ class TDMPCPolicy(nn.Module):
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"]
|
||||
reward = batch["next.reward"][:, :, None] # add extra channel dimension
|
||||
reward = batch["next.reward"]
|
||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
Reference in New Issue
Block a user