forked from tangger/lerobot
Fixes for the reward classifier
This commit is contained in:
committed by
Michel Aractingi
parent
54c3c6d684
commit
3b24ad3c84
@@ -225,12 +225,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
|||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
push_to_hub: bool = True
|
push_to_hub: bool = True
|
||||||
pretrained_policy_name_or_path: Optional[str] = None
|
pretrained_policy_name_or_path: Optional[str] = None
|
||||||
reward_classifier: dict[str, str | None] = field(
|
reward_classifier_pretrained_path: Optional[str] = None
|
||||||
default_factory=lambda: {
|
|
||||||
"pretrained_path": None,
|
|
||||||
"config_path": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
@@ -271,12 +266,7 @@ class ManiskillEnvConfig(EnvConfig):
|
|||||||
"observation.state": OBS_ROBOT,
|
"observation.state": OBS_ROBOT,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
reward_classifier: dict[str, str | None] = field(
|
reward_classifier_pretrained_path: Optional[str] = None
|
||||||
default_factory=lambda: {
|
|
||||||
"pretrained_path": None,
|
|
||||||
"config_path": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
|
|||||||
@@ -20,10 +20,13 @@ class ClassifierConfig(PreTrainedConfig):
|
|||||||
model_type: str = "cnn" # "transformer" or "cnn"
|
model_type: str = "cnn" # "transformer" or "cnn"
|
||||||
num_cameras: int = 2
|
num_cameras: int = 2
|
||||||
learning_rate: float = 1e-4
|
learning_rate: float = 1e-4
|
||||||
normalization_mode = None
|
weight_decay: float = 0.01
|
||||||
# output_features: Dict[str, PolicyFeature] = field(
|
grad_clip_norm: float = 1.0
|
||||||
# default_factory=lambda: {"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,))}
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
# )
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> List | None:
|
def observation_delta_indices(self) -> List | None:
|
||||||
@@ -40,8 +43,8 @@ class ClassifierConfig(PreTrainedConfig):
|
|||||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||||
return AdamWConfig(
|
return AdamWConfig(
|
||||||
lr=self.learning_rate,
|
lr=self.learning_rate,
|
||||||
weight_decay=0.01,
|
weight_decay=self.weight_decay,
|
||||||
grad_clip_norm=1.0,
|
grad_clip_norm=self.grad_clip_norm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||||
@@ -49,5 +52,8 @@ class ClassifierConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate feature configurations."""
|
"""Validate feature configurations."""
|
||||||
# Classifier doesn't need specific feature validation
|
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||||
pass
|
if not has_image:
|
||||||
|
raise ValueError(
|
||||||
|
"You must provide an image observation (key starting with 'observation.image') in the input features"
|
||||||
|
)
|
||||||
|
|||||||
@@ -139,11 +139,7 @@ class Classifier(PreTrainedPolicy):
|
|||||||
|
|
||||||
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
||||||
"""Extract image tensors and label tensors from batch."""
|
"""Extract image tensors and label tensors from batch."""
|
||||||
# Find image keys in input features
|
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||||
image_keys = [key for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
|
||||||
|
|
||||||
# Extract the images and labels
|
|
||||||
images = [batch[key] for key in image_keys]
|
|
||||||
labels = batch["next.reward"]
|
labels = batch["next.reward"]
|
||||||
|
|
||||||
return images, labels
|
return images, labels
|
||||||
@@ -197,9 +193,9 @@ class Classifier(PreTrainedPolicy):
|
|||||||
|
|
||||||
return loss, output_dict
|
return loss, output_dict
|
||||||
|
|
||||||
def predict_reward(self, batch, threshold=0.6):
|
def predict_reward(self, batch, threshold=0.5):
|
||||||
"""Legacy method for compatibility."""
|
"""Legacy method for compatibility."""
|
||||||
images, _ = self.extract_images_and_labels(batch)
|
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||||
if self.config.num_classes == 2:
|
if self.config.num_classes == 2:
|
||||||
probs = self.predict(images).probabilities
|
probs = self.predict(images).probabilities
|
||||||
logging.debug(f"Predicted reward images: {probs}")
|
logging.debug(f"Predicted reward images: {probs}")
|
||||||
@@ -207,8 +203,6 @@ class Classifier(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||||
|
|
||||||
# Methods required by PreTrainedPolicy abstract class
|
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
"""Return optimizer parameters for the policy."""
|
"""Return optimizer parameters for the policy."""
|
||||||
return {
|
return {
|
||||||
@@ -217,21 +211,16 @@ class Classifier(PreTrainedPolicy):
|
|||||||
"weight_decay": getattr(self.config, "weight_decay", 0.01),
|
"weight_decay": getattr(self.config, "weight_decay", 0.01),
|
||||||
}
|
}
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Reset any stateful components (required by PreTrainedPolicy)."""
|
|
||||||
# Classifier doesn't have stateful components that need resetting
|
|
||||||
pass
|
|
||||||
|
|
||||||
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
"""Return action (class prediction) based on input observation."""
|
"""
|
||||||
images, _ = self.extract_images_and_labels(batch)
|
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||||
|
The reward classifier is not an actor and does not select actions.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Reward classifiers do not select actions")
|
||||||
|
|
||||||
with torch.no_grad():
|
def reset(self):
|
||||||
outputs = self.predict(images)
|
"""
|
||||||
|
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||||
if self.config.num_classes == 2:
|
The reward classifier is not an actor and does not select actions.
|
||||||
# For binary classification return 0 or 1
|
"""
|
||||||
return (outputs.probabilities > 0.5).float()
|
pass
|
||||||
else:
|
|
||||||
# For multi-class return the predicted class
|
|
||||||
return torch.argmax(outputs.probabilities, dim=1)
|
|
||||||
|
|||||||
@@ -362,20 +362,20 @@ class RewardWrapper(gym.Wrapper):
|
|||||||
"""
|
"""
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
# NOTE: We got 15% speedup by compiling the model
|
|
||||||
self.reward_classifier = torch.compile(reward_classifier)
|
|
||||||
|
|
||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
self.reward_classifier = torch.compile(reward_classifier)
|
||||||
|
self.reward_classifier.to(self.device)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation, reward, terminated, truncated, info = self.env.step(action)
|
observation, _, terminated, truncated, info = self.env.step(action)
|
||||||
images = [
|
images = {
|
||||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||||
for key in observation
|
for key in observation
|
||||||
if "image" in key
|
if "image" in key
|
||||||
]
|
}
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
success = (
|
success = (
|
||||||
@@ -1184,7 +1184,9 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add reward computation and control wrappers
|
# Add reward computation and control wrappers
|
||||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
reward_classifier = init_reward_classifier(cfg)
|
||||||
|
if reward_classifier is not None:
|
||||||
|
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
|
||||||
if cfg.wrapper.use_gripper:
|
if cfg.wrapper.use_gripper:
|
||||||
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
|
env = GripperActionWrapper(env=env, quantization_threshold=cfg.wrapper.gripper_quantization_threshold)
|
||||||
@@ -1227,26 +1229,34 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def get_classifier(cfg):
|
def init_reward_classifier(cfg):
|
||||||
if (
|
"""
|
||||||
cfg.wrapper.reward_classifier_pretrained_path is None
|
Load a reward classifier policy from a pretrained path if configured.
|
||||||
or cfg.wrapper.reward_classifier_config_file is None
|
|
||||||
):
|
Args:
|
||||||
|
cfg: The environment configuration containing classifier paths
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded classifier model or None if not configured
|
||||||
|
"""
|
||||||
|
if cfg.reward_classifier_pretrained_path is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||||
ClassifierConfig,
|
|
||||||
)
|
# Get device from config or default to CUDA
|
||||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
device = getattr(cfg, "device", "cpu")
|
||||||
Classifier,
|
|
||||||
|
# Load the classifier directly using from_pretrained
|
||||||
|
classifier = Classifier.from_pretrained(
|
||||||
|
pretrained_name_or_path=cfg.reward_classifier_pretrained_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
# Ensure model is on the correct device
|
||||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
classifier.to(device)
|
||||||
model = Classifier(classifier_config)
|
classifier.eval() # Set to evaluation mode
|
||||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
|
||||||
model = model.to(device)
|
return classifier
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def record_dataset(env, policy, cfg):
|
def record_dataset(env, policy, cfg):
|
||||||
|
|||||||
Reference in New Issue
Block a user