online training works (loss goes down), remove repeat_action, eval_policy outputs episodes data, eval_policy uses max_episodes_rendered

This commit is contained in:
Cadene
2024-04-10 11:34:01 +00:00
parent 19e7661b8d
commit 06573d7f67
11 changed files with 219 additions and 211 deletions

View File

@@ -105,7 +105,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:

View File

@@ -119,7 +119,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:

View File

@@ -60,7 +60,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
@property
def num_episodes(self) -> int:
@@ -126,7 +126,8 @@ class XarmDataset(torch.utils.data.Dataset):
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): concat the last "next_observations" to "observations"
# TODO(rcadene): we have a missing last frame which is the observation when the env is done
# it is critical to have this frame for tdmpc to predict a "done observation/state"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])