Merge remote-tracking branch 'upstream/main' into refactor_dp

This commit is contained in:
Alexander Soare
2024-04-11 17:52:10 +01:00
29 changed files with 545 additions and 603 deletions

View File

@@ -29,9 +29,9 @@ def make_policy(cfg):
if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
if "offline" in cfg.policy.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
elif "final" in cfg.policy.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()

View File

@@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
# num_slices = self.cfg.batch_size
# batch_size = self.cfg.horizon * num_slices
# if demo_buffer is None:
# demo_batch_size = 0
# else:
# # Update oversampling ratio
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
# demo_num_slices = int(demo_pc_batch * self.batch_size)
# demo_batch_size = self.cfg.horizon * demo_num_slices
# batch_size -= demo_batch_size
# num_slices -= demo_num_slices
# replay_buffer._sampler.num_slices = num_slices
# demo_buffer._sampler.num_slices = demo_num_slices
# assert demo_batch_size % self.cfg.horizon == 0
# assert demo_batch_size % demo_num_slices == 0
# assert batch_size % self.cfg.horizon == 0
# assert batch_size % num_slices == 0
# # Sample from interaction dataset
# def process_batch(batch, horizon, num_slices):
# # trajectory t = 256, horizon h = 5
# # (t h) ... -> h t ...
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
# obs = {
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
# }
# action = batch["action"].to(self.device, non_blocking=True)
# next_obses = {
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
# }
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# # TODO(rcadene): rearrange directly in offline dataset
# if reward.ndim == 2:
# reward = einops.rearrange(reward, "h t -> h t 1")
# assert reward.ndim == 3
# assert reward.shape == (horizon, num_slices, 1)
# # We dont use `batch["next", "done"]` since it only indicates the end of an
# # episode, but not the end of the trajectory of an episode.
# # Neither does `batch["next", "terminated"]`
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
# return obs, action, next_obses, reward, mask, done, idxs, weights
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
# batch, self.cfg.horizon, num_slices
# )
# Sample from demonstration dataset
# if demo_batch_size > 0:
# demo_batch = demo_buffer.sample(demo_batch_size)
# (
# demo_obs,
# demo_action,
# demo_next_obses,
# demo_reward,
# demo_mask,
# demo_done,
# demo_idxs,
# demo_weights,
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
# if isinstance(obs, dict):
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
# else:
# obs = torch.cat([obs, demo_obs])
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
# action = torch.cat([action, demo_action], dim=1)
# reward = torch.cat([reward, demo_reward], dim=1)
# mask = torch.cat([mask, demo_mask], dim=1)
# done = torch.cat([done, demo_done], dim=1)
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
@@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module):
)
self.optim.step()
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
# if self.cfg.per:
# # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach()