Add Prod transform, Add test_factory
This commit is contained in:
@@ -50,6 +50,8 @@ def train(cfg: dict):
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
if cfg.balanced_sampling:
|
||||
num_traj_per_batch = cfg.batch_size
|
||||
|
||||
online_sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.per_alpha,
|
||||
|
||||
Reference in New Issue
Block a user