Refactor configs to have env in seperate yaml + Fix training
This commit is contained in:
@@ -62,13 +62,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
if cfg.balanced_sampling:
|
||||
num_traj_per_batch = cfg.batch_size
|
||||
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
|
||||
if cfg.policy.balanced_sampling:
|
||||
num_traj_per_batch = cfg.policy.batch_size
|
||||
|
||||
online_sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=True,
|
||||
)
|
||||
@@ -92,7 +93,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
_step = step + num_updates
|
||||
rollout_metrics = {}
|
||||
|
||||
if step >= cfg.offline_steps:
|
||||
# TODO(rcadene): move offline_steps outside policy
|
||||
if step >= cfg.policy.offline_steps:
|
||||
is_offline = False
|
||||
|
||||
# TODO: use SyncDataCollector for that?
|
||||
@@ -118,7 +120,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
"avg_max_reward": np.nanmean(ep_max_reward),
|
||||
"pc_success": np.nanmean(ep_success) * 100,
|
||||
}
|
||||
num_updates = len(rollout) * cfg.utd
|
||||
num_updates = len(rollout) * cfg.policy.utd
|
||||
_step = min(step + len(rollout), cfg.train_steps)
|
||||
|
||||
# Update model
|
||||
@@ -128,8 +130,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
else:
|
||||
train_metrics = policy.update(
|
||||
online_buffer,
|
||||
step + i // cfg.utd,
|
||||
demo_buffer=offline_buffer if cfg.balanced_sampling else None,
|
||||
step + i // cfg.policy.utd,
|
||||
demo_buffer=(
|
||||
offline_buffer if cfg.policy.balanced_sampling else None
|
||||
),
|
||||
)
|
||||
|
||||
# Log training metrics
|
||||
|
||||
Reference in New Issue
Block a user