fix train.py, stats, eval.py (training is running)
This commit is contained in:
@@ -91,15 +91,17 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
data_dir = self.root / f"{self.dataset_id}"
|
||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
|
||||
@@ -105,15 +105,17 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
data_dir = self.root / f"{self.dataset_id}"
|
||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
|
||||
@@ -46,15 +46,17 @@ class SimxarmDataset(torch.utils.data.Dataset):
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
data_dir = self.root / f"{self.dataset_id}"
|
||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
|
||||
@@ -112,16 +112,19 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
else:
|
||||
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
shuffle=False,
|
||||
# pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# these einops patterns will be used to aggregate batches and compute statistics
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
@@ -142,9 +145,9 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
for i, batch in enumerate(
|
||||
tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = batch.batch_size[0]
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
@@ -166,8 +169,10 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")):
|
||||
this_batch_size = batch.batch_size[0]
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
|
||||
Reference in New Issue
Block a user