Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)
This commit is contained in:
@@ -5,6 +5,7 @@ from copy import deepcopy
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
from tensordict import TensorDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -126,19 +127,30 @@ class TDMPC(nn.Module):
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
if observation["image"].ndim == 3:
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs = {
|
||||
"rgb": observation["image"],
|
||||
"state": observation["state"],
|
||||
# TODO(rcadene): remove contiguous hack...
|
||||
"rgb": observation["image"].contiguous(),
|
||||
"state": observation["state"].contiguous(),
|
||||
}
|
||||
return self.act(obs, t0=t0, step=self.step.item())
|
||||
action = self.act(obs, t0=t0, step=self.step.item())
|
||||
|
||||
# TODO(rcadene): hack to postprocess action (e.g. unnormalize)
|
||||
# action = action * self.action_std + self.action_mean
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, step=None):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
if isinstance(obs, dict):
|
||||
obs = {k: o.detach().unsqueeze(0) for k, o in obs.items()}
|
||||
obs = {k: o.detach() for k, o in obs.items()}
|
||||
else:
|
||||
obs = obs.detach().unsqueeze(0)
|
||||
obs = obs.detach()
|
||||
z = self.model.encode(obs)
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
@@ -315,26 +327,20 @@ class TDMPC(nn.Module):
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
batch = batch.to(self.device)
|
||||
|
||||
obs = {
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||
"state": batch["observation", "state"][FIRST_FRAME],
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||
"state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||
}
|
||||
action = batch["action"]
|
||||
action = batch["action"].to(self.device, non_blocking=True)
|
||||
next_obses = {
|
||||
"rgb": batch["next", "observation", "image"].float(),
|
||||
"state": batch["next", "observation", "state"],
|
||||
"rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
|
||||
"state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
|
||||
}
|
||||
reward = batch["next", "reward"]
|
||||
reward = batch["next", "reward"].to(self.device, non_blocking=True)
|
||||
|
||||
# TODO(rcadene): add non_blocking=True
|
||||
# for key in obs:
|
||||
# obs[key] = obs[key].to(self.device, non_blocking=True)
|
||||
# next_obses[key] = next_obses[key].to(self.device, non_blocking=True)
|
||||
|
||||
# action = action.to(self.device, non_blocking=True)
|
||||
# reward = reward.to(self.device, non_blocking=True)
|
||||
idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
|
||||
|
||||
# TODO(rcadene): rearrange directly in offline dataset
|
||||
if reward.ndim == 2:
|
||||
@@ -347,9 +353,6 @@ class TDMPC(nn.Module):
|
||||
# Neither does `batch["next", "terminated"]`
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
idxs = batch["index"][FIRST_FRAME]
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||
|
||||
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
|
||||
Reference in New Issue
Block a user