diff --git a/README.md b/README.md index afa58baf6..f55ee6951 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,15 @@ python setup.py develop - [x] self.step=100000 should be updated at every step to adjust to horizon of planner - [ ] prefetch replay buffer to speedup training - [ ] parallelize env to speedup eval +- [ ] clean checkpointing / loading +- [ ] clean logging +- [ ] clean config +- [ ] clean hyperparameter tuning +- [ ] add pusht +- [ ] add aloha +- [ ] add act +- [ ] add diffusion +- [ ] add aloha 2 ## Contribute diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 8ae05cdae..13020b554 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -1,9 +1,6 @@ -import pickle import time -from pathlib import Path import hydra -import imageio import numpy as np import torch from tensordict.nn import TensorDictModule @@ -19,7 +16,6 @@ from lerobot.common.logger import Logger from lerobot.common.tdmpc import TDMPC from lerobot.common.utils import set_seed from lerobot.scripts.eval import eval_policy -from rl.torchrl.collectors.collectors import SyncDataCollector @hydra.main(version_base=None, config_name="default", config_path="../configs") @@ -30,11 +26,11 @@ def train(cfg: dict): env = make_env(cfg) policy = TDMPC(cfg) - ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" - policy.step = 25000 - # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" - # policy.step = 100000 - policy.load(ckpt_path) + # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" + # policy.step = 25000 + # # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" + # # policy.step = 100000 + # policy.load(ckpt_path) td_policy = TensorDictModule( policy, @@ -51,8 +47,8 @@ def train(cfg: dict): # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. sampler = PrioritizedSliceSampler( max_capacity=100_000, - alpha=0.7, - beta=0.9, + alpha=cfg.per_alpha, + beta=cfg.per_beta, num_slices=num_traj_per_batch, strict_length=False, ) @@ -74,8 +70,8 @@ def train(cfg: dict): if cfg.balanced_sampling: online_sampler = PrioritizedSliceSampler( max_capacity=100_000, - alpha=0.7, - beta=0.9, + alpha=cfg.per_alpha, + beta=cfg.per_beta, num_slices=num_traj_per_batch, strict_length=False, ) @@ -83,18 +79,8 @@ def train(cfg: dict): online_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage(100_000), sampler=online_sampler, - # batch_size=3, - # pin_memory=False, - # prefetch=3, ) - # Observation encoder - # Dynamics predictor - # Reward predictor - # Policy - # Qs state-action value predictor - # V state value predictor - L = Logger(cfg.log_dir, cfg) online_episode_idx = 0 @@ -103,9 +89,6 @@ def train(cfg: dict): last_log_step = 0 last_save_step = 0 - # TODO(rcadene): remove - step = 25000 - while step < cfg.train_steps: is_offline = True num_updates = cfg.episode_length @@ -126,26 +109,11 @@ def train(cfg: dict): ) online_buffer.extend(rollout) - # Collect trajectory - # obs = env.reset() - # episode = Episode(cfg, obs) - # success = False - # while not episode.done: - # action = policy.act(obs, step=step, t0=episode.first) - # obs, reward, done, info = env.step(action.cpu().numpy()) - # reward = reward_normalizer(reward) - # mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0 - # success = info.get('success', False) - # episode += (obs, action, reward, done, mask, success) - ep_reward = rollout["next", "reward"].sum() ep_success = rollout["next", "success"].any() online_episode_idx += 1 rollout_metrics = { - # 'episode_reward': episode.cumulative_reward, - # 'episode_success': float(success), - # 'episode_length': len(episode) "avg_reward": np.nanmean(ep_reward), "pc_success": np.nanmean(ep_success) * 100, } @@ -190,10 +158,6 @@ def train(cfg: dict): # TODO(rcadene): add step, env_step, L.video ) - # TODO(rcadene): - # if hasattr(env, "get_normalized_score"): - # eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0 - common_metrics.update(eval_metrics) L.log(common_metrics, category="eval")