Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act

This commit is contained in:
Alexander Soare
2024-04-08 15:44:00 +01:00
109 changed files with 81 additions and 228 deletions

View File

@@ -429,7 +429,7 @@ class TDMPCPolicy(nn.Module):
batch[key] = batch[key].transpose(1, 0)
action = batch["action"]
reward = batch["next.reward"][:, :, None] # add extra channel dimension
reward = batch["next.reward"]
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)