fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene
2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@@ -112,16 +112,19 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
# these einops patterns will be used to aggregate batches and compute statistics
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
@@ -142,9 +145,9 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
first_batch = None
running_item_count = 0 # for online mean computation
for i, batch in enumerate(
tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = batch.batch_size[0]
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
@@ -166,8 +169,10 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None
running_item_count = 0 # for online std computation
for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")):
this_batch_size = batch.batch_size[0]
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None: