Add diffusion policy (train and eval works, TODO: reproduce results)

This commit is contained in:
Cadene
2024-02-28 15:21:30 +00:00
parent f1708c8a37
commit cf5063e50e
5 changed files with 125 additions and 31 deletions

View File

@@ -122,11 +122,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
start_time = time.time()
step = 0 # number of policy update
print("First eval_policy_and_log with a random model or pretrained")
eval_policy_and_log(
env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline=True
)
for offline_step in range(cfg.offline_steps):
if offline_step == 0:
print("Start offline training on a fixed dataset")