[Port HIL_SERL] Final fixes for the Reward Classifier (#598)

This commit is contained in:
Eugene Mironov
2025-01-06 17:34:00 +07:00
committed by GitHub
parent 35de91ef2b
commit c5bca1cf0f
11 changed files with 59 additions and 19 deletions

View File

@@ -246,7 +246,7 @@ def record(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)

View File

@@ -183,8 +183,14 @@ def record(
resume: bool = False,
local_files_only: bool = False,
run_compute_stats: bool = True,
assign_rewards: bool = False,
) -> LeRobotDataset:
# Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
)
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@@ -197,7 +203,7 @@ def record(
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
# initialize listener before sim env
listener, events = init_keyboard_listener()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
# create sim env
env = env()
@@ -237,6 +243,7 @@ def record(
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
@@ -288,6 +295,13 @@ def record(
"timestamp": env_timestamp,
}
# Overwrite environment reward with manually assigned reward
if assign_rewards:
frame["next.reward"] = events["next.reward"]
# Should success always be false to match what we do in control_utils?
frame["next.success"] = False
for key in image_keys:
if not key.startswith("observation.image"):
frame["observation.image." + key] = observation[key]
@@ -472,6 +486,13 @@ if __name__ == "__main__":
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"

View File

@@ -45,7 +45,7 @@ from lerobot.common.utils.utils import (
)
def get_model(cfg, logger):
def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
@@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg):
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
# Check if the device supports AMP
# Here is an example of the issue that says that MPS doesn't support AMP properply
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
# Single epoch training loop with AMP support and progress tracking
model.train()
@@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
outputs = model(images)
loss = criterion(outputs.logits, labels)
@@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
samples = []
running_loss = 0
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = batch[cfg.training.image_key].to(device)
labels = batch[cfg.training.label_key].float().to(device)
@@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
return accuracy, eval_info
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))