Refactor env queue, Training diffusion works (Still not converging)
This commit is contained in:
@@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
@@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
@@ -143,13 +143,24 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||
in_keys=[
|
||||
# ("observation", "image"),
|
||||
("observation", "state"),
|
||||
# TODO(rcadene): for tdmpc, we might want image and state
|
||||
# ("next", "observation", "image"),
|
||||
("next", "observation", "state"),
|
||||
# ("next", "observation", "state"),
|
||||
("action"),
|
||||
],
|
||||
mode="min_max",
|
||||
)
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec
|
||||
transform.stats["observation", "state", "min"] = torch.tensor(
|
||||
[13.456424, 32.938293], dtype=torch.float32
|
||||
)
|
||||
transform.stats["observation", "state", "max"] = torch.tensor(
|
||||
[496.14618, 510.9579], dtype=torch.float32
|
||||
)
|
||||
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
if writer is None:
|
||||
writer = ImmutableDatasetWriter()
|
||||
if collate_fn is None:
|
||||
|
||||
Reference in New Issue
Block a user