fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene
2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@@ -91,15 +91,17 @@ class AlohaDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property
def num_samples(self) -> int:

View File

@@ -105,15 +105,17 @@ class PushtDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property
def num_samples(self) -> int:

View File

@@ -46,15 +46,17 @@ class SimxarmDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
@property
def num_samples(self) -> int:

View File

@@ -112,16 +112,19 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
# these einops patterns will be used to aggregate batches and compute statistics
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
@@ -142,9 +145,9 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
first_batch = None
running_item_count = 0 # for online mean computation
for i, batch in enumerate(
tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = batch.batch_size[0]
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
@@ -166,8 +169,10 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None
running_item_count = 0 # for online std computation
for i, batch in enumerate(tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")):
this_batch_size = batch.batch_size[0]
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:

View File

@@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
result = {"action": action, "action_pred": action_pred}
return result
def compute_loss(self, batch):
assert "valid_mask" not in batch
nobs = batch["obs"]
nactions = batch["action"]
def compute_loss(self, obs_dict, action):
nobs = obs_dict
nactions = action
batch_size = nactions.shape[0]
horizon = nactions.shape[1]

View File

@@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module):
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
loss = self.diffusion.compute_loss(obs_dict)
action = batch["action"]
loss = self.diffusion.compute_loss(obs_dict, action)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(

View File

@@ -72,12 +72,12 @@ class NormalizeTransform(Transform):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[f"{inkey}.mean"]
std = self.stats[f"{inkey}.std"]
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
else:
min = self.stats[f"{inkey}.min"]
max = self.stats[f"{inkey}.max"]
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
# normalize to [0,1]
item[outkey] = (item[inkey] - min) / (max - min)
# normalize to [-1, 1]
@@ -89,12 +89,12 @@ class NormalizeTransform(Transform):
if inkey not in item:
continue
if self.mode == "mean_std":
mean = self.stats[f"{inkey}.mean"]
std = self.stats[f"{inkey}.std"]
mean = self.stats[inkey]["mean"]
std = self.stats[inkey]["std"]
item[outkey] = item[inkey] * std + mean
else:
min = self.stats[f"{inkey}.min"]
max = self.stats[f"{inkey}.max"]
min = self.stats[inkey]["min"]
max = self.stats[inkey]["max"]
item[outkey] = (item[inkey] + 1) / 2
item[outkey] = item[outkey] * (max - min) + min
return item