forked from tangger/lerobot
Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)
This commit is contained in:
@@ -10,6 +10,7 @@ from torchrl.data.datasets.d4rl import D4RLExperienceReplay
|
||||
from torchrl.data.datasets.openx import OpenXExperienceReplay
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger
|
||||
@@ -26,11 +27,17 @@ def train(cfg: dict):
|
||||
|
||||
env = make_env(cfg)
|
||||
policy = TDMPC(cfg)
|
||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
# policy.step = 25000
|
||||
# # ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
# # policy.step = 100000
|
||||
# policy.load(ckpt_path)
|
||||
if cfg.pretrained_model_path:
|
||||
ckpt_path = (
|
||||
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
)
|
||||
if "offline" in cfg.pretrained_model_path:
|
||||
policy.step = 25000
|
||||
elif "final" in cfg.pretrained_model_path:
|
||||
policy.step = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
policy.load(ckpt_path)
|
||||
|
||||
td_policy = TensorDictModule(
|
||||
policy,
|
||||
@@ -40,32 +47,7 @@ def train(cfg: dict):
|
||||
|
||||
# initialize offline dataset
|
||||
|
||||
dataset_id = f"xarm_{cfg.task}_medium"
|
||||
|
||||
num_traj_per_batch = cfg.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.per_alpha,
|
||||
beta=cfg.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
|
||||
# TODO(rcadene): add PrioritizedSliceSampler inside Simxarm to not have to `sampler.extend(index)` here
|
||||
offline_buffer = SimxarmExperienceReplay(
|
||||
dataset_id,
|
||||
# download="force",
|
||||
download=True,
|
||||
streaming=False,
|
||||
root="data",
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
num_steps = len(offline_buffer)
|
||||
index = torch.arange(0, num_steps, 1)
|
||||
sampler.extend(index)
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
if cfg.balanced_sampling:
|
||||
online_sampler = PrioritizedSliceSampler(
|
||||
|
||||
Reference in New Issue
Block a user