Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -232,7 +232,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -339,7 +339,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
policy,
|
||||
transform=offline_dataset.transform,
|
||||
return_episode_data=True,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user