Refactor datasets into LeRobotDataset (#91)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 12:23:12 +02:00
committed by GitHub
parent e760e4cd63
commit 659c69a1c0
90 changed files with 167 additions and 352 deletions

View File

@@ -8,31 +8,25 @@ Example:
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
```
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets`
When implementing a new dataset loadable with LeRobotDataset follow these steps:
- Update `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` in `lerobot/__init__.py`
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
"""
from lerobot.__version__ import __version__ # noqa: F401
available_envs = [
"aloha",
"pusht",
"xarm",
]
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
@@ -41,22 +35,24 @@ available_tasks_per_env = {
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
}
available_envs = list(available_tasks_per_env.keys())
available_datasets = {
available_datasets_per_env = {
"aloha": [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"lerobot/aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_scripted",
"lerobot/aloha_sim_transfer_cube_human",
"lerobot/aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
"pusht": ["lerobot/pusht"],
"xarm": [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"lerobot/xarm_lift_medium",
"lerobot/xarm_lift_medium_replay",
"lerobot/xarm_push_medium",
"lerobot/xarm_push_medium_replay",
],
}
available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets]
available_policies = [
"act",
@@ -71,10 +67,12 @@ available_policies_per_env = {
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets]
env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
]
env_dataset_policy_triplets = [
(env, dataset, policy)
for env, datasets in available_datasets.items()
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
for policy in available_policies_per_env[env]
]