Refactor env queue, Training diffusion works (Still not converging)
This commit is contained in:
@@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
@@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
Reference in New Issue
Block a user