[Port HIL_SERL] Final fixes for the Reward Classifier (#598)
This commit is contained in:
committed by
Michel Aractingi
parent
e5801f467f
commit
d1d6ffd23c
@@ -4,7 +4,6 @@ from typing import Optional
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
@@ -44,6 +43,8 @@ class Classifier(
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
|
||||
@@ -333,7 +333,6 @@ class Critic(nn.Module):
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -342,12 +342,16 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
features_from_robot.update(extra_features)
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
("features", dataset.features, features_from_robot),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
|
||||
Reference in New Issue
Block a user