forked from tangger/lerobot
Cleanup policies imports
This commit is contained in:
@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransforms
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
IMAGENET_STATS = {
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from . import act, diffusion, pi0, tdmpc, vqbet
|
||||
from .config import PreTrainedConfig
|
||||
from .factory import make_policy
|
||||
from .pretrained import PreTrainedPolicy
|
||||
|
||||
__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy"]
|
||||
__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy", "PreTrainedConfig", "PreTrainedPolicy"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .configuration_act import ACTConfig
|
||||
from .modeling_act import ACT
|
||||
from .modeling_act import ACTPolicy
|
||||
|
||||
__all__ = ["ACTConfig", "ACT"]
|
||||
__all__ = ["ACTConfig", "ACTPolicy"]
|
||||
|
||||
@@ -22,53 +22,59 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
from .config import PreTrainedConfig
|
||||
from .pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
return DiffusionPolicy
|
||||
elif name == "act":
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
if name == "act":
|
||||
from .act import ACTPolicy
|
||||
|
||||
return ACTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
elif name == "diffusion":
|
||||
from .diffusion import DiffusionPolicy
|
||||
|
||||
return VQBeTPolicy
|
||||
return DiffusionPolicy
|
||||
elif name == "pi0":
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from .pi0 import PI0Policy
|
||||
|
||||
return PI0Policy
|
||||
elif name == "tdmpc":
|
||||
from .tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy
|
||||
elif name == "vqbet":
|
||||
from .vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
if policy_type == "tdmpc":
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
if policy_type == "act":
|
||||
from .act import ACTConfig
|
||||
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
from .diffusion import DiffusionConfig
|
||||
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
from .pi0 import PI0Config
|
||||
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "tdmpc":
|
||||
from .tdmpc import TDMPCConfig
|
||||
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
from .vqbet import VQBeTConfig
|
||||
|
||||
return VQBeTConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
@@ -19,8 +19,7 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||
|
||||
|
||||
def display(tensor: torch.Tensor):
|
||||
|
||||
@@ -26,9 +26,10 @@ from safetensors.torch import load_model as load_model_as_safetensor
|
||||
from safetensors.torch import save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
|
||||
from .config import PreTrainedConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
DEFAULT_POLICY_CARD = """
|
||||
|
||||
@@ -17,7 +17,7 @@ from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.configs import parser
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common import envs, policies # noqa: F401
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import EvalConfig
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
from lerobot.common import envs
|
||||
from lerobot.common.optim import OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.common.utils.hub import HubMixin
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||
|
||||
@@ -141,7 +141,7 @@ from pprint import pformat
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlPipelineConfig,
|
||||
|
||||
@@ -67,8 +67,7 @@ from tqdm import trange
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies import PreTrainedPolicy, make_policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
|
||||
@@ -29,8 +29,7 @@ from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies import PreTrainedPolicy, make_policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
|
||||
@@ -21,7 +21,8 @@ from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.common.policies import make_policy
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
@@ -41,9 +41,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.config import PreTrainedConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
RecordControlConfig,
|
||||
|
||||
Reference in New Issue
Block a user