Refactor env queue, Training diffusion works (Still not converging)

This commit is contained in:
Remi Cadene
2024-03-04 10:59:43 +00:00
parent fddd9f0311
commit cfc304e870
11 changed files with 96 additions and 111 deletions

View File

@@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
@@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
else:
raise ValueError(cfg.env.name)