Merge remote-tracking branch 'upstream/main' into refactor_dp
This commit is contained in:
@@ -29,9 +29,9 @@ def make_policy(cfg):
|
||||
if cfg.policy.pretrained_model_path:
|
||||
# TODO(rcadene): hack for old pretrained models from fowm
|
||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||
if "offline" in cfg.pretrained_model_path:
|
||||
if "offline" in cfg.policy.pretrained_model_path:
|
||||
policy.step[0] = 25000
|
||||
elif "final" in cfg.pretrained_model_path:
|
||||
elif "final" in cfg.policy.pretrained_model_path:
|
||||
policy.step[0] = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module):
|
||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||
start_time = time.time()
|
||||
|
||||
# num_slices = self.cfg.batch_size
|
||||
# batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
# if demo_buffer is None:
|
||||
# demo_batch_size = 0
|
||||
# else:
|
||||
# # Update oversampling ratio
|
||||
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
|
||||
# demo_num_slices = int(demo_pc_batch * self.batch_size)
|
||||
# demo_batch_size = self.cfg.horizon * demo_num_slices
|
||||
# batch_size -= demo_batch_size
|
||||
# num_slices -= demo_num_slices
|
||||
# replay_buffer._sampler.num_slices = num_slices
|
||||
# demo_buffer._sampler.num_slices = demo_num_slices
|
||||
|
||||
# assert demo_batch_size % self.cfg.horizon == 0
|
||||
# assert demo_batch_size % demo_num_slices == 0
|
||||
|
||||
# assert batch_size % self.cfg.horizon == 0
|
||||
# assert batch_size % num_slices == 0
|
||||
|
||||
# # Sample from interaction dataset
|
||||
|
||||
# def process_batch(batch, horizon, num_slices):
|
||||
# # trajectory t = 256, horizon h = 5
|
||||
# # (t h) ... -> h t ...
|
||||
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
|
||||
# obs = {
|
||||
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
|
||||
# }
|
||||
# action = batch["action"].to(self.device, non_blocking=True)
|
||||
# next_obses = {
|
||||
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
|
||||
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
|
||||
# }
|
||||
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
|
||||
|
||||
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
|
||||
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
|
||||
|
||||
# # TODO(rcadene): rearrange directly in offline dataset
|
||||
# if reward.ndim == 2:
|
||||
# reward = einops.rearrange(reward, "h t -> h t 1")
|
||||
|
||||
# assert reward.ndim == 3
|
||||
# assert reward.shape == (horizon, num_slices, 1)
|
||||
# # We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||
# # episode, but not the end of the trajectory of an episode.
|
||||
# # Neither does `batch["next", "terminated"]`
|
||||
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
# return obs, action, next_obses, reward, mask, done, idxs, weights
|
||||
|
||||
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
|
||||
|
||||
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
|
||||
# batch, self.cfg.horizon, num_slices
|
||||
# )
|
||||
|
||||
# Sample from demonstration dataset
|
||||
# if demo_batch_size > 0:
|
||||
# demo_batch = demo_buffer.sample(demo_batch_size)
|
||||
# (
|
||||
# demo_obs,
|
||||
# demo_action,
|
||||
# demo_next_obses,
|
||||
# demo_reward,
|
||||
# demo_mask,
|
||||
# demo_done,
|
||||
# demo_idxs,
|
||||
# demo_weights,
|
||||
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
|
||||
|
||||
# if isinstance(obs, dict):
|
||||
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
|
||||
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
|
||||
# else:
|
||||
# obs = torch.cat([obs, demo_obs])
|
||||
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
|
||||
# action = torch.cat([action, demo_action], dim=1)
|
||||
# reward = torch.cat([reward, demo_reward], dim=1)
|
||||
# mask = torch.cat([mask, demo_mask], dim=1)
|
||||
# done = torch.cat([done, demo_done], dim=1)
|
||||
# idxs = torch.cat([idxs, demo_idxs])
|
||||
# weights = torch.cat([weights, demo_weights])
|
||||
|
||||
batch_size = batch["index"].shape[0]
|
||||
|
||||
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
||||
@@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module):
|
||||
)
|
||||
self.optim.step()
|
||||
|
||||
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
|
||||
# if self.cfg.per:
|
||||
# # Update priorities
|
||||
# priorities = priority_loss.clamp(max=1e4).detach()
|
||||
|
||||
Reference in New Issue
Block a user