Clean + alpha beta corresponds to config (before 0.7 and 0.9)
This commit is contained in:
@@ -21,6 +21,15 @@ python setup.py develop
|
|||||||
- [x] self.step=100000 should be updated at every step to adjust to horizon of planner
|
- [x] self.step=100000 should be updated at every step to adjust to horizon of planner
|
||||||
- [ ] prefetch replay buffer to speedup training
|
- [ ] prefetch replay buffer to speedup training
|
||||||
- [ ] parallelize env to speedup eval
|
- [ ] parallelize env to speedup eval
|
||||||
|
- [ ] clean checkpointing / loading
|
||||||
|
- [ ] clean logging
|
||||||
|
- [ ] clean config
|
||||||
|
- [ ] clean hyperparameter tuning
|
||||||
|
- [ ] add pusht
|
||||||
|
- [ ] add aloha
|
||||||
|
- [ ] add act
|
||||||
|
- [ ] add diffusion
|
||||||
|
- [ ] add aloha 2
|
||||||
|
|
||||||
## Contribute
|
## Contribute
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
import pickle
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import imageio
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
@@ -19,7 +16,6 @@ from lerobot.common.logger import Logger
|
|||||||
from lerobot.common.tdmpc import TDMPC
|
from lerobot.common.tdmpc import TDMPC
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_seed
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
from rl.torchrl.collectors.collectors import SyncDataCollector
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
@@ -30,11 +26,11 @@ def train(cfg: dict):
|
|||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
policy = TDMPC(cfg)
|
policy = TDMPC(cfg)
|
||||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||||
policy.step = 25000
|
# policy.step = 25000
|
||||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
# # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||||
# policy.step = 100000
|
# # policy.step = 100000
|
||||||
policy.load(ckpt_path)
|
# policy.load(ckpt_path)
|
||||||
|
|
||||||
td_policy = TensorDictModule(
|
td_policy = TensorDictModule(
|
||||||
policy,
|
policy,
|
||||||
@@ -51,8 +47,8 @@ def train(cfg: dict):
|
|||||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||||
sampler = PrioritizedSliceSampler(
|
sampler = PrioritizedSliceSampler(
|
||||||
max_capacity=100_000,
|
max_capacity=100_000,
|
||||||
alpha=0.7,
|
alpha=cfg.per_alpha,
|
||||||
beta=0.9,
|
beta=cfg.per_beta,
|
||||||
num_slices=num_traj_per_batch,
|
num_slices=num_traj_per_batch,
|
||||||
strict_length=False,
|
strict_length=False,
|
||||||
)
|
)
|
||||||
@@ -74,8 +70,8 @@ def train(cfg: dict):
|
|||||||
if cfg.balanced_sampling:
|
if cfg.balanced_sampling:
|
||||||
online_sampler = PrioritizedSliceSampler(
|
online_sampler = PrioritizedSliceSampler(
|
||||||
max_capacity=100_000,
|
max_capacity=100_000,
|
||||||
alpha=0.7,
|
alpha=cfg.per_alpha,
|
||||||
beta=0.9,
|
beta=cfg.per_beta,
|
||||||
num_slices=num_traj_per_batch,
|
num_slices=num_traj_per_batch,
|
||||||
strict_length=False,
|
strict_length=False,
|
||||||
)
|
)
|
||||||
@@ -83,18 +79,8 @@ def train(cfg: dict):
|
|||||||
online_buffer = TensorDictReplayBuffer(
|
online_buffer = TensorDictReplayBuffer(
|
||||||
storage=LazyMemmapStorage(100_000),
|
storage=LazyMemmapStorage(100_000),
|
||||||
sampler=online_sampler,
|
sampler=online_sampler,
|
||||||
# batch_size=3,
|
|
||||||
# pin_memory=False,
|
|
||||||
# prefetch=3,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Observation encoder
|
|
||||||
# Dynamics predictor
|
|
||||||
# Reward predictor
|
|
||||||
# Policy
|
|
||||||
# Qs state-action value predictor
|
|
||||||
# V state value predictor
|
|
||||||
|
|
||||||
L = Logger(cfg.log_dir, cfg)
|
L = Logger(cfg.log_dir, cfg)
|
||||||
|
|
||||||
online_episode_idx = 0
|
online_episode_idx = 0
|
||||||
@@ -103,9 +89,6 @@ def train(cfg: dict):
|
|||||||
last_log_step = 0
|
last_log_step = 0
|
||||||
last_save_step = 0
|
last_save_step = 0
|
||||||
|
|
||||||
# TODO(rcadene): remove
|
|
||||||
step = 25000
|
|
||||||
|
|
||||||
while step < cfg.train_steps:
|
while step < cfg.train_steps:
|
||||||
is_offline = True
|
is_offline = True
|
||||||
num_updates = cfg.episode_length
|
num_updates = cfg.episode_length
|
||||||
@@ -126,26 +109,11 @@ def train(cfg: dict):
|
|||||||
)
|
)
|
||||||
online_buffer.extend(rollout)
|
online_buffer.extend(rollout)
|
||||||
|
|
||||||
# Collect trajectory
|
|
||||||
# obs = env.reset()
|
|
||||||
# episode = Episode(cfg, obs)
|
|
||||||
# success = False
|
|
||||||
# while not episode.done:
|
|
||||||
# action = policy.act(obs, step=step, t0=episode.first)
|
|
||||||
# obs, reward, done, info = env.step(action.cpu().numpy())
|
|
||||||
# reward = reward_normalizer(reward)
|
|
||||||
# mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0
|
|
||||||
# success = info.get('success', False)
|
|
||||||
# episode += (obs, action, reward, done, mask, success)
|
|
||||||
|
|
||||||
ep_reward = rollout["next", "reward"].sum()
|
ep_reward = rollout["next", "reward"].sum()
|
||||||
ep_success = rollout["next", "success"].any()
|
ep_success = rollout["next", "success"].any()
|
||||||
|
|
||||||
online_episode_idx += 1
|
online_episode_idx += 1
|
||||||
rollout_metrics = {
|
rollout_metrics = {
|
||||||
# 'episode_reward': episode.cumulative_reward,
|
|
||||||
# 'episode_success': float(success),
|
|
||||||
# 'episode_length': len(episode)
|
|
||||||
"avg_reward": np.nanmean(ep_reward),
|
"avg_reward": np.nanmean(ep_reward),
|
||||||
"pc_success": np.nanmean(ep_success) * 100,
|
"pc_success": np.nanmean(ep_success) * 100,
|
||||||
}
|
}
|
||||||
@@ -190,10 +158,6 @@ def train(cfg: dict):
|
|||||||
# TODO(rcadene): add step, env_step, L.video
|
# TODO(rcadene): add step, env_step, L.video
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(rcadene):
|
|
||||||
# if hasattr(env, "get_normalized_score"):
|
|
||||||
# eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0
|
|
||||||
|
|
||||||
common_metrics.update(eval_metrics)
|
common_metrics.update(eval_metrics)
|
||||||
|
|
||||||
L.log(common_metrics, category="eval")
|
L.log(common_metrics, category="eval")
|
||||||
|
|||||||
Reference in New Issue
Block a user