test_examples are passing
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
from torchrl.data.replay_buffers import SamplerWithoutReplacement
|
||||
from pathlib import Path
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
@@ -9,16 +8,13 @@ from lerobot.scripts.visualize_dataset import render_dataset
|
||||
print(lerobot.available_datasets)
|
||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
||||
|
||||
# we use this sampler to sample 1 frame after the other
|
||||
sampler = SamplerWithoutReplacement(shuffle=False)
|
||||
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
|
||||
# TODO(rcadene): remove DATA_DIR
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human", root=Path(os.environ.get("DATA_DIR")))
|
||||
|
||||
video_paths = render_dataset(
|
||||
dataset,
|
||||
out_dir="outputs/visualize_dataset/example",
|
||||
max_num_samples=300,
|
||||
fps=50,
|
||||
max_num_episodes=1,
|
||||
)
|
||||
print(video_paths)
|
||||
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
||||
|
||||
@@ -9,9 +9,8 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
|
||||
@@ -37,19 +36,33 @@ policy = DiffusionPolicy(
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
policy.train()
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# create dataloader for offline training
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.policy.batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=cfg.device != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
for step, batch in enumerate(dataloader):
|
||||
info = policy(batch, step)
|
||||
|
||||
if step % cfg.log_freq == 0:
|
||||
num_samples = (step + 1) * cfg.policy.batch_size
|
||||
loss = info["loss"]
|
||||
update_s = info["update_s"]
|
||||
print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
|
||||
|
||||
for offline_step in trange(cfg.offline_steps):
|
||||
train_info = policy.update(offline_buffer, offline_step)
|
||||
if offline_step % cfg.log_freq == 0:
|
||||
print(train_info)
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
||||
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
|
||||
|
||||
Reference in New Issue
Block a user