Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)

This commit is contained in:
Cadene
2024-02-20 12:26:57 +00:00
parent fdfb2010fd
commit 3da6ffb2cb
10 changed files with 559 additions and 89 deletions

View File

@@ -71,24 +71,35 @@ def eval(cfg: dict):
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg)
policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
policy.load(ckpt_path)
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
if cfg.pretrained_model_path:
policy = TDMPC(cfg)
ckpt_path = (
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
)
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(ckpt_path)
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
else:
# when policy is None, rollout a random policy
policy = None
# policy can be None to rollout a random policy
metrics = eval_policy(
env,
policy=policy,
num_episodes=20,
save_video=False,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
save_video=True,
video_dir=Path("tmp/2023_02_19_pusht"),
)
print(metrics)

View File

@@ -10,6 +10,7 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
from torchrl.data.datasets.openx import OpenXExperienceReplay
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger
@@ -26,11 +27,17 @@ def train(cfg: dict):
env = make_env(cfg)
policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
# policy.step = 25000
# # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
# # policy.step = 100000
# policy.load(ckpt_path)
if cfg.pretrained_model_path:
ckpt_path = (
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
)
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(ckpt_path)
td_policy = TensorDictModule(
policy,
@@ -40,32 +47,7 @@ def train(cfg: dict):
# initialize offline dataset
dataset_id = f"xarm_{cfg.task}_medium"
num_traj_per_batch = cfg.batch_size # // cfg.horizon
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.per_alpha,
beta=cfg.per_beta,
num_slices=num_traj_per_batch,
strict_length=False,
)
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
offline_buffer = SimxarmExperienceReplay(
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
num_steps = len(offline_buffer)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
offline_buffer = make_offline_buffer(cfg)
if cfg.balanced_sampling:
online_sampler = PrioritizedSliceSampler(