test_examples are passing

This commit is contained in:
Cadene
2024-04-10 13:45:45 +00:00
parent 6082a7bc73
commit c08003278e
4 changed files with 62 additions and 79 deletions

View File

@@ -1,6 +1,5 @@
import os
from torchrl.data.replay_buffers import SamplerWithoutReplacement
from pathlib import Path
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
@@ -9,16 +8,13 @@ from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# we use this sampler to sample 1 frame after the other
sampler = SamplerWithoutReplacement(shuffle=False)
dataset = AlohaDataset("aloha_sim_transfer_cube_human", sampler=sampler, root=os.environ.get("DATA_DIR"))
# TODO(rcadene): remove DATA_DIR
dataset = AlohaDataset("aloha_sim_transfer_cube_human", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_samples=300,
fps=50,
max_num_episodes=1,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']

View File

@@ -9,9 +9,8 @@ from pathlib import Path
import torch
from omegaconf import OmegaConf
from tqdm import trange
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
@@ -37,19 +36,33 @@ policy = DiffusionPolicy(
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
policy.train()
offline_buffer = make_offline_buffer(cfg)
dataset = make_dataset(cfg)
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=True,
)
for step, batch in enumerate(dataloader):
info = policy(batch, step)
if step % cfg.log_freq == 0:
num_samples = (step + 1) * cfg.policy.batch_size
loss = info["loss"]
update_s = info["update_s"]
print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)")
for offline_step in trange(cfg.offline_steps):
train_info = policy.update(offline_buffer, offline_step)
if offline_step % cfg.log_freq == 0:
print(train_info)
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")