forked from tangger/lerobot
[Port HIL_SERL] Final fixes for the Reward Classifier (#598)
This commit is contained in:
committed by
Michel Aractingi
parent
e5801f467f
commit
d1d6ffd23c
@@ -4,7 +4,6 @@ from typing import Optional
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
@@ -44,6 +43,8 @@ class Classifier(
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
|
||||
@@ -333,7 +333,6 @@ class Critic(nn.Module):
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user