From 598bb496b0a1a8a9ffff6548dda29dc6cbd08e65 Mon Sep 17 00:00:00 2001 From: Cadene Date: Sun, 25 Feb 2024 10:50:23 +0000 Subject: [PATCH] Add policies/factory, Add test, Add _self_ in config --- README.md | 7 ++----- lerobot/common/policies/__init__.py | 0 lerobot/common/policies/factory.py | 21 +++++++++++++++++++ lerobot/common/{ => policies}/tdmpc.py | 2 +- lerobot/common/{ => policies}/tdmpc_helper.py | 0 lerobot/configs/default.yaml | 1 + lerobot/configs/pusht.yaml | 1 + lerobot/scripts/eval.py | 12 ++--------- lerobot/scripts/train.py | 15 ++----------- test/__init__.py | 0 test/test_envs.py | 12 +++-------- test/test_policies.py | 17 +++++++++++++++ test/utils.py | 11 ++++++++++ 13 files changed, 61 insertions(+), 38 deletions(-) create mode 100644 lerobot/common/policies/__init__.py create mode 100644 lerobot/common/policies/factory.py rename lerobot/common/{ => policies}/tdmpc.py (99%) rename lerobot/common/{ => policies}/tdmpc_helper.py (100%) create mode 100644 test/__init__.py create mode 100644 test/test_policies.py create mode 100644 test/utils.py diff --git a/README.md b/README.md index ce4f3dd4..c037b49c 100644 --- a/README.md +++ b/README.md @@ -89,9 +89,6 @@ eval_episodes=7 **style** ``` -isort lerobot -black lerobot -isort test -black test -pylint lerobot +isort lerobot && isort test && black lerobot && black test +pylint lerobot && pylint test # not enforce for now ``` diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py new file mode 100644 index 00000000..82e54476 --- /dev/null +++ b/lerobot/common/policies/factory.py @@ -0,0 +1,21 @@ +from lerobot.common.policies.tdmpc import TDMPC + + +def make_policy(cfg): + if cfg.policy == "tdmpc": + policy = TDMPC(cfg) + else: + raise ValueError(cfg.policy) + + if cfg.pretrained_model_path: + # TODO(rcadene): hack for old pretrained models from fowm + if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path: + if "offline" in cfg.pretrained_model_path: + policy.step[0] = 25000 + elif "final" in cfg.pretrained_model_path: + policy.step[0] = 100000 + else: + raise NotImplementedError() + policy.load(cfg.pretrained_model_path) + + return policy diff --git a/lerobot/common/tdmpc.py b/lerobot/common/policies/tdmpc.py similarity index 99% rename from lerobot/common/tdmpc.py rename to lerobot/common/policies/tdmpc.py index f8133279..d3a3c19e 100644 --- a/lerobot/common/tdmpc.py +++ b/lerobot/common/policies/tdmpc.py @@ -5,7 +5,7 @@ import numpy as np import torch import torch.nn as nn -import lerobot.common.tdmpc_helper as h +import lerobot.common.policies.tdmpc_helper as h class TOLD(nn.Module): diff --git a/lerobot/common/tdmpc_helper.py b/lerobot/common/policies/tdmpc_helper.py similarity index 100% rename from lerobot/common/tdmpc_helper.py rename to lerobot/common/policies/tdmpc_helper.py diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index c0c21713..f0339e5e 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -40,6 +40,7 @@ state_dim: 4 action_dim: 4 # TDMPC +policy: tdmpc # planning mpc: true diff --git a/lerobot/configs/pusht.yaml b/lerobot/configs/pusht.yaml index c700f6b2..d7166309 100644 --- a/lerobot/configs/pusht.yaml +++ b/lerobot/configs/pusht.yaml @@ -1,4 +1,5 @@ defaults: + - _self_ - default hydra: diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 12c5d14b..7b3357d9 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -10,7 +10,7 @@ from termcolor import colored from torchrl.envs import EnvBase from lerobot.common.envs.factory import make_env -from lerobot.common.tdmpc import TDMPC +from lerobot.common.policies.factory import make_policy from lerobot.common.utils import set_seed @@ -111,15 +111,7 @@ def eval(cfg: dict, out_dir=None): env = make_env(cfg) if cfg.pretrained_model_path: - policy = TDMPC(cfg) - if "offline" in cfg.pretrained_model_path: - policy.step[0] = 25000 - elif "final" in cfg.pretrained_model_path: - policy.step[0] = 100000 - else: - raise NotImplementedError() - policy.load(cfg.pretrained_model_path) - + policy = make_policy(cfg) policy = TensorDictModule( policy, in_keys=["observation", "step_count"], diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 40b9d30a..6c91b83f 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -11,10 +11,9 @@ from torchrl.data.datasets.openx import OpenXExperienceReplay from torchrl.data.replay_buffers import PrioritizedSliceSampler from lerobot.common.datasets.factory import make_offline_buffer -from lerobot.common.datasets.simxarm import SimxarmExperienceReplay from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger -from lerobot.common.tdmpc import TDMPC +from lerobot.common.policies.factory import make_policy from lerobot.common.utils import set_seed from lerobot.scripts.eval import eval_policy @@ -51,17 +50,7 @@ def train(cfg: dict, out_dir=None, job_name=None): print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir) env = make_env(cfg) - policy = TDMPC(cfg) - if cfg.pretrained_model_path: - # TODO(rcadene): hack for old pretrained models from fowm - if "fowm" in cfg.pretrained_model_path: - if "offline" in cfg.pretrained_model_path: - policy.step[0] = 25000 - elif "final" in cfg.pretrained_model_path: - policy.step[0] = 100000 - else: - raise NotImplementedError() - policy.load(cfg.pretrained_model_path) + policy = make_policy(cfg) td_policy = TensorDictModule( policy, diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_envs.py b/test/test_envs.py index b5c730e3..433b719d 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -6,6 +6,8 @@ from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht import PushtEnv from lerobot.common.envs.simxarm import SimxarmEnv +from .utils import init_config + def print_spec_rollout(env): print("observation_spec:", env.observation_spec) @@ -83,14 +85,6 @@ def test_pusht(from_pixels, pixels_only): ], ) def test_factory(config_name): - import hydra - from hydra import compose, initialize - - config_path = "../lerobot/configs" - hydra.core.global_hydra.GlobalHydra.instance().clear() - initialize(config_path=config_path) - cfg = compose(config_name=config_name) - + cfg = init_config(config_name) env = make_env(cfg) - check_env_specs(env) diff --git a/test/test_policies.py b/test/test_policies.py new file mode 100644 index 00000000..062e58f0 --- /dev/null +++ b/test/test_policies.py @@ -0,0 +1,17 @@ +import pytest + +from lerobot.common.policies.factory import make_policy + +from .utils import init_config + + +@pytest.mark.parametrize( + "config_name", + [ + "default", + "pusht", + ], +) +def test_factory(config_name): + cfg = init_config(config_name) + policy = make_policy(cfg) diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 00000000..be8583f5 --- /dev/null +++ b/test/utils.py @@ -0,0 +1,11 @@ +import hydra +from hydra import compose, initialize + +CONFIG_PATH = "../lerobot/configs" + + +def init_config(config_name): + hydra.core.global_hydra.GlobalHydra.instance().clear() + initialize(config_path=CONFIG_PATH) + cfg = compose(config_name=config_name) + return cfg