[Port HIL_SERL] Final fixes for the Reward Classifier (#598)
This commit is contained in:
@@ -246,7 +246,7 @@ def record(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
|
||||
@@ -183,8 +183,14 @@ def record(
|
||||
resume: bool = False,
|
||||
local_files_only: bool = False,
|
||||
run_compute_stats: bool = True,
|
||||
assign_rewards: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
# Load pretrained policy
|
||||
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
policy = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
@@ -197,7 +203,7 @@ def record(
|
||||
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||
|
||||
# initialize listener before sim env
|
||||
listener, events = init_keyboard_listener()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
# create sim env
|
||||
env = env()
|
||||
@@ -237,6 +243,7 @@ def record(
|
||||
}
|
||||
|
||||
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
||||
features = {**features, **extra_features}
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
@@ -288,6 +295,13 @@ def record(
|
||||
"timestamp": env_timestamp,
|
||||
}
|
||||
|
||||
# Overwrite environment reward with manually assigned reward
|
||||
if assign_rewards:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
|
||||
# Should success always be false to match what we do in control_utils?
|
||||
frame["next.success"] = False
|
||||
|
||||
for key in image_keys:
|
||||
if not key.startswith("observation.image"):
|
||||
frame["observation.image." + key] = observation[key]
|
||||
@@ -472,6 +486,13 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="Resume recording on an existing dataset.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--assign-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
|
||||
@@ -45,7 +45,7 @@ from lerobot.common.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def get_model(cfg, logger):
|
||||
def get_model(cfg, logger): # noqa I001
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
@@ -64,6 +64,12 @@ def create_balanced_sampler(dataset, cfg):
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
# Check if the device supports AMP
|
||||
# Here is an example of the issue that says that MPS doesn't support AMP properply
|
||||
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
@@ -77,7 +83,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
@@ -119,7 +125,10 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||
samples = []
|
||||
running_loss = 0
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
@@ -170,7 +179,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
|
||||
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
Reference in New Issue
Block a user