Fixes for the reward classifier

This commit is contained in:
Michel Aractingi
2025-04-15 18:12:21 +02:00
committed by Michel Aractingi
parent 54c3c6d684
commit 3b24ad3c84
4 changed files with 64 additions and 69 deletions

View File

@@ -362,20 +362,20 @@ class RewardWrapper(gym.Wrapper):
"""
self.env = env
# NOTE: We got 15% speedup by compiling the model
self.reward_classifier = torch.compile(reward_classifier)
if isinstance(device, str):
device = torch.device(device)
self.device = device
self.reward_classifier = torch.compile(reward_classifier)
self.reward_classifier.to(self.device)
def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action)
images = [
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
observation, _, terminated, truncated, info = self.env.step(action)
images = {
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
if "image" in key
]
}
start_time = time.perf_counter()
with torch.inference_mode():
success = (
@@ -1184,7 +1184,9 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
)
# Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
reward_classifier = init_reward_classifier(cfg)
if reward_classifier is not None:
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.use_gripper:
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
@@ -1227,26 +1229,34 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
return env
def get_classifier(cfg):
if (
cfg.wrapper.reward_classifier_pretrained_path is None
or cfg.wrapper.reward_classifier_config_file is None
):
def init_reward_classifier(cfg):
"""
Load a reward classifier policy from a pretrained path if configured.
Args:
cfg: The environment configuration containing classifier paths
Returns:
The loaded classifier model or None if not configured
"""
if cfg.reward_classifier_pretrained_path is None:
return None
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
# Get device from config or default to CUDA
device = getattr(cfg, "device", "cpu")
# Load the classifier directly using from_pretrained
classifier = Classifier.from_pretrained(
pretrained_name_or_path=cfg.reward_classifier_pretrained_path,
)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device)
return model
# Ensure model is on the correct device
classifier.to(device)
classifier.eval() # Set to evaluation mode
return classifier
def record_dataset(env, policy, cfg):