Compare commits
6 Commits
realman-si
...
user/alibe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc36fefa8e | ||
|
|
cfbbb4e80a | ||
|
|
b10960140d | ||
|
|
eb56a96e67 | ||
|
|
01b88b208e | ||
|
|
4c1024e537 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,3 +1,7 @@
|
||||
# Apple
|
||||
.DS_Store
|
||||
._.DS_Store
|
||||
|
||||
# Logging
|
||||
logs
|
||||
tmp
|
||||
|
||||
@@ -18,11 +18,18 @@ repos:
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.2
|
||||
rev: v0.3.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/nbQA-dev/nbQA
|
||||
rev: 1.8.4
|
||||
hooks:
|
||||
- id: nbqa-ruff
|
||||
- id: nbqa-pyupgrade
|
||||
args: ["--py310-plus"]
|
||||
- id: nbqa-isort
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: 1.8.0
|
||||
hooks:
|
||||
|
||||
@@ -30,7 +30,7 @@ conda activate lerobot
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
```
|
||||
|
||||
Install dependencies
|
||||
Install the project
|
||||
```
|
||||
poetry install
|
||||
```
|
||||
@@ -48,6 +48,13 @@ wandb login
|
||||
|
||||
## Usage
|
||||
|
||||
### Example
|
||||
|
||||
To use the [notebook example](./examples/pretrained.ipynb), install the project with jupyter dependencies
|
||||
```
|
||||
poetry install --with examples
|
||||
```
|
||||
|
||||
|
||||
### Train
|
||||
|
||||
|
||||
65
examples/notebook_utils.py
Normal file
65
examples/notebook_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# ruff: noqa
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from hydra import compose, initialize
|
||||
from hydra.core.global_hydra import GlobalHydra
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
CONFIG_DIR = "../lerobot/configs"
|
||||
DEFAULT_CONFIG = "default"
|
||||
|
||||
|
||||
def config_notebook(
|
||||
policy: str = "diffusion",
|
||||
env: str = "pusht",
|
||||
device: str = "cpu",
|
||||
config_name=DEFAULT_CONFIG,
|
||||
config_path=CONFIG_DIR,
|
||||
pretrained_model_path: str = None,
|
||||
print_config: bool = False,
|
||||
) -> DictConfig:
|
||||
GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
overrides = [
|
||||
f"env={env}",
|
||||
f"policy={policy}",
|
||||
f"device={device}",
|
||||
f"policy.pretrained_model_path={pretrained_model_path}",
|
||||
f"eval_episodes=1",
|
||||
f"env.episode_length=200",
|
||||
]
|
||||
cfg = compose(config_name=config_name, overrides=overrides)
|
||||
if print_config:
|
||||
pprint(OmegaConf.to_container(cfg))
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def notebook():
|
||||
"""tmp"""
|
||||
from pathlib import Path
|
||||
|
||||
from examples.notebook_utils import config_notebook
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
# Select policy and env
|
||||
POLICY = "act" # "tdmpc" | "diffusion"
|
||||
ENV = "aloha" # "pusht" | "simxarm"
|
||||
|
||||
# Select device
|
||||
DEVICE = "mps" # "cuda" | "mps"
|
||||
|
||||
# Generated videos will be written here
|
||||
OUT_DIR = Path("./outputs")
|
||||
OUT_EXAMPLE = OUT_DIR / "eval" / "eval_episode_0.mp4"
|
||||
|
||||
# Setup config
|
||||
cfg = config_notebook(policy=POLICY, env=ENV, device=DEVICE, print_config=False)
|
||||
|
||||
eval(cfg, out_dir=OUT_DIR)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
notebook()
|
||||
101
examples/pretrained.ipynb
Normal file
101
examples/pretrained.ipynb
Normal file
@@ -0,0 +1,101 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"from huggingface_hub import snapshot_download\n",
|
||||
"from hydra import compose, initialize\n",
|
||||
"from hydra.core.global_hydra import GlobalHydra\n",
|
||||
"from IPython.display import Video\n",
|
||||
"from omegaconf import OmegaConf\n",
|
||||
"from omegaconf.dictconfig import DictConfig\n",
|
||||
"\n",
|
||||
"from examples.notebook_utils import config_notebook\n",
|
||||
"from examples.pretrained_script import download_eval_pretrained\n",
|
||||
"from lerobot.scripts.eval import eval\n",
|
||||
"\n",
|
||||
"# Select policy and env\n",
|
||||
"POLICY = \"diffusion\" # \"tdmpc\" | \"diffusion\"\n",
|
||||
"ENV = \"pusht\" # \"pusht\" | \"simxarm\"\n",
|
||||
"\n",
|
||||
"# Select device\n",
|
||||
"DEVICE = \"mps\" # \"cuda\" | \"mps\"\n",
|
||||
"\n",
|
||||
"# Generated videos will be written here\n",
|
||||
"OUT_DIR = Path(\"./outputs\")\n",
|
||||
"OUT_EXAMPLE = OUT_DIR / \"eval\" / \"eval_episode_0.mp4\"\n",
|
||||
"\n",
|
||||
"PRETRAINED_REPO = \"lerobot/diffusion_policy_pusht_image\"\n",
|
||||
"pretrained_folder = Path(snapshot_download(repo_id=PRETRAINED_REPO, repo_type=\"model\", revision=\"v1.0\"))\n",
|
||||
"pretrained_model_path = pretrained_folder / \"model.pt\"\n",
|
||||
"\n",
|
||||
"cfg_path = pretrained_folder / \"config.yaml\"\n",
|
||||
"GlobalHydra.instance().clear()\n",
|
||||
"\n",
|
||||
"print(pretrained_folder)\n",
|
||||
"\n",
|
||||
"initialize(config_path=\"../../../.cache/huggingface/hub/models--lerobot--diffusion_policy_pusht_image/snapshots/163d168f5c193c356b82e3bf6bbf5b4eeaa780d7\")\n",
|
||||
"overrides = [\n",
|
||||
" f\"env={ENV}\",\n",
|
||||
" f\"policy={POLICY}\",\n",
|
||||
" f\"device={DEVICE}\",\n",
|
||||
" f\"+policy.pretrained_model_path={pretrained_model_path}\",\n",
|
||||
" f\"eval_episodes=1\",\n",
|
||||
" f\"+env.episode_length=200\",\n",
|
||||
"]\n",
|
||||
"cfg = compose(config_name=\"config\", overrides=overrides)\n",
|
||||
"pprint(OmegaConf.to_container(cfg))\n",
|
||||
"# Setup config\n",
|
||||
"#cfg = config_notebook(cfg_path, policy=POLICY, env=ENV, device=DEVICE, print_config=False, pretrained_model_path=pretrained_model_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# eval(cfg, out_dir=OUT_DIR)\n",
|
||||
"download_eval_pretrained(OUT_DIR, cfg)\n",
|
||||
"Video(OUT_EXAMPLE, embed=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lerobot",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
62
examples/pretrained_script.py
Normal file
62
examples/pretrained_script.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tensordict.nn import TensorDictModule
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def download_eval_pretrained(out_dir, cfg):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
else:
|
||||
# when policy is None, rollout a random policy
|
||||
policy = None
|
||||
|
||||
metrics = eval_policy(
|
||||
env,
|
||||
policy=policy,
|
||||
save_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
fps=cfg.env.fps,
|
||||
max_steps=cfg.env.episode_length,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
)
|
||||
print(metrics)
|
||||
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_eval_pretrained()
|
||||
@@ -202,7 +202,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
d = torch.load(fp, map_location=torch.device(self.device))
|
||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||
if len(missing_keys) > 0:
|
||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||
|
||||
@@ -145,16 +145,24 @@ def eval(cfg: dict, out_dir=None):
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
else:
|
||||
# when policy is None, rollout a random policy
|
||||
policy = None
|
||||
# WIP
|
||||
policy = make_policy(cfg)
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
# TODO(aliberts, Cadene): fetch pretrained model from HF hub
|
||||
# if cfg.policy.pretrained_model_path:
|
||||
# policy = make_policy(cfg)
|
||||
# policy = TensorDictModule(
|
||||
# policy,
|
||||
# in_keys=["observation", "step_count"],
|
||||
# out_keys=["action"],
|
||||
# )
|
||||
# else:
|
||||
# # when policy is None, rollout a random policy
|
||||
# policy = None
|
||||
|
||||
metrics = eval_policy(
|
||||
env,
|
||||
|
||||
@@ -25,15 +25,6 @@ def train_cli(cfg: dict):
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
|
||||
loss = info["loss"]
|
||||
grad_norm = info["grad_norm"]
|
||||
|
||||
1602
poetry.lock
generated
1602
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -60,6 +60,13 @@ debugpy = "^1.8.1"
|
||||
pytest = "^8.1.0"
|
||||
|
||||
|
||||
[tool.poetry.group.examples]
|
||||
optional = true
|
||||
|
||||
|
||||
[tool.poetry.group.examples.dependencies]
|
||||
jupyter = "^1.0.0"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
|
||||
Reference in New Issue
Block a user