Cleanup policies imports
This commit is contained in:
@@ -24,7 +24,7 @@ from lerobot.common.datasets.lerobot_dataset import (
|
|||||||
MultiLeRobotDataset,
|
MultiLeRobotDataset,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.transforms import ImageTransforms
|
from lerobot.common.datasets.transforms import ImageTransforms
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
from lerobot.common.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
|
||||||
IMAGENET_STATS = {
|
IMAGENET_STATS = {
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from . import act, diffusion, pi0, tdmpc, vqbet
|
from . import act, diffusion, pi0, tdmpc, vqbet
|
||||||
|
from .config import PreTrainedConfig
|
||||||
from .factory import make_policy
|
from .factory import make_policy
|
||||||
|
from .pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy"]
|
__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy", "PreTrainedConfig", "PreTrainedPolicy"]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .configuration_act import ACTConfig
|
from .configuration_act import ACTConfig
|
||||||
from .modeling_act import ACT
|
from .modeling_act import ACTPolicy
|
||||||
|
|
||||||
__all__ = ["ACTConfig", "ACT"]
|
__all__ = ["ACTConfig", "ACTPolicy"]
|
||||||
|
|||||||
@@ -22,53 +22,59 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
|||||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||||
from lerobot.common.envs.configs import EnvConfig
|
from lerobot.common.envs.configs import EnvConfig
|
||||||
from lerobot.common.envs.utils import env_to_policy_features
|
from lerobot.common.envs.utils import env_to_policy_features
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
|
||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
|
|
||||||
|
from .config import PreTrainedConfig
|
||||||
|
from .pretrained import PreTrainedPolicy
|
||||||
|
|
||||||
|
|
||||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||||
if name == "tdmpc":
|
if name == "act":
|
||||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
from .act import ACTPolicy
|
||||||
|
|
||||||
return TDMPCPolicy
|
|
||||||
elif name == "diffusion":
|
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
|
||||||
|
|
||||||
return DiffusionPolicy
|
|
||||||
elif name == "act":
|
|
||||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
|
||||||
|
|
||||||
return ACTPolicy
|
return ACTPolicy
|
||||||
elif name == "vqbet":
|
elif name == "diffusion":
|
||||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
from .diffusion import DiffusionPolicy
|
||||||
|
|
||||||
return VQBeTPolicy
|
return DiffusionPolicy
|
||||||
elif name == "pi0":
|
elif name == "pi0":
|
||||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
from .pi0 import PI0Policy
|
||||||
|
|
||||||
return PI0Policy
|
return PI0Policy
|
||||||
|
elif name == "tdmpc":
|
||||||
|
from .tdmpc import TDMPCPolicy
|
||||||
|
|
||||||
|
return TDMPCPolicy
|
||||||
|
elif name == "vqbet":
|
||||||
|
from .vqbet import VQBeTPolicy
|
||||||
|
|
||||||
|
return VQBeTPolicy
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||||
|
|
||||||
|
|
||||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||||
if policy_type == "tdmpc":
|
if policy_type == "act":
|
||||||
return TDMPCConfig(**kwargs)
|
from .act import ACTConfig
|
||||||
elif policy_type == "diffusion":
|
|
||||||
return DiffusionConfig(**kwargs)
|
|
||||||
elif policy_type == "act":
|
|
||||||
return ACTConfig(**kwargs)
|
return ACTConfig(**kwargs)
|
||||||
elif policy_type == "vqbet":
|
elif policy_type == "diffusion":
|
||||||
return VQBeTConfig(**kwargs)
|
from .diffusion import DiffusionConfig
|
||||||
|
|
||||||
|
return DiffusionConfig(**kwargs)
|
||||||
elif policy_type == "pi0":
|
elif policy_type == "pi0":
|
||||||
|
from .pi0 import PI0Config
|
||||||
|
|
||||||
return PI0Config(**kwargs)
|
return PI0Config(**kwargs)
|
||||||
|
elif policy_type == "tdmpc":
|
||||||
|
from .tdmpc import TDMPCConfig
|
||||||
|
|
||||||
|
return TDMPCConfig(**kwargs)
|
||||||
|
elif policy_type == "vqbet":
|
||||||
|
from .vqbet import VQBeTConfig
|
||||||
|
|
||||||
|
return VQBeTConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||||
from lerobot.common.policies.factory import make_policy
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||||
from lerobot.common.policies.factory import make_policy
|
|
||||||
|
|
||||||
|
|
||||||
def display(tensor: torch.Tensor):
|
def display(tensor: torch.Tensor):
|
||||||
|
|||||||
@@ -26,9 +26,10 @@ from safetensors.torch import load_model as load_model_as_safetensor
|
|||||||
from safetensors.torch import save_model as save_model_as_safetensor
|
from safetensors.torch import save_model as save_model_as_safetensor
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
|
|
||||||
|
from .config import PreTrainedConfig
|
||||||
|
|
||||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||||
|
|
||||||
DEFAULT_POLICY_CARD = """
|
DEFAULT_POLICY_CARD = """
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
from lerobot.common.policies import PreTrainedConfig
|
||||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from dataclasses import dataclass, field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lerobot.common import envs, policies # noqa: F401
|
from lerobot.common import envs, policies # noqa: F401
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
from lerobot.common.policies import PreTrainedConfig
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import EvalConfig
|
from lerobot.configs.default import EvalConfig
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError
|
|||||||
from lerobot.common import envs
|
from lerobot.common import envs
|
||||||
from lerobot.common.optim import OptimizerConfig
|
from lerobot.common.optim import OptimizerConfig
|
||||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
from lerobot.common.policies import PreTrainedConfig
|
||||||
from lerobot.common.utils.hub import HubMixin
|
from lerobot.common.utils.hub import HubMixin
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ from pprint import pformat
|
|||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies import make_policy
|
||||||
from lerobot.common.robot_devices.control_configs import (
|
from lerobot.common.robot_devices.control_configs import (
|
||||||
CalibrateControlConfig,
|
CalibrateControlConfig,
|
||||||
ControlPipelineConfig,
|
ControlPipelineConfig,
|
||||||
|
|||||||
@@ -67,8 +67,7 @@ from tqdm import trange
|
|||||||
|
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies import PreTrainedPolicy, make_policy
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
from lerobot.common.utils.io_utils import write_video
|
from lerobot.common.utils.io_utils import write_video
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
|
|||||||
@@ -29,8 +29,7 @@ from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
|||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies import PreTrainedPolicy, make_policy
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|
||||||
from lerobot.common.policies.utils import get_device_from_parameters
|
from lerobot.common.policies.utils import get_device_from_parameters
|
||||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ from safetensors.torch import save_file
|
|||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.common.policies.factory import make_policy, make_policy_config
|
from lerobot.common.policies import make_policy
|
||||||
|
from lerobot.common.policies.factory import make_policy_config
|
||||||
from lerobot.common.utils.random_utils import set_seed
|
from lerobot.common.utils.random_utils import set_seed
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
|||||||
@@ -41,9 +41,8 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.common.policies import PreTrainedConfig, make_policy
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.config import PreTrainedConfig
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
|
||||||
from lerobot.common.robot_devices.control_configs import (
|
from lerobot.common.robot_devices.control_configs import (
|
||||||
CalibrateControlConfig,
|
CalibrateControlConfig,
|
||||||
RecordControlConfig,
|
RecordControlConfig,
|
||||||
|
|||||||
Reference in New Issue
Block a user