Clean + alpha beta corresponds to config (before 0.7 and 0.9)

This commit is contained in:
Cadene
2024-02-16 16:27:54 +00:00
parent 0cdd23dcac
commit 0b4084f0f8
2 changed files with 18 additions and 45 deletions

View File

@@ -21,6 +21,15 @@ python setup.py develop
- [x] self.step=100000 should be updated at every step to adjust to horizon of planner - [x] self.step=100000 should be updated at every step to adjust to horizon of planner
- [ ] prefetch replay buffer to speedup training - [ ] prefetch replay buffer to speedup training
- [ ] parallelize env to speedup eval - [ ] parallelize env to speedup eval
- [ ] clean checkpointing / loading
- [ ] clean logging
- [ ] clean config
- [ ] clean hyperparameter tuning
- [ ] add pusht
- [ ] add aloha
- [ ] add act
- [ ] add diffusion
- [ ] add aloha 2
## Contribute ## Contribute

View File

@@ -1,9 +1,6 @@
import pickle
import time import time
from pathlib import Path
import hydra import hydra
import imageio
import numpy as np import numpy as np
import torch import torch
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
@@ -19,7 +16,6 @@ from lerobot.common.logger import Logger
from lerobot.common.tdmpc import TDMPC from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed from lerobot.common.utils import set_seed
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy
from rl.torchrl.collectors.collectors import SyncDataCollector
@hydra.main(version_base=None, config_name="default", config_path="../configs") @hydra.main(version_base=None, config_name="default", config_path="../configs")
@@ -30,11 +26,11 @@ def train(cfg: dict):
env = make_env(cfg) env = make_env(cfg)
policy = TDMPC(cfg) policy = TDMPC(cfg)
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
policy.step = 25000 # policy.step = 25000
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" # # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
# policy.step = 100000 # # policy.step = 100000
policy.load(ckpt_path) # policy.load(ckpt_path)
td_policy = TensorDictModule( td_policy = TensorDictModule(
policy, policy,
@@ -51,8 +47,8 @@ def train(cfg: dict):
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler( sampler = PrioritizedSliceSampler(
max_capacity=100_000, max_capacity=100_000,
alpha=0.7, alpha=cfg.per_alpha,
beta=0.9, beta=cfg.per_beta,
num_slices=num_traj_per_batch, num_slices=num_traj_per_batch,
strict_length=False, strict_length=False,
) )
@@ -74,8 +70,8 @@ def train(cfg: dict):
if cfg.balanced_sampling: if cfg.balanced_sampling:
online_sampler = PrioritizedSliceSampler( online_sampler = PrioritizedSliceSampler(
max_capacity=100_000, max_capacity=100_000,
alpha=0.7, alpha=cfg.per_alpha,
beta=0.9, beta=cfg.per_beta,
num_slices=num_traj_per_batch, num_slices=num_traj_per_batch,
strict_length=False, strict_length=False,
) )
@@ -83,18 +79,8 @@ def train(cfg: dict):
online_buffer = TensorDictReplayBuffer( online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000), storage=LazyMemmapStorage(100_000),
sampler=online_sampler, sampler=online_sampler,
# batch_size=3,
# pin_memory=False,
# prefetch=3,
) )
# Observation encoder
# Dynamics predictor
# Reward predictor
# Policy
# Qs state-action value predictor
# V state value predictor
L = Logger(cfg.log_dir, cfg) L = Logger(cfg.log_dir, cfg)
online_episode_idx = 0 online_episode_idx = 0
@@ -103,9 +89,6 @@ def train(cfg: dict):
last_log_step = 0 last_log_step = 0
last_save_step = 0 last_save_step = 0
# TODO(rcadene): remove
step = 25000
while step < cfg.train_steps: while step < cfg.train_steps:
is_offline = True is_offline = True
num_updates = cfg.episode_length num_updates = cfg.episode_length
@@ -126,26 +109,11 @@ def train(cfg: dict):
) )
online_buffer.extend(rollout) online_buffer.extend(rollout)
# Collect trajectory
# obs = env.reset()
# episode = Episode(cfg, obs)
# success = False
# while not episode.done:
# action = policy.act(obs, step=step, t0=episode.first)
# obs, reward, done, info = env.step(action.cpu().numpy())
# reward = reward_normalizer(reward)
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
# success = info.get('success', False)
# episode += (obs, action, reward, done, mask, success)
ep_reward = rollout["next", "reward"].sum() ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any() ep_success = rollout["next", "success"].any()
online_episode_idx += 1 online_episode_idx += 1
rollout_metrics = { rollout_metrics = {
# 'episode_reward': episode.cumulative_reward,
# 'episode_success': float(success),
# 'episode_length': len(episode)
"avg_reward": np.nanmean(ep_reward), "avg_reward": np.nanmean(ep_reward),
"pc_success": np.nanmean(ep_success) * 100, "pc_success": np.nanmean(ep_success) * 100,
} }
@@ -190,10 +158,6 @@ def train(cfg: dict):
# TODO(rcadene): add step, env_step, L.video # TODO(rcadene): add step, env_step, L.video
) )
# TODO(rcadene):
# if hasattr(env, "get_normalized_score"):
# eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0
common_metrics.update(eval_metrics) common_metrics.update(eval_metrics)
L.log(common_metrics, category="eval") L.log(common_metrics, category="eval")