import pickle import time from pathlib import Path import hydra import imageio import numpy as np import torch from tensordict.nn import TensorDictModule from termcolor import colored from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.datasets.d4rl import D4RLExperienceReplay from torchrl.data.datasets.openx import OpenXExperienceReplay from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger from lerobot.common.tdmpc import TDMPC from lerobot.common.utils import set_seed from lerobot.scripts.eval import eval_policy from rl.torchrl.collectors.collectors import SyncDataCollector @hydra.main(version_base=None, config_name="default", config_path="../configs") def train(cfg: dict): assert torch.cuda.is_available() set_seed(cfg.seed) print(colored("Work dir:", "yellow", attrs=["bold"]), cfg.log_dir) env = make_env(cfg) policy = TDMPC(cfg) ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt" policy.step = 25000 # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt" # policy.step = 100000 policy.load(ckpt_path) td_policy = TensorDictModule( policy, in_keys=["observation", "step_count"], out_keys=["action"], ) # initialize offline dataset dataset_id = f"xarm_{cfg.task}_medium" num_traj_per_batch = cfg.batch_size # // cfg.horizon # TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size. # We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size. sampler = PrioritizedSliceSampler( max_capacity=100_000, alpha=0.7, beta=0.9, num_slices=num_traj_per_batch, strict_length=False, ) # TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here offline_buffer = SimxarmExperienceReplay( dataset_id, # download="force", download=True, streaming=False, root="data", sampler=sampler, ) num_steps = len(offline_buffer) index = torch.arange(0, num_steps, 1) sampler.extend(index) if cfg.balanced_sampling: online_sampler = PrioritizedSliceSampler( max_capacity=100_000, alpha=0.7, beta=0.9, num_slices=num_traj_per_batch, strict_length=False, ) online_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage(100_000), sampler=online_sampler, # batch_size=3, # pin_memory=False, # prefetch=3, ) # Observation encoder # Dynamics predictor # Reward predictor # Policy # Qs state-action value predictor # V state value predictor L = Logger(cfg.log_dir, cfg) online_episode_idx = 0 start_time = time.time() step = 0 last_log_step = 0 last_save_step = 0 # TODO(rcadene): remove step = 25000 while step < cfg.train_steps: is_offline = True num_updates = cfg.episode_length _step = step + num_updates rollout_metrics = {} if step >= cfg.offline_steps: is_offline = False # TODO: use SyncDataCollector for that? rollout = env.rollout( max_steps=cfg.episode_length, policy=td_policy, ) assert len(rollout) <= cfg.episode_length rollout["episode"] = torch.tensor( [online_episode_idx] * len(rollout), dtype=torch.int ) online_buffer.extend(rollout) # Collect trajectory # obs = env.reset() # episode = Episode(cfg, obs) # success = False # while not episode.done: # action = policy.act(obs, step=step, t0=episode.first) # obs, reward, done, info = env.step(action.cpu().numpy()) # reward = reward_normalizer(reward) # mask = 1.0 if (not done or "TimeLimit.truncated" in info) else 0.0 # success = info.get('success', False) # episode += (obs, action, reward, done, mask, success) ep_reward = rollout["next", "reward"].sum() ep_success = rollout["next", "success"].any() online_episode_idx += 1 rollout_metrics = { # 'episode_reward': episode.cumulative_reward, # 'episode_success': float(success), # 'episode_length': len(episode) "avg_reward": np.nanmean(ep_reward), "pc_success": np.nanmean(ep_success) * 100, } num_updates = len(rollout) * cfg.utd _step = min(step + len(rollout), cfg.train_steps) # Update model train_metrics = {} if is_offline: for i in range(num_updates): train_metrics.update(policy.update(offline_buffer, step + i)) else: for i in range(num_updates): train_metrics.update( policy.update( online_buffer, step + i // cfg.utd, demo_buffer=offline_buffer if cfg.balanced_sampling else None, ) ) # Log training metrics env_step = int(_step * cfg.action_repeat) common_metrics = { "episode": online_episode_idx, "step": _step, "env_step": env_step, "total_time": time.time() - start_time, "is_offline": float(is_offline), } train_metrics.update(common_metrics) train_metrics.update(rollout_metrics) L.log(train_metrics, category="train") # Evaluate policy periodically if step == 0 or env_step - last_log_step >= cfg.eval_freq: eval_metrics = eval_policy( env, td_policy, num_episodes=cfg.eval_episodes, # TODO(rcadene): add step, env_step, L.video ) # TODO(rcadene): # if hasattr(env, "get_normalized_score"): # eval_metrics['normalized_score'] = env.get_normalized_score(eval_metrics["episode_reward"]) * 100.0 common_metrics.update(eval_metrics) L.log(common_metrics, category="eval") last_log_step = env_step - env_step % cfg.eval_freq # Save model periodically # if cfg.save_model and env_step - last_save_step >= cfg.save_freq: # L.save_model(policy, identifier=env_step) # print(f"Model has been checkpointed at step {env_step}") # last_save_step = env_step - env_step % cfg.save_freq # if cfg.save_model and is_offline and _step >= cfg.offline_steps: # # save the model after offline training # L.save_model(policy, identifier="offline") step = _step # dataset_d4rl = D4RLExperienceReplay( # dataset_id="maze2d-umaze-v1", # split_trajs=False, # batch_size=1, # sampler=SamplerWithoutReplacement(drop_last=False), # prefetch=4, # direct_download=True, # ) # dataset_openx = OpenXExperienceReplay( # "cmu_stretch", # batch_size=1, # num_slices=1, # #download="force", # streaming=False, # root="data", # ) if __name__ == "__main__": train()