forked from tangger/lerobot
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97cb7a2362 | ||
|
|
b6c216b590 | ||
|
|
2b270d085b |
38
Makefile
38
Makefile
@@ -20,6 +20,8 @@ build-gpu:
|
|||||||
test-end-to-end:
|
test-end-to-end:
|
||||||
${MAKE} test-act-ete-train
|
${MAKE} test-act-ete-train
|
||||||
${MAKE} test-act-ete-eval
|
${MAKE} test-act-ete-eval
|
||||||
|
${MAKE} test-act-ete-train-amp
|
||||||
|
${MAKE} test-act-ete-eval-amp
|
||||||
${MAKE} test-diffusion-ete-train
|
${MAKE} test-diffusion-ete-train
|
||||||
${MAKE} test-diffusion-ete-eval
|
${MAKE} test-diffusion-ete-eval
|
||||||
${MAKE} test-tdmpc-ete-train
|
${MAKE} test-tdmpc-ete-train
|
||||||
@@ -29,6 +31,7 @@ test-end-to-end:
|
|||||||
test-act-ete-train:
|
test-act-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=act \
|
policy=act \
|
||||||
|
policy.dim_model=64 \
|
||||||
env=aloha \
|
env=aloha \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
@@ -51,9 +54,40 @@ test-act-ete-eval:
|
|||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
test-act-ete-train-amp:
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
policy=act \
|
||||||
|
policy.dim_model=64 \
|
||||||
|
env=aloha \
|
||||||
|
wandb.enable=False \
|
||||||
|
training.offline_steps=2 \
|
||||||
|
training.online_steps=0 \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
device=cpu \
|
||||||
|
training.save_model=true \
|
||||||
|
training.save_freq=2 \
|
||||||
|
policy.n_action_steps=20 \
|
||||||
|
policy.chunk_size=20 \
|
||||||
|
training.batch_size=2 \
|
||||||
|
hydra.run.dir=tests/outputs/act/ \
|
||||||
|
use_amp=true
|
||||||
|
|
||||||
|
test-act-ete-eval-amp:
|
||||||
|
python lerobot/scripts/eval.py \
|
||||||
|
-p tests/outputs/act/checkpoints/000002 \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
env.episode_length=8 \
|
||||||
|
device=cpu \
|
||||||
|
use_amp=true
|
||||||
|
|
||||||
test-diffusion-ete-train:
|
test-diffusion-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=diffusion \
|
policy=diffusion \
|
||||||
|
policy.down_dims=\[64,128,256\] \
|
||||||
|
policy.diffusion_step_embed_dim=32 \
|
||||||
|
policy.num_inference_steps=10 \
|
||||||
env=pusht \
|
env=pusht \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
@@ -74,6 +108,7 @@ test-diffusion-ete-eval:
|
|||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
||||||
test-tdmpc-ete-train:
|
test-tdmpc-ete-train:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
policy=tdmpc \
|
policy=tdmpc \
|
||||||
@@ -82,7 +117,7 @@ test-tdmpc-ete-train:
|
|||||||
dataset_repo_id=lerobot/xarm_lift_medium \
|
dataset_repo_id=lerobot/xarm_lift_medium \
|
||||||
wandb.enable=False \
|
wandb.enable=False \
|
||||||
training.offline_steps=2 \
|
training.offline_steps=2 \
|
||||||
training.online_steps=2 \
|
training.online_steps=0 \
|
||||||
eval.n_episodes=1 \
|
eval.n_episodes=1 \
|
||||||
eval.batch_size=1 \
|
eval.batch_size=1 \
|
||||||
env.episode_length=2 \
|
env.episode_length=2 \
|
||||||
@@ -100,7 +135,6 @@ test-tdmpc-ete-eval:
|
|||||||
env.episode_length=8 \
|
env.episode_length=8 \
|
||||||
device=cpu \
|
device=cpu \
|
||||||
|
|
||||||
|
|
||||||
test-default-ete-eval:
|
test-default-ete-eval:
|
||||||
python lerobot/scripts/eval.py \
|
python lerobot/scripts/eval.py \
|
||||||
--config lerobot/configs/default.yaml \
|
--config lerobot/configs/default.yaml \
|
||||||
|
|||||||
@@ -43,8 +43,7 @@ def get_cameras(hdf5_data):
|
|||||||
|
|
||||||
|
|
||||||
def check_format(raw_dir) -> bool:
|
def check_format(raw_dir) -> bool:
|
||||||
# only frames from simulation are uncompressed
|
compressed_images = None
|
||||||
compressed_images = "sim" not in raw_dir.name
|
|
||||||
|
|
||||||
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
|
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
|
||||||
assert len(hdf5_paths) != 0
|
assert len(hdf5_paths) != 0
|
||||||
@@ -62,18 +61,20 @@ def check_format(raw_dir) -> bool:
|
|||||||
for camera in get_cameras(data):
|
for camera in get_cameras(data):
|
||||||
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
|
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
|
||||||
|
|
||||||
if compressed_images:
|
assert data[f"/observations/images/{camera}"].ndim in [2, 4]
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
if data[f"/observations/images/{camera}"].ndim == 2:
|
||||||
|
assert compressed_images is None or compressed_images
|
||||||
|
compressed_images = True
|
||||||
else:
|
else:
|
||||||
|
assert compressed_images is None or not compressed_images
|
||||||
|
compressed_images = False
|
||||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||||
|
return compressed_images
|
||||||
|
|
||||||
|
|
||||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
def load_from_raw(raw_dir, out_dir, fps, video, debug, compressed_images):
|
||||||
# only frames from simulation are uncompressed
|
|
||||||
compressed_images = "sim" not in raw_dir.name
|
|
||||||
|
|
||||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
||||||
ep_dicts = []
|
ep_dicts = []
|
||||||
episode_data_index = {"from": [], "to": []}
|
episode_data_index = {"from": [], "to": []}
|
||||||
@@ -199,12 +200,12 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
|||||||
|
|
||||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||||
# sanity check
|
# sanity check
|
||||||
check_format(raw_dir)
|
compressed_images = check_format(raw_dir)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
fps = 50
|
fps = 50
|
||||||
|
|
||||||
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug)
|
data_dir, episode_data_index = load_from_raw(raw_dir, out_dir, fps, video, debug, compressed_images)
|
||||||
hf_dataset = to_hf_dataset(data_dir, video)
|
hf_dataset = to_hf_dataset(data_dir, video)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ hydra:
|
|||||||
name: default
|
name: default
|
||||||
|
|
||||||
device: cuda # cpu
|
device: cuda # cpu
|
||||||
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
|
# automatic gradient scaling is used.
|
||||||
|
use_amp: false
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: ???
|
seed: ???
|
||||||
@@ -17,6 +20,7 @@ dataset_repo_id: lerobot/pusht
|
|||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: ???
|
offline_steps: ???
|
||||||
|
# NOTE: `online_steps` is not implemented yet. It's here as a placeholder.
|
||||||
online_steps: ???
|
online_steps: ???
|
||||||
online_steps_between_rollouts: ???
|
online_steps_between_rollouts: ???
|
||||||
online_sampling_ratio: 0.5
|
online_sampling_ratio: 0.5
|
||||||
|
|||||||
14
lerobot/configs/env/aloha_thom.yaml
vendored
Normal file
14
lerobot/configs/env/aloha_thom.yaml
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
fps: 50
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: aloha
|
||||||
|
task: AlohaInsertion-v0
|
||||||
|
from_pixels: True
|
||||||
|
pixels_only: False
|
||||||
|
image_size: [3, 480, 640]
|
||||||
|
episode_length: 500
|
||||||
|
fps: ${fps}
|
||||||
|
state_dim: 6
|
||||||
|
action_dim: 6
|
||||||
77
lerobot/configs/policy/act_thom.yaml
Normal file
77
lerobot/configs/policy/act_thom.yaml
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/aloha_sim_insertion_human
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 20000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: 100000
|
||||||
|
save_freq: 200
|
||||||
|
log_freq: 200
|
||||||
|
save_model: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images: [3, 480, 640]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.front: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
||||||
@@ -5,7 +5,8 @@ dataset_repo_id: lerobot/xarm_lift_medium
|
|||||||
|
|
||||||
training:
|
training:
|
||||||
offline_steps: 25000
|
offline_steps: 25000
|
||||||
online_steps: 25000
|
# TODO(alexander-soare): uncomment when online training gets reinstated
|
||||||
|
online_steps: 0 # 25000 not implemented yet
|
||||||
eval_freq: 5000
|
eval_freq: 5000
|
||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
online_sampling_ratio: 0.5
|
online_sampling_ratio: 0.5
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -520,7 +521,7 @@ def eval(
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(hydra_cfg.device, log=True)
|
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
@@ -539,16 +540,17 @@ def eval(
|
|||||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
info = eval_policy(
|
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
|
||||||
env,
|
info = eval_policy(
|
||||||
policy,
|
env,
|
||||||
hydra_cfg.eval.n_episodes,
|
policy,
|
||||||
max_episodes_rendered=10,
|
hydra_cfg.eval.n_episodes,
|
||||||
video_dir=Path(out_dir) / "eval",
|
max_episodes_rendered=10,
|
||||||
start_seed=hydra_cfg.seed,
|
video_dir=Path(out_dir) / "eval",
|
||||||
enable_progbar=True,
|
start_seed=hydra_cfg.seed,
|
||||||
enable_inner_progbar=True,
|
enable_progbar=True,
|
||||||
)
|
enable_inner_progbar=True,
|
||||||
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
|
||||||
# Save info
|
# Save info
|
||||||
|
|||||||
@@ -15,15 +15,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
@@ -31,6 +31,7 @@ from lerobot.common.envs.factory import make_env
|
|||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
|
||||||
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
from lerobot.common.utils.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
@@ -69,7 +70,6 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||||||
cfg.training.adam_eps,
|
cfg.training.adam_eps,
|
||||||
cfg.training.adam_weight_decay,
|
cfg.training.adam_weight_decay,
|
||||||
)
|
)
|
||||||
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
|
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
@@ -87,21 +87,40 @@ def make_optimizer_and_scheduler(cfg, policy):
|
|||||||
return optimizer, lr_scheduler
|
return optimizer, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
def update_policy(
|
||||||
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
grad_clip_norm,
|
||||||
|
grad_scaler: GradScaler,
|
||||||
|
lr_scheduler=None,
|
||||||
|
use_amp: bool = False,
|
||||||
|
):
|
||||||
"""Returns a dictionary of items for logging."""
|
"""Returns a dictionary of items for logging."""
|
||||||
start_time = time.time()
|
start_time = time.perf_counter()
|
||||||
|
device = get_device_from_parameters(policy)
|
||||||
policy.train()
|
policy.train()
|
||||||
output_dict = policy.forward(batch)
|
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
output_dict = policy.forward(batch)
|
||||||
loss = output_dict["loss"]
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
loss.backward()
|
loss = output_dict["loss"]
|
||||||
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||||
|
grad_scaler.unscale_(optimizer)
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.parameters(),
|
policy.parameters(),
|
||||||
grad_clip_norm,
|
grad_clip_norm,
|
||||||
error_if_nonfinite=False,
|
error_if_nonfinite=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer.step()
|
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||||
|
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||||
|
grad_scaler.step(optimizer)
|
||||||
|
# Updates the scale for next iteration.
|
||||||
|
grad_scaler.update()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if lr_scheduler is not None:
|
if lr_scheduler is not None:
|
||||||
@@ -115,7 +134,7 @@ def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
|||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": optimizer.param_groups[0]["lr"],
|
"lr": optimizer.param_groups[0]["lr"],
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.perf_counter() - start_time,
|
||||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -211,103 +230,6 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
|
|||||||
logger.log_dict(info, step, mode="eval")
|
logger.log_dict(info, step, mode="eval")
|
||||||
|
|
||||||
|
|
||||||
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
|
||||||
"""
|
|
||||||
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- n_off (int): Number of offline samples, each with a sampling weight of 1.
|
|
||||||
- n_on (int): Number of online samples.
|
|
||||||
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
|
|
||||||
|
|
||||||
The total weight of offline samples is n_off * 1.0.
|
|
||||||
The total weight of offline samples is n_on * w.
|
|
||||||
The total combined weight of all samples is n_off + n_on * w.
|
|
||||||
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
|
|
||||||
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
|
|
||||||
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
|
|
||||||
"""
|
|
||||||
assert 0.0 <= pc_on <= 1.0
|
|
||||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
|
||||||
|
|
||||||
|
|
||||||
def add_episodes_inplace(
|
|
||||||
online_dataset: torch.utils.data.Dataset,
|
|
||||||
concat_dataset: torch.utils.data.ConcatDataset,
|
|
||||||
sampler: torch.utils.data.WeightedRandomSampler,
|
|
||||||
hf_dataset: datasets.Dataset,
|
|
||||||
episode_data_index: dict[str, torch.Tensor],
|
|
||||||
pc_online_samples: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
|
||||||
new episodes from hf_dataset into the online_dataset, updating the concatenated
|
|
||||||
dataset's structure and adjusting the sampling strategy based on the specified
|
|
||||||
percentage of online samples.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
|
||||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
|
||||||
offline and online datasets, used for sampling purposes.
|
|
||||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
|
||||||
reflect changes in the dataset sizes and specified sampling weights.
|
|
||||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
|
||||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
|
||||||
They indicate the start index and end index of each episode in the dataset.
|
|
||||||
- pc_online_samples (float): The target percentage of samples that should come from
|
|
||||||
the online dataset during sampling operations.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
|
||||||
"""
|
|
||||||
first_episode_idx = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
|
||||||
last_episode_idx = hf_dataset.select_columns("episode_index")[-1]["episode_index"].item()
|
|
||||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
|
||||||
last_index = hf_dataset.select_columns("index")[-1]["index"].item()
|
|
||||||
# sanity check
|
|
||||||
assert first_episode_idx == 0, f"{first_episode_idx=} is not 0"
|
|
||||||
assert first_index == 0, f"{first_index=} is not 0"
|
|
||||||
assert first_index == episode_data_index["from"][first_episode_idx].item()
|
|
||||||
assert last_index == episode_data_index["to"][last_episode_idx].item() - 1
|
|
||||||
|
|
||||||
if len(online_dataset) == 0:
|
|
||||||
# initialize online dataset
|
|
||||||
online_dataset.hf_dataset = hf_dataset
|
|
||||||
online_dataset.episode_data_index = episode_data_index
|
|
||||||
else:
|
|
||||||
# get the starting indices of the new episodes and frames to be added
|
|
||||||
start_episode_idx = last_episode_idx + 1
|
|
||||||
start_index = last_index + 1
|
|
||||||
|
|
||||||
def shift_indices(episode_index, index):
|
|
||||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
|
||||||
example = {"episode_index": episode_index + start_episode_idx, "index": index + start_index}
|
|
||||||
return example
|
|
||||||
|
|
||||||
disable_progress_bars() # map has a tqdm progress bar
|
|
||||||
hf_dataset = hf_dataset.map(shift_indices, input_columns=["episode_index", "index"])
|
|
||||||
enable_progress_bars()
|
|
||||||
|
|
||||||
episode_data_index["from"] += start_index
|
|
||||||
episode_data_index["to"] += start_index
|
|
||||||
|
|
||||||
# extend online dataset
|
|
||||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
|
||||||
|
|
||||||
# update the concatenated dataset length used during sampling
|
|
||||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
|
||||||
|
|
||||||
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
|
|
||||||
len_online = len(online_dataset)
|
|
||||||
len_offline = len(concat_dataset) - len_online
|
|
||||||
weight_offline = 1.0
|
|
||||||
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
|
|
||||||
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
|
|
||||||
|
|
||||||
# update the total number of samples used during sampling
|
|
||||||
sampler.num_samples = len(concat_dataset)
|
|
||||||
|
|
||||||
|
|
||||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -316,11 +238,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
|
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
if cfg.training.online_steps > 0 and cfg.eval.batch_size > 1:
|
if cfg.training.online_steps > 0:
|
||||||
logging.warning("eval.batch_size > 1 not supported for online training steps")
|
raise NotImplementedError("Online training is not implemented yet.")
|
||||||
|
|
||||||
# Check device is available
|
# Check device is available
|
||||||
get_safe_torch_device(cfg.device, log=True)
|
device = get_safe_torch_device(cfg.device, log=True)
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
@@ -338,6 +260,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
# Temporary hack to move optimizer out of policy
|
# Temporary hack to move optimizer out of policy
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
|
grad_scaler = GradScaler(enabled=cfg.use_amp)
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
@@ -358,14 +281,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
def evaluate_and_checkpoint_if_needed(step):
|
def evaluate_and_checkpoint_if_needed(step):
|
||||||
if step % cfg.training.eval_freq == 0:
|
if step % cfg.training.eval_freq == 0:
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
eval_info = eval_policy(
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||||
eval_env,
|
eval_info = eval_policy(
|
||||||
policy,
|
eval_env,
|
||||||
cfg.eval.n_episodes,
|
policy,
|
||||||
video_dir=Path(out_dir) / "eval",
|
cfg.eval.n_episodes,
|
||||||
max_episodes_rendered=4,
|
video_dir=Path(out_dir) / "eval",
|
||||||
start_seed=cfg.seed,
|
max_episodes_rendered=4,
|
||||||
)
|
start_seed=cfg.seed,
|
||||||
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
@@ -389,36 +313,38 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
step = 0 # number of policy update (forward + backward + optim)
|
|
||||||
is_offline = True
|
is_offline = True
|
||||||
for offline_step in range(cfg.training.offline_steps):
|
for offline_step in tqdm(range(cfg.training.offline_steps)):
|
||||||
if offline_step == 0:
|
if offline_step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
|
||||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
train_info = update_policy(
|
||||||
|
policy,
|
||||||
|
batch,
|
||||||
|
optimizer,
|
||||||
|
cfg.training.grad_clip_norm,
|
||||||
|
grad_scaler=grad_scaler,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
use_amp=cfg.use_amp,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.training.log_freq == 0:
|
if offline_step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
log_train_info(logger, train_info, offline_step, cfg, offline_dataset, is_offline)
|
||||||
|
|
||||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||||
# so we pass in step + 1.
|
# so we pass in step + 1.
|
||||||
evaluate_and_checkpoint_if_needed(step + 1)
|
evaluate_and_checkpoint_if_needed(offline_step + 1)
|
||||||
|
|
||||||
step += 1
|
|
||||||
|
|
||||||
# create an env dedicated to online episodes collection from policy rollout
|
|
||||||
online_training_env = make_env(cfg, n_envs=1)
|
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
@@ -436,58 +362,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=cfg.device != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
|
||||||
|
|
||||||
online_step = 0
|
|
||||||
is_offline = False
|
|
||||||
for env_step in range(cfg.training.online_steps):
|
|
||||||
if env_step == 0:
|
|
||||||
logging.info("Start online training by interacting with environment")
|
|
||||||
|
|
||||||
policy.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
eval_info = eval_policy(
|
|
||||||
online_training_env,
|
|
||||||
policy,
|
|
||||||
n_episodes=1,
|
|
||||||
return_episode_data=True,
|
|
||||||
start_seed=cfg.training.online_env_seed,
|
|
||||||
enable_progbar=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
add_episodes_inplace(
|
|
||||||
online_dataset,
|
|
||||||
concat_dataset,
|
|
||||||
sampler,
|
|
||||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
|
||||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
|
||||||
pc_online_samples=cfg.training.online_sampling_ratio,
|
|
||||||
)
|
|
||||||
|
|
||||||
policy.train()
|
|
||||||
for _ in range(cfg.training.online_steps_between_rollouts):
|
|
||||||
batch = next(dl_iter)
|
|
||||||
|
|
||||||
for key in batch:
|
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
|
||||||
|
|
||||||
train_info = update_policy(policy, batch, optimizer, cfg.training.grad_clip_norm, lr_scheduler)
|
|
||||||
|
|
||||||
if step % cfg.training.log_freq == 0:
|
|
||||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
|
||||||
|
|
||||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
|
||||||
# so we pass in step + 1.
|
|
||||||
evaluate_and_checkpoint_if_needed(step + 1)
|
|
||||||
|
|
||||||
step += 1
|
|
||||||
online_step += 1
|
|
||||||
|
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
online_training_env.close()
|
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user