Refactor configs to have env in seperate yaml + Fix training
This commit is contained in:
@@ -10,28 +10,29 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
overwrite_sampler = sampler is not None
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
||||
if cfg.env == "simxarm":
|
||||
if cfg.env.name == "simxarm":
|
||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||
offline_buffer = SimxarmExperienceReplay(
|
||||
f"xarm_{cfg.task}_medium",
|
||||
f"xarm_{cfg.env.task}_medium",
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root="data",
|
||||
sampler=sampler,
|
||||
)
|
||||
elif cfg.env == "pusht":
|
||||
elif cfg.env.name == "pusht":
|
||||
offline_buffer = PushtExperienceReplay(
|
||||
"pusht",
|
||||
# download="force",
|
||||
@@ -41,7 +42,7 @@ def make_offline_buffer(cfg, sampler=None):
|
||||
sampler=sampler,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.env)
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
if not overwrite_sampler:
|
||||
num_steps = len(offline_buffer)
|
||||
|
||||
Reference in New Issue
Block a user