diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index b73ba5f4e..8b3c25194 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .act.configuration_act import ACTConfig as ACTConfig -from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig -from .pi0.configuration_pi0 import PI0Config as PI0Config -from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig -from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig +from . import act, diffusion, pi0, tdmpc, vqbet +from .factory import make_policy + +__all__ = ["act", "diffusion", "pi0", "tdmpc", "vqbet", "make_policy"] diff --git a/lerobot/common/policies/act/__init__.py b/lerobot/common/policies/act/__init__.py new file mode 100644 index 000000000..8b3b2ba2d --- /dev/null +++ b/lerobot/common/policies/act/__init__.py @@ -0,0 +1,4 @@ +from .configuration_act import ACTConfig +from .modeling_act import ACT + +__all__ = ["ACTConfig", "ACT"] diff --git a/lerobot/common/policies/diffusion/__init__.py b/lerobot/common/policies/diffusion/__init__.py new file mode 100644 index 000000000..092ea4636 --- /dev/null +++ b/lerobot/common/policies/diffusion/__init__.py @@ -0,0 +1,4 @@ +from .configuration_diffusion import DiffusionConfig +from .modeling_diffusion import DiffusionPolicy + +__all__ = ["DiffusionConfig", "DiffusionPolicy"] diff --git a/lerobot/common/policies/pi0/__init__.py b/lerobot/common/policies/pi0/__init__.py new file mode 100644 index 000000000..077dee700 --- /dev/null +++ b/lerobot/common/policies/pi0/__init__.py @@ -0,0 +1,4 @@ +from .configuration_pi0 import PI0Config +from .modeling_pi0 import PI0Policy + +__all__ = ["PI0Config", "PI0Policy"] diff --git a/lerobot/common/policies/tdmpc/__init__.py b/lerobot/common/policies/tdmpc/__init__.py new file mode 100644 index 000000000..83b7b14a4 --- /dev/null +++ b/lerobot/common/policies/tdmpc/__init__.py @@ -0,0 +1,4 @@ +from .configuration_tdmpc import TDMPCConfig +from .modeling_tdmpc import TDMPCPolicy + +__all__ = ["TDMPCConfig", "TDMPCPolicy"] diff --git a/lerobot/common/policies/vqbet/__init__.py b/lerobot/common/policies/vqbet/__init__.py new file mode 100644 index 000000000..600b604f2 --- /dev/null +++ b/lerobot/common/policies/vqbet/__init__.py @@ -0,0 +1,4 @@ +from .configuration_vqbet import VQBeTConfig +from .modeling_vqbet import VQBeTPolicy + +__all__ = ["VQBeTConfig", "VQBeTPolicy"]