diff --git a/.gitignore b/.gitignore index 8001b6950..096cfb50d 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,7 @@ celerybeat.pid # Environments .env .venv +env/ venv/ env.bak/ venv.bak/ diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 026917018..3427c4829 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -38,7 +38,13 @@ from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize -class ACTPolicy(nn.Module, PyTorchModelHubMixin): +class ACTPolicy( + nn.Module, + PyTorchModelHubMixin, + library_name="lerobot", + repo_url="https://github.com/huggingface/lerobot", + tags=["robotics", "act"], +): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 486085374..308a8be3c 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -43,7 +43,13 @@ from lerobot.common.policies.utils import ( ) -class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): +class DiffusionPolicy( + nn.Module, + PyTorchModelHubMixin, + library_name="lerobot", + repo_url="https://github.com/huggingface/lerobot", + tags=["robotics", "diffusion-policy"], +): """ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 7dbffcefc..d97c4824c 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -41,7 +41,13 @@ from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.utils import get_device_from_parameters, populate_queues -class TDMPCPolicy(nn.Module, PyTorchModelHubMixin): +class TDMPCPolicy( + nn.Module, + PyTorchModelHubMixin, + library_name="lerobot", + repo_url="https://github.com/huggingface/lerobot", + tags=["robotics", "tdmpc"], +): """Implementation of TD-MPC learning + inference. Please note several warnings for this policy. diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index a73acb4f3..87cf59f19 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -38,7 +38,13 @@ from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ # ruff: noqa: N806 -class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): +class VQBeTPolicy( + nn.Module, + PyTorchModelHubMixin, + library_name="lerobot", + repo_url="https://github.com/huggingface/lerobot", + tags=["robotics", "vqbet"], +): """ VQ-BeT Policy as per "Behavior Generation with Latent Actions" """