online training works (loss goes down), remove repeat_action, eval_policy outputs episodes data, eval_policy uses max_episodes_rendered

This commit is contained in:
Cadene
2024-04-10 11:34:01 +00:00
parent 19e7661b8d
commit 06573d7f67
11 changed files with 219 additions and 211 deletions

View File

@@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:

View File

@@ -119,7 +119,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:

View File

@@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:
@@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset):
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): concat the last "next_observations" to "observations"
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])

View File

@@ -35,9 +35,9 @@ def make_policy(cfg):
if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
if "offline" in cfg.policy.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
elif "final" in cfg.policy.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()

View File

@@ -333,94 +333,6 @@ class TDMPCPolicy(nn.Module):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
# num_slices = self.cfg.batch_size
# batch_size = self.cfg.horizon * num_slices
# if demo_buffer is None:
# demo_batch_size = 0
# else:
# # Update oversampling ratio
# demo_pc_batch = h.linear_schedule(self.cfg.demo_schedule, step)
# demo_num_slices = int(demo_pc_batch * self.batch_size)
# demo_batch_size = self.cfg.horizon * demo_num_slices
# batch_size -= demo_batch_size
# num_slices -= demo_num_slices
# replay_buffer._sampler.num_slices = num_slices
# demo_buffer._sampler.num_slices = demo_num_slices
# assert demo_batch_size % self.cfg.horizon == 0
# assert demo_batch_size % demo_num_slices == 0
# assert batch_size % self.cfg.horizon == 0
# assert batch_size % num_slices == 0
# # Sample from interaction dataset
# def process_batch(batch, horizon, num_slices):
# # trajectory t = 256, horizon h = 5
# # (t h) ... -> h t ...
# batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
# obs = {
# "rgb": batch["observation", "image"][FIRST_FRAME].to(self.device, non_blocking=True),
# "state": batch["observation", "state"][FIRST_FRAME].to(self.device, non_blocking=True),
# }
# action = batch["action"].to(self.device, non_blocking=True)
# next_obses = {
# "rgb": batch["next", "observation", "image"].to(self.device, non_blocking=True),
# "state": batch["next", "observation", "state"].to(self.device, non_blocking=True),
# }
# reward = batch["next", "reward"].to(self.device, non_blocking=True)
# idxs = batch["index"][FIRST_FRAME].to(self.device, non_blocking=True)
# weights = batch["_weight"][FIRST_FRAME, :, None].to(self.device, non_blocking=True)
# # TODO(rcadene): rearrange directly in offline dataset
# if reward.ndim == 2:
# reward = einops.rearrange(reward, "h t -> h t 1")
# assert reward.ndim == 3
# assert reward.shape == (horizon, num_slices, 1)
# # We dont use `batch["next", "done"]` since it only indicates the end of an
# # episode, but not the end of the trajectory of an episode.
# # Neither does `batch["next", "terminated"]`
# done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
# mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
# return obs, action, next_obses, reward, mask, done, idxs, weights
# batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
# obs, action, next_obses, reward, mask, done, idxs, weights = process_batch(
# batch, self.cfg.horizon, num_slices
# )
# Sample from demonstration dataset
# if demo_batch_size > 0:
# demo_batch = demo_buffer.sample(demo_batch_size)
# (
# demo_obs,
# demo_action,
# demo_next_obses,
# demo_reward,
# demo_mask,
# demo_done,
# demo_idxs,
# demo_weights,
# ) = process_batch(demo_batch, self.cfg.horizon, demo_num_slices)
# if isinstance(obs, dict):
# obs = {k: torch.cat([obs[k], demo_obs[k]]) for k in obs}
# next_obses = {k: torch.cat([next_obses[k], demo_next_obses[k]], dim=1) for k in next_obses}
# else:
# obs = torch.cat([obs, demo_obs])
# next_obses = torch.cat([next_obses, demo_next_obses], dim=1)
# action = torch.cat([action, demo_action], dim=1)
# reward = torch.cat([reward, demo_reward], dim=1)
# mask = torch.cat([mask, demo_mask], dim=1)
# done = torch.cat([done, demo_done], dim=1)
# idxs = torch.cat([idxs, demo_idxs])
# weights = torch.cat([weights, demo_weights])
batch_size = batch["index"].shape[0]
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
@@ -534,6 +446,7 @@ class TDMPCPolicy(nn.Module):
)
self.optim.step()
# TODO(rcadene): implement PrioritizedSampling by modifying sampler.weights with priorities computed by a criterion
# if self.cfg.per:
# # Update priorities
# priorities = priority_loss.clamp(max=1e4).detach()