forked from tangger/lerobot
rename reward classifier
This commit is contained in:
@@ -24,10 +24,10 @@ from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -64,8 +64,8 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "hilserl_classifier":
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
elif name == "reward_classifier":
|
||||
from lerobot.common.policies.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
else:
|
||||
@@ -85,8 +85,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "hilserl_classifier":
|
||||
return ClassifierConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ def create_stats_buffers(
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats and key in stats:
|
||||
# NOTE:(maractingi, azouitine): Change the order of these conditions becuase in online environments we don't have dataset stats
|
||||
# NOTE:(maractingi, azouitine): Change the order of these conditions because in online environments we don't have dataset stats
|
||||
# Therefore, we don't access to full stats of the data, some elements either have min-max or mean-std only
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||
|
||||
@@ -7,12 +7,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class ClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Classifier model."""
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "hilserl_classifier"
|
||||
name: str = "reward_classifier"
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
@@ -5,11 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_IMAGE
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,12 +37,12 @@ class ClassifierOutput:
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "hilserl_classifier"
|
||||
config_class = ClassifierConfig
|
||||
name = "reward_classifier"
|
||||
config_class = RewardClassifierConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ClassifierConfig,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
Reference in New Issue
Block a user