Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)

This commit is contained in:
Remi Cadene
2024-03-02 15:53:29 +00:00
parent b5a2f460ea
commit 1ae6205269
9 changed files with 183 additions and 67 deletions

View File

@@ -5,6 +5,7 @@ from copy import deepcopy
import einops
import numpy as np
from tensordict import TensorDict
import torch
import torch.nn as nn
@@ -126,19 +127,30 @@ class TDMPC(nn.Module):
@torch.no_grad()
def forward(self, observation, step_count):
t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs = {
"rgb": observation["image"],
"state": observation["state"],
# TODO(rcadene): remove contiguous hack...
"rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
}
return self.act(obs, t0=t0, step=self.step.item())
action = self.act(obs, t0=t0, step=self.step.item())
# TODO(rcadene): hack to postprocess action (e.g. unnormalize)
# action = action * self.action_std + self.action_mean
return action
@torch.no_grad()
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {k: o.detach().unsqueeze(0) for k, o in obs.items()}
obs = {k: o.detach() for k, o in obs.items()}
else:
obs = obs.detach().unsqueeze(0)
obs = obs.detach()
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
@@ -315,26 +327,20 @@ class TDMPC(nn.Module):
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to(self.device)
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME],
"rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
"state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
}
action = batch["action"]
action = batch["action"].to(self.device, non_blocking=True)
next_obses = {
"rgb": batch["next", "observation", "image"].float(),
"state": batch["next", "observation", "state"],
"rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
"state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
}
reward = batch["next", "reward"]
reward = batch["next", "reward"].to(self.device, non_blocking=True)
# TODO(rcadene): add non_blocking=True
# for key in obs:
# obs[key] = obs[key].to(self.device, non_blocking=True)
# next_obses[key] = next_obses[key].to(self.device, non_blocking=True)
# action = action.to(self.device, non_blocking=True)
# reward = reward.to(self.device, non_blocking=True)
idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# TODO(rcadene): rearrange directly in offline dataset
if reward.ndim == 2:
@@ -347,9 +353,6 @@ class TDMPC(nn.Module):
# Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
idxs = batch["index"][FIRST_FRAME]
weights = batch["_weight"][FIRST_FRAME, :, None]
return obs, action, next_obses, reward, mask, done, idxs, weights
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()