Refactor env queue, Training diffusion works (Still not converging)

This commit is contained in:
Remi Cadene
2024-03-04 10:59:43 +00:00
parent fddd9f0311
commit cfc304e870
11 changed files with 96 additions and 111 deletions

View File

@@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
@@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
else:
raise ValueError(cfg.env.name)

View File

@@ -143,13 +143,24 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
in_keys=[
# ("observation", "image"),
("observation", "state"),
# TODO(rcadene): for tdmpc, we might want image and state
# ("next", "observation", "image"),
("next", "observation", "state"),
# ("next", "observation", "state"),
("action"),
],
mode="min_max",
)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec
transform.stats["observation", "state", "min"] = torch.tensor(
[13.456424, 32.938293], dtype=torch.float32
)
transform.stats["observation", "state", "max"] = torch.tensor(
[496.14618, 510.9579], dtype=torch.float32
)
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None: