ready for review

This commit is contained in:
Alexander Soare
2024-04-16 13:43:58 +01:00
parent 03b08eb74e
commit 9c2f10bd04
4 changed files with 26 additions and 16 deletions

View File

@@ -11,6 +11,7 @@ from lerobot.common.utils import init_hydra_config
from lerobot.scripts.eval import eval
# Get a pretrained policy from the hub.
# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs.
hub_id = "lerobot/diffusion_policy_pusht_image"
folder = Path(snapshot_download(hub_id))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.

View File

@@ -11,7 +11,6 @@ import torch
from omegaconf import OmegaConf
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
@@ -26,8 +25,8 @@ device = torch.device("cuda")
log_freq = 250
# Set up the dataset.
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
dataset = make_dataset(cfg)
hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"])
dataset = make_dataset(hydra_cfg)
# Set up the the policy.
# Policies are initialized with a configuration class, in this case `DiffusionConfig`.
@@ -50,17 +49,25 @@ dataloader = torch.utils.data.DataLoader(
)
# Run training loop.
dataloader = cycle(dataloader)
for step in range(training_steps):
batch = {k: v.to(device, non_blocking=True) for k, v in next(dataloader).items()}
info = policy(batch)
if step % log_freq == 0:
num_samples = (step + 1) * cfg.batch_size
loss = info["loss"]
update_s = info["update_s"]
print(f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)")
step = 0
done = False
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
info = policy(batch)
if step % log_freq == 0:
num_samples = (step + 1) * cfg.batch_size
loss = info["loss"]
update_s = info["update_s"]
print(
f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)"
)
step += 1
if step >= training_steps:
done = True
break
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")