Online finetuning runs (sometimes crash because of nans)
This commit is contained in:
@@ -77,18 +77,16 @@ class SimxarmEnv(EnvBase):
|
||||
|
||||
def _format_raw_obs(self, raw_obs):
|
||||
if self.from_pixels:
|
||||
camera = self.render(
|
||||
image = self.render(
|
||||
mode="rgb_array", width=self.image_size, height=self.image_size
|
||||
)
|
||||
camera = camera.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
camera = torch.tensor(camera.copy(), dtype=torch.uint8)
|
||||
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
||||
image = torch.tensor(image.copy(), dtype=torch.uint8)
|
||||
|
||||
obs = {"camera": camera}
|
||||
obs = {"image": image}
|
||||
|
||||
if not self.pixels_only:
|
||||
obs["robot_state"] = torch.tensor(
|
||||
self._env.robot_state, dtype=torch.float32
|
||||
)
|
||||
obs["state"] = torch.tensor(self._env.robot_state, dtype=torch.float32)
|
||||
else:
|
||||
obs = {"state": torch.tensor(raw_obs["observation"], dtype=torch.float32)}
|
||||
|
||||
@@ -136,7 +134,7 @@ class SimxarmEnv(EnvBase):
|
||||
def _make_spec(self):
|
||||
obs = {}
|
||||
if self.from_pixels:
|
||||
obs["camera"] = BoundedTensorSpec(
|
||||
obs["image"] = BoundedTensorSpec(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(3, self.image_size, self.image_size),
|
||||
@@ -144,7 +142,7 @@ class SimxarmEnv(EnvBase):
|
||||
device=self.device,
|
||||
)
|
||||
if not self.pixels_only:
|
||||
obs["robot_state"] = UnboundedContinuousTensorSpec(
|
||||
obs["state"] = UnboundedContinuousTensorSpec(
|
||||
shape=(len(self._env.robot_state),),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
|
||||
@@ -96,8 +96,7 @@ class TDMPC(nn.Module):
|
||||
self.model_target.eval()
|
||||
self.batch_size = cfg.batch_size
|
||||
|
||||
# TODO(rcadene): clean
|
||||
self.step = 100000
|
||||
self.step = 0
|
||||
|
||||
def state_dict(self):
|
||||
"""Retrieve state dict of TOLD model, including slow-moving target network."""
|
||||
@@ -120,8 +119,8 @@ class TDMPC(nn.Module):
|
||||
def forward(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
obs = {
|
||||
"rgb": observation["camera"],
|
||||
"state": observation["robot_state"],
|
||||
"rgb": observation["image"],
|
||||
"state": observation["state"],
|
||||
}
|
||||
return self.act(obs, t0=t0, step=self.step)
|
||||
|
||||
@@ -298,65 +297,81 @@ class TDMPC(nn.Module):
|
||||
def update(self, replay_buffer, step, demo_buffer=None):
|
||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||
|
||||
if demo_buffer is not None:
|
||||
# Update oversampling ratio
|
||||
self.demo_batch_size = int(
|
||||
h.linear_schedule(self.cfg.demo_schedule, step) * self.batch_size
|
||||
)
|
||||
replay_buffer.cfg.batch_size = self.batch_size - self.demo_batch_size
|
||||
demo_buffer.cfg.batch_size = self.demo_batch_size
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
if demo_buffer is None:
|
||||
demo_batch_size = 0
|
||||
else:
|
||||
self.demo_batch_size = 0
|
||||
# Update oversampling ratio
|
||||
demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
|
||||
demo_num_slices = int(demo_pc_batch * self.batch_size)
|
||||
demo_batch_size = self.cfg.horizon * demo_num_slices
|
||||
batch_size -= demo_batch_size
|
||||
num_slices -= demo_num_slices
|
||||
replay_buffer._sampler.num_slices = num_slices
|
||||
demo_buffer._sampler.num_slices = demo_num_slices
|
||||
|
||||
assert demo_batch_size % self.cfg.horizon == 0
|
||||
assert demo_batch_size % demo_num_slices == 0
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
# Sample from interaction dataset
|
||||
|
||||
# to not have to mask
|
||||
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
|
||||
batch_size = self.cfg.horizon * self.cfg.batch_size
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
batch = batch.to("cuda")
|
||||
|
||||
FIRST_FRAME = 0
|
||||
obs = {
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||
"state": batch["observation", "state"][FIRST_FRAME],
|
||||
}
|
||||
action = batch["action"]
|
||||
next_obses = {
|
||||
"rgb": batch["next", "observation", "image"].float(),
|
||||
"state": batch["next", "observation", "state"],
|
||||
}
|
||||
reward = batch["next", "reward"]
|
||||
|
||||
# TODO(rcadene): rearrange directly in offline dataset
|
||||
if reward.ndim == 2:
|
||||
reward = einops.rearrange(reward, "h t -> h t 1")
|
||||
|
||||
assert reward.ndim == 3
|
||||
assert reward.shape == (horizon, num_slices, 1)
|
||||
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||
# episode, but not the end of the trajectory of an episode.
|
||||
# Neither does `batch["next", "terminated"]`
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
idxs = batch["index"][FIRST_FRAME]
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||
return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = (
|
||||
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
|
||||
.transpose(1, 0)
|
||||
.contiguous()
|
||||
obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||
batch, self.cfg.horizon, num_slices
|
||||
)
|
||||
batch = batch.to("cuda")
|
||||
|
||||
FIRST_FRAME = 0
|
||||
obs = {
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||
"state": batch["observation", "state"][FIRST_FRAME],
|
||||
}
|
||||
action = batch["action"]
|
||||
next_obses = {
|
||||
"rgb": batch["next", "observation", "image"].float(),
|
||||
"state": batch["next", "observation", "state"],
|
||||
}
|
||||
reward = batch["next", "reward"]
|
||||
reward = einops.rearrange(reward, "h t -> h t 1")
|
||||
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||
# episode, but not the end of the trajectory of an episode.
|
||||
# Neither does `batch["next", "terminated"]`
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
idxs = batch["frame_id"][FIRST_FRAME]
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||
|
||||
# Sample from demonstration dataset
|
||||
if self.demo_batch_size > 0:
|
||||
if demo_batch_size > 0:
|
||||
demo_batch = demo_buffer.sample(demo_batch_size)
|
||||
(
|
||||
demo_obs,
|
||||
demo_next_obses,
|
||||
demo_action,
|
||||
demo_next_obses,
|
||||
demo_reward,
|
||||
demo_mask,
|
||||
demo_done,
|
||||
demo_idxs,
|
||||
demo_weights,
|
||||
) = demo_buffer.sample()
|
||||
) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
|
||||
|
||||
if isinstance(obs, dict):
|
||||
obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||
@@ -440,9 +455,9 @@ class TDMPC(nn.Module):
|
||||
q_value_loss += (rho * h.mse(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
||||
|
||||
self.expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
expectile = h.linear_schedule(self.cfg.expectile, step)
|
||||
v_value_loss = (
|
||||
rho * h.l2_expectile(v_target - v, expectile=self.expectile) * loss_mask
|
||||
rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask
|
||||
).sum(dim=0)
|
||||
|
||||
total_loss = (
|
||||
@@ -464,17 +479,12 @@ class TDMPC(nn.Module):
|
||||
if self.cfg.per:
|
||||
# Update priorities
|
||||
priorities = priority_loss.clamp(max=1e4).detach()
|
||||
# normalize between [0,1] to fit torchrl specification
|
||||
priorities /= 1e4
|
||||
priorities = priorities.clamp(max=1.0)
|
||||
replay_buffer.update_priority(
|
||||
idxs[: self.cfg.batch_size],
|
||||
priorities[: self.cfg.batch_size],
|
||||
idxs[:num_slices],
|
||||
priorities[:num_slices],
|
||||
)
|
||||
if self.demo_batch_size > 0:
|
||||
demo_buffer.update_priority(
|
||||
demo_idxs, priorities[self.cfg.batch_size :]
|
||||
)
|
||||
if demo_batch_size > 0:
|
||||
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
|
||||
# Update policy + target network
|
||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||
@@ -493,10 +503,12 @@ class TDMPC(nn.Module):
|
||||
"weighted_loss": float(weighted_loss.mean().item()),
|
||||
"grad_norm": float(grad_norm),
|
||||
}
|
||||
for key in ["demo_batch_size", "expectile"]:
|
||||
if hasattr(self, key):
|
||||
metrics[key] = getattr(self, key)
|
||||
# for key in ["demo_batch_size", "expectile"]:
|
||||
# if hasattr(self, key):
|
||||
metrics["demo_batch_size"] = demo_batch_size
|
||||
metrics["expectile"] = expectile
|
||||
metrics.update(value_info)
|
||||
metrics.update(pi_update_info)
|
||||
|
||||
self.step = step
|
||||
return metrics
|
||||
|
||||
Reference in New Issue
Block a user