Small fix and improve logging message

This commit is contained in:
Cadene
2024-02-27 11:44:26 +00:00
parent 21670dce90
commit 7df542445c
5 changed files with 37 additions and 16 deletions

View File

@@ -22,21 +22,24 @@ python setup.py develop
``` ```
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
--config-name=pusht hydra.job.name=pusht hydra.job.name=pusht \
env=pusht
``` ```
### Visualize offline buffer ### Visualize offline buffer
``` ```
python lerobot/scripts/visualize_dataset.py \ python lerobot/scripts/visualize_dataset.py \
--config-name=pusht hydra.run.dir=tmp/$(date +"%Y_%m_%d") hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
env=pusht
``` ```
### Visualize online buffer / Eval ### Visualize online buffer / Eval
``` ```
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--config-name=pusht hydra.run.dir=tmp/$(date +"%Y_%m_%d") hydra.run.dir=tmp/$(date +"%Y_%m_%d") \
env=pusht
``` ```

View File

@@ -3,6 +3,7 @@
eval_episodes: 50 eval_episodes: 50
eval_freq: 7500 eval_freq: 7500
save_freq: 75000 save_freq: 75000
log_freq: 250
# TODO: same as simxarm, need to adjust # TODO: same as simxarm, need to adjust
offline_steps: 25000 offline_steps: 25000
online_steps: 25000 online_steps: 25000

View File

@@ -21,6 +21,9 @@ past_action_visible: False
keypoint_visible_rate: 1.0 keypoint_visible_rate: 1.0
obs_as_global_cond: True obs_as_global_cond: True
offline_steps: 50000
online_steps: 0
policy: policy:
name: diffusion name: diffusion

View File

@@ -5,6 +5,7 @@ import hydra
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
import tqdm
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored from termcolor import colored
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
@@ -32,7 +33,7 @@ def eval_policy(
max_rewards = [] max_rewards = []
successes = [] successes = []
threads = [] threads = []
for i in range(num_episodes): for i in tqdm.tqdm(range(num_episodes)):
tensordict = env.reset() tensordict = env.reset()
ep_frames = [] ep_frames = []

View File

@@ -50,7 +50,7 @@ def log_training_metrics(L, metrics, step, online_episode_idx, start_time, is_of
def eval_policy_and_log( def eval_policy_and_log(
env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline
): ):
common_metrics = { common_metrics = {
"episode": online_episode_idx, "episode": online_episode_idx,
@@ -83,7 +83,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir) print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir)
print("make_env")
env = make_env(cfg) env = make_env(cfg)
print("make_policy")
policy = make_policy(cfg) policy = make_policy(cfg)
td_policy = TensorDictModule( td_policy = TensorDictModule(
@@ -92,12 +95,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
out_keys=["action"], out_keys=["action"],
) )
# initialize offline dataset print("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg) offline_buffer = make_offline_buffer(cfg)
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy # TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
if cfg.policy.balanced_sampling: if cfg.policy.balanced_sampling:
print("make online_buffer")
num_traj_per_batch = cfg.policy.batch_size num_traj_per_batch = cfg.policy.batch_size
online_sampler = PrioritizedSliceSampler( online_sampler = PrioritizedSliceSampler(
@@ -117,15 +120,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_episode_idx = 0 online_episode_idx = 0
start_time = time.time() start_time = time.time()
step = 0 step = 0 # number of policy update
# First eval with a random model or pretrained print("First eval_policy_and_log with a random model or pretrained")
eval_policy_and_log( eval_policy_and_log(
env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L env, td_policy, step, online_episode_idx, start_time, cfg, L, is_offline=True
) )
# Train offline for offline_step in range(cfg.offline_steps):
for _ in range(cfg.offline_steps): if offline_step == 0:
print("Start offline training on a fixed dataset")
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
metrics = policy.update(offline_buffer, step) metrics = policy.update(offline_buffer, step)
@@ -136,7 +140,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step > 0 and step % cfg.eval_freq == 0: if step > 0 and step % cfg.eval_freq == 0:
eval_policy_and_log( eval_policy_and_log(
env, td_policy, step, online_episode_idx, start_time, is_offline, cfg, L env,
td_policy,
step,
online_episode_idx,
start_time,
cfg,
L,
is_offline=True,
) )
if step > 0 and cfg.save_model and step % cfg.save_freq == 0: if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
@@ -145,10 +156,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1 step += 1
# Train online
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
for _ in range(cfg.online_steps): for env_step in range(cfg.online_steps):
if env_step == 0:
print("Start online training by interacting with environment")
# TODO: use SyncDataCollector for that? # TODO: use SyncDataCollector for that?
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad(): with torch.no_grad():
rollout = env.rollout( rollout = env.rollout(
max_steps=cfg.env.episode_length, max_steps=cfg.env.episode_length,
@@ -191,9 +204,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
step, step,
online_episode_idx, online_episode_idx,
start_time, start_time,
is_offline,
cfg, cfg,
L, L,
is_offline=False,
) )
if step > 0 and cfg.save_model and step % cfg.save_freq == 0: if step > 0 and cfg.save_model and step % cfg.save_freq == 0: