Compare commits

..

12 Commits

Author SHA1 Message Date
Alexander Soare
d374873849 use Path type instead of str 2024-03-15 13:15:34 +00:00
Remi
9c88071bc7 Merge pull request #28 from Cadene/user/rcadene/2024_03_14_hf_dataset
Download datasets from hugging face
2024-03-15 13:52:13 +01:00
Cadene
5805a7ffb1 small fix in type + comments 2024-03-15 12:44:52 +00:00
Cadene
41521f7e96 self.root is Path or None + The following packages are already present in the pyproject.toml and will be skipped:
- huggingface-hub

If you want to update it to the latest compatible version, you can use `poetry update package`.
If you prefer to upgrade it to the latest available version, you can use `poetry add package@latest`.

Nothing to add.
2024-03-15 10:56:46 +00:00
Cadene
b10c9507d4 Small fix 2024-03-15 00:36:55 +00:00
Cadene
a311d38796 Add aloha + improve readme 2024-03-15 00:30:11 +00:00
Cadene
19730b3412 Add pusht on hf dataset (WIP) 2024-03-14 16:59:37 +00:00
Simon Alibert
95e84079ef Merge pull request #25 from Cadene/user/aliberts/2024_03_13_ci_fix
CI env fix
2024-03-14 15:24:56 +01:00
Simon Alibert
8e856f1bf7 Update readme 2024-03-14 15:24:38 +01:00
Simon Alibert
8c2b47752a Remove cuda env copy 2024-03-14 13:55:35 +01:00
Simon Alibert
f515cb6efd Add dm-control 2024-03-14 13:42:03 +01:00
Simon Alibert
c3f8d14fd8 CI env fix 2024-03-14 13:29:27 +01:00
13 changed files with 3379 additions and 163 deletions

3126
.github/poetry/cpu/poetry.lock generated vendored Normal file

File diff suppressed because it is too large Load Diff

107
.github/poetry/cpu/pyproject.toml vendored Normal file
View File

@@ -0,0 +1,107 @@
[tool.poetry]
name = "lerobot"
version = "0.1.0"
description = "Le robot is learning"
authors = [
"Rémi Cadène <re.cadene@gmail.com>",
"Simon Alibert <alibert.sim@gmail.com>",
]
repository = "https://github.com/Cadene/lerobot"
readme = "README.md"
license = "MIT"
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Topic :: Software Development :: Build Tools",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.10",
]
packages = [{include = "lerobot"}]
[tool.poetry.dependencies]
python = "^3.10"
cython = "^3.0.8"
termcolor = "^2.4.0"
omegaconf = "^2.3.0"
dm-env = "^1.6"
pandas = "^2.2.1"
wandb = "^0.16.3"
moviepy = "^1.0.3"
imageio = {extras = ["pyav"], version = "^2.34.0"}
gdown = "^5.1.0"
hydra-core = "^1.3.2"
einops = "^0.7.0"
pygame = "^2.5.2"
pymunk = "^6.6.0"
zarr = "^2.17.0"
shapely = "^2.0.3"
scikit-image = "^0.22.0"
numba = "^0.59.0"
mpmath = "^1.3.0"
torch = {version = "^2.2.1", source = "torch-cpu"}
tensordict = {git = "https://github.com/pytorch/tensordict"}
torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"}
mujoco = "^3.1.2"
mujoco-py = "^2.1.2.14"
gym = "^0.26.2"
opencv-python = "^4.9.0.80"
diffusers = "^0.26.3"
torchvision = {version = "^0.17.1", source = "torch-cpu"}
h5py = "^3.10.0"
dm = "^1.3"
dm-control = "^1.0.16"
huggingface-hub = "^0.21.4"
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
debugpy = "^1.8.1"
pytest = "^8.1.0"
[[tool.poetry.source]]
name = "torch-cpu"
url = "https://download.pytorch.org/whl/cpu"
priority = "supplemental"
[tool.ruff]
line-length = 110
target-version = "py310"
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
[tool.poetry-dynamic-versioning]
enable = true
[build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
build-backend = "poetry_dynamic_versioning.backend"

View File

@@ -46,8 +46,8 @@ jobs:
id: restore-poetry-cache id: restore-poetry-cache
uses: actions/cache/restore@v3 uses: actions/cache/restore@v3
with: with:
path: ~/.local # the path depends on the OS path: ~/.local
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache key: poetry-${{ env.POETRY_VERSION }}
- name: Install Poetry - name: Install Poetry
if: steps.restore-poetry-cache.outputs.cache-hit != 'true' if: steps.restore-poetry-cache.outputs.cache-hit != 'true'
@@ -64,8 +64,8 @@ jobs:
id: save-poetry-cache id: save-poetry-cache
uses: actions/cache/save@v3 uses: actions/cache/save@v3
with: with:
path: ~/.local # the path depends on the OS path: ~/.local
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache key: poetry-${{ env.POETRY_VERSION }}
- name: Configure Poetry - name: Configure Poetry
run: poetry config virtualenvs.in-project true run: poetry config virtualenvs.in-project true
@@ -73,6 +73,10 @@ jobs:
#---------------------------------------------- #----------------------------------------------
# install dependencies # install dependencies
#---------------------------------------------- #----------------------------------------------
# TODO(aliberts): move to gpu runners
- name: Select cpu dependencies # HACK
run: cp -t . .github/poetry/cpu/pyproject.toml .github/poetry/cpu/poetry.lock
- name: Load cached venv - name: Load cached venv
id: restore-dependencies-cache id: restore-dependencies-cache
uses: actions/cache/restore@v3 uses: actions/cache/restore@v3
@@ -80,18 +84,10 @@ jobs:
path: .venv path: .venv
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }} key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
- name: Info
run: |
sudo du -sh /tmp
sudo df -h
- name: Install dependencies - name: Install dependencies
if: steps.restore-dependencies-cache.outputs.cache-hit != 'true' if: steps.restore-dependencies-cache.outputs.cache-hit != 'true'
run: | run: |
mkdir ~/tmp mkdir ~/tmp
echo $TMPDIR
echo $TEMP
echo $TMP
poetry install --no-interaction --no-root poetry install --no-interaction --no-root
- name: Save cached venv - name: Save cached venv

View File

@@ -103,6 +103,18 @@ pre-commit install
pre-commit run -a pre-commit run -a
``` ```
**Adding dependencies (temporary)**
Right now, for the CI to work, whenever a new dependency is added it needs to be also added to the cpu env, eg:
```
# Run in this directory, adds the package to the main env with cuda
poetry add some-package
# Adds the same package to the cpu env
cd .github/poetry/cpu && poetry add some-package
```
**Tests** **Tests**
Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already). Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already).
@@ -134,6 +146,25 @@ Run tests
DATA_DIR="tests/data" pytest -sx tests DATA_DIR="tests/data" pytest -sx tests
``` ```
**Datasets**
To add a pytorch rl dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```
huggingface-cli login --token $HUGGINGFACE_TOKEN --add-to-git-credential
```
Then you can upload it to the hub with:
```
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload --repo-type dataset $HF_USER/$DATASET data/$DATASET
```
For instance, for [cadene/pusht](https://huggingface.co/datasets/cadene/pusht), we used:
```
HF_USER=cadene
DATASET=pusht
```
## Acknowledgment ## Acknowledgment
- Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/) - Our Diffusion policy and Pusht environment are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
- Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/) - Our TDMPC policy and Simxarm environment are adapted from [FOWM](https://www.yunhaifeng.com/FOWM/)

View File

@@ -1,4 +1,3 @@
import abc
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@@ -7,8 +6,8 @@ import einops
import torch import torch
import torchrl import torchrl
import tqdm import tqdm
from huggingface_hub import snapshot_download
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
@@ -23,7 +22,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,
root: Path = None, root: Path | None = None,
pin_memory: bool = False, pin_memory: bool = False,
prefetch: int = None, prefetch: int = None,
sampler: SliceSampler = None, sampler: SliceSampler = None,
@@ -33,11 +32,8 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
): ):
self.dataset_id = dataset_id self.dataset_id = dataset_id
self.shuffle = shuffle self.shuffle = shuffle
self.root = _get_root_dir(self.dataset_id) if root is None else root self.root = root
self.root = Path(self.root) storage = self._download_or_load_dataset()
self.data_dir = self.root / self.dataset_id
storage = self._download_or_load_storage()
super().__init__( super().__init__(
storage=storage, storage=storage,
@@ -89,7 +85,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
self._transform = transform self._transform = transform
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict: def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
stats_path = self.data_dir / "stats.pth" stats_path = Path(self.data_dir) / "stats.pth"
if stats_path.exists(): if stats_path.exists():
stats = torch.load(stats_path) stats = torch.load(stats_path)
else: else:
@@ -98,19 +94,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
torch.save(stats, stats_path) torch.save(stats, stats_path)
return stats return stats
@abc.abstractmethod def _download_or_load_dataset(self) -> torch.StorageBase:
def _download_and_preproc(self) -> torch.StorageBase: if self.root is None:
raise NotImplementedError() self.data_dir = snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset")
def _download_or_load_storage(self):
if not self._is_downloaded():
storage = self._download_and_preproc()
else: else:
storage = TensorStorage(TensorDict.load_memmap(self.data_dir)) self.data_dir = self.root / self.dataset_id
return storage return TensorStorage(TensorDict.load_memmap(self.data_dir))
def _is_downloaded(self) -> bool:
return self.data_dir.is_dir()
def _compute_stats(self, num_batch=100, batch_size=32): def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer( rb = TensorDictReplayBuffer(

View File

@@ -87,7 +87,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,
root: Path = None, root: Path | None = None,
pin_memory: bool = False, pin_memory: bool = False,
prefetch: int = None, prefetch: int = None,
sampler: SliceSampler = None, sampler: SliceSampler = None,
@@ -124,8 +124,9 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def image_keys(self) -> list: def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]] return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
def _download_and_preproc(self): def _download_and_preproc_obsolete(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.is_dir(): if not raw_dir.is_dir():
download(raw_dir, self.dataset_id) download(raw_dir, self.dataset_id)
@@ -174,7 +175,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
if ep_id == 0: if ep_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.data_dir) td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td td_data[idxtd : idxtd + len(ep_td)] = ep_td
idxtd = idxtd + len(ep_td) idxtd = idxtd + len(ep_td)

View File

@@ -7,7 +7,10 @@ from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
from lerobot.common.envs.transforms import NormalizeTransform, Prod from lerobot.common.envs.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ.get("DATA_DIR", "data")) # DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
# to load a subset of our datasets for faster continuous integration.
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_offline_buffer( def make_offline_buffer(
@@ -77,9 +80,9 @@ def make_offline_buffer(
offline_buffer = clsfunc( offline_buffer = clsfunc(
dataset_id=dataset_id, dataset_id=dataset_id,
root=DATA_DIR,
sampler=sampler, sampler=sampler,
batch_size=batch_size, batch_size=batch_size,
root=DATA_DIR,
pin_memory=pin_memory, pin_memory=pin_memory,
prefetch=prefetch if isinstance(prefetch, int) else None, prefetch=prefetch if isinstance(prefetch, int) else None,
) )

View File

@@ -90,7 +90,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,
root: Path = None, root: Path | None = None,
pin_memory: bool = False, pin_memory: bool = False,
prefetch: int = None, prefetch: int = None,
sampler: SliceSampler = None, sampler: SliceSampler = None,
@@ -111,8 +111,9 @@ class PushtExperienceReplay(AbstractExperienceReplay):
transform=transform, transform=transform,
) )
def _download_and_preproc(self): def _download_and_preproc_obsolete(self):
raw_dir = self.data_dir.parent / f"{self.data_dir.name}_raw" assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
zarr_path = (raw_dir / PUSHT_ZARR).resolve() zarr_path = (raw_dir / PUSHT_ZARR).resolve()
if not zarr_path.is_dir(): if not zarr_path.is_dir():
raw_dir.mkdir(parents=True, exist_ok=True) raw_dir.mkdir(parents=True, exist_ok=True)
@@ -208,7 +209,7 @@ class PushtExperienceReplay(AbstractExperienceReplay):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.data_dir) td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td td_data[idxtd : idxtd + len(ep_td)] = ep_td

View File

@@ -43,7 +43,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
batch_size: int = None, batch_size: int = None,
*, *,
shuffle: bool = True, shuffle: bool = True,
root: Path = None, root: Path | None = None,
pin_memory: bool = False, pin_memory: bool = False,
prefetch: int = None, prefetch: int = None,
sampler: SliceSampler = None, sampler: SliceSampler = None,
@@ -64,11 +64,12 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
transform=transform, transform=transform,
) )
def _download_and_preproc(self): def _download_and_preproc_obsolete(self):
assert self.root is not None
# TODO(rcadene): finish download # TODO(rcadene): finish download
download() download()
dataset_path = self.data_dir / "buffer.pkl" dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'") print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f: with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f) dataset_dict = pickle.load(f)
@@ -110,7 +111,7 @@ class SimxarmExperienceReplay(AbstractExperienceReplay):
if episode_id == 0: if episode_id == 0:
# hack to initialize tensordict data structure to store episodes # hack to initialize tensordict data structure to store episodes
td_data = episode[0].expand(total_frames).memmap_like(self.data_dir) td_data = episode[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idx0:idx1] = episode td_data[idx0:idx1] = episode

View File

@@ -4,113 +4,9 @@ from typing import Optional
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs import EnvBase from torchrl.envs import EnvBase
from torchrl.envs.utils import _terminated_or_truncated, step_mdp
class EnvBaseWithMultiStepRollouts(EnvBase): class AbstractEnv(EnvBase):
"""Adds handling of policies that output action trajectories to be execute with a fixed horizon."""
def _rollout_stop_early(
self,
*,
tensordict,
auto_cast_to_device,
max_steps,
policy,
policy_device,
env_device,
callback,
):
"""Override adds handling of multi-step policies."""
tensordicts = []
step_ix = 0
do_break = False
while not do_break:
if auto_cast_to_device:
if policy_device is not None:
tensordict = tensordict.to(policy_device, non_blocking=True)
else:
tensordict.clear_device_()
tensordict = policy(tensordict)
if auto_cast_to_device:
if env_device is not None:
tensordict = tensordict.to(env_device, non_blocking=True)
else:
tensordict.clear_device_()
for action in tensordict["action"].clone():
tensordict["action"] = action
tensordict = self.step(tensordict)
tensordicts.append(tensordict.clone(False))
if step_ix == max_steps - 1:
# we don't truncated as one could potentially continue the run
do_break = True
break
tensordict = step_mdp(
tensordict,
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=self.reward_keys,
action_keys=self.action_keys,
done_keys=self.done_keys,
)
# done and truncated are in done_keys
# We read if any key is done.
any_done = _terminated_or_truncated(
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key=None,
)
if any_done:
break
if callback is not None:
callback(self, tensordict)
step_ix += 1
return tensordicts
def _rollout_nonstop(
self,
*,
tensordict,
auto_cast_to_device,
max_steps,
policy,
policy_device,
env_device,
callback,
):
"""Override adds handling of multi-step policies."""
tensordicts = []
tensordict_ = tensordict
for i in range(max_steps):
if auto_cast_to_device:
if policy_device is not None:
tensordict_ = tensordict_.to(policy_device, non_blocking=True)
else:
tensordict_.clear_device_()
tensordict_ = policy(tensordict_)
if auto_cast_to_device:
if env_device is not None:
tensordict_ = tensordict_.to(env_device, non_blocking=True)
else:
tensordict_.clear_device_()
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
tensordicts.append(tensordict)
if i == max_steps - 1:
# we don't truncated as one could potentially continue the run
break
if callback is not None:
callback(self, tensordict)
return tensordicts
class AbstractEnv(EnvBaseWithMultiStepRollouts):
def __init__( def __init__(
self, self,
task, task,

View File

@@ -4,16 +4,7 @@ import torch
from tensordict import TensorDictBase from tensordict import TensorDictBase
from tensordict.nn import dispatch from tensordict.nn import dispatch
from tensordict.utils import NestedKey from tensordict.utils import NestedKey
from torchrl.envs.transforms import ObservationTransform, Transform, TransformedEnv from torchrl.envs.transforms import ObservationTransform, Transform
from torchrl.envs.transforms.transforms import _TEnvPostInit
from lerobot.common.envs.abstract import EnvBaseWithMultiStepRollouts
class TransformedEnv(EnvBaseWithMultiStepRollouts, TransformedEnv, metaclass=_TEnvPostInit):
"""Keep method overrides from EnvBaseWithMultiStepRollouts."""
pass
class Prod(ObservationTransform): class Prod(ObservationTransform):

75
poetry.lock generated
View File

@@ -838,6 +838,78 @@ files = [
[package.dependencies] [package.dependencies]
numpy = ">=1.17.3" numpy = ">=1.17.3"
[[package]]
name = "hf-transfer"
version = "0.1.6"
description = ""
optional = false
python-versions = ">=3.7"
files = [
{file = "hf_transfer-0.1.6-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6fd3d61f9229d27def007e53540412507b74ac2fdb1a29985ae0b6a5137749a2"},
{file = "hf_transfer-0.1.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b043bb78df1225de043eb041de9d97783fcca14a0bdc1b1d560fc172fc21b648"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7db60dd18eae4fa6ea157235fb82196cde5313995b396d1b591aad3b790a7f8f"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:30d31dbab9b5a558cce407b8728e39d87d7af1ef8745ddb90187e9ae0b9e1e90"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6b368bddd757efc7af3126ba81f9ac8f9435e2cc00902cb3d64f2be28d8f719"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa2086d8aefaaa3e144e167324574882004c0cec49bf2d0638ec4b74732d8da0"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45d8985a0940bfe1535cb4ca781f5c11e47c83798ef3373ee1f5d57bbe527a9c"},
{file = "hf_transfer-0.1.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f42b89735f1cde22f2a795d1f0915741023235666be7de45879e533c7d6010c"},
{file = "hf_transfer-0.1.6-cp310-none-win32.whl", hash = "sha256:2d2c4c4613f3ad45b6ce6291e347b2d3ba1b86816635681436567e461cb3c961"},
{file = "hf_transfer-0.1.6-cp310-none-win_amd64.whl", hash = "sha256:78b0eed8d8dce60168a46e584b9742b816af127d7e410a713e12c31249195342"},
{file = "hf_transfer-0.1.6-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1d8c172153f9a6cdaecf137612c42796076f61f6bea1072c90ac2e17c1ab6fa"},
{file = "hf_transfer-0.1.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2c601996351f90c514a75a0eeb02bf700b1ad1db2d946cbfe4b60b79e29f0b2f"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e585c808405557d3f5488f385706abb696997bbae262ea04520757e30836d9d"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ec51af1e8cf4268c268bd88932ade3d7ca895a3c661b42493503f02610ae906b"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d106fdf996332f6df3ed3fab6d6332df82e8c1fb4b20fd81a491ca4d2ab5616a"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e9c2ee9e9fde5a0319cc0e8ddfea10897482bc06d5709b10a238f1bc2ebcbc0b"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f394ea32bc7802b061e549d3133efc523b4ae4fd19bf4b74b183ca6066eef94e"},
{file = "hf_transfer-0.1.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4282f09902114cd67fca98a1a1bad569a44521a8395fedf327e966714f68b977"},
{file = "hf_transfer-0.1.6-cp311-none-win32.whl", hash = "sha256:276dbf307d5ab6f1bcbf57b5918bfcf9c59d6848ccb28242349e1bb5985f983b"},
{file = "hf_transfer-0.1.6-cp311-none-win_amd64.whl", hash = "sha256:fa475175c51451186bea804471995fa8e7b2a48a61dcca55534911dc25955527"},
{file = "hf_transfer-0.1.6-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:23d157a67acfa00007799323a1c441b2bbacc7dee625b016b7946fe0e25e6c89"},
{file = "hf_transfer-0.1.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6067342a2864b988f861cd2d31bd78eb1e84d153a3f6df38485b6696d9ad3013"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91cfcb3070e205b58fa8dc8bcb6a62ccc40913fcdb9cd1ff7c364c8e3aa85345"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb76064ac5165d5eeaaf8d0903e8bf55477221ecc2a4a4d69f0baca065ab905b"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9dabd3a177d83028f164984cf4dd859f77ec1e20c97a6f307ff8fcada0785ef1"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0bf4254e44f64a26e0a5b73b5d7e8d91bb36870718fb4f8e126ec943ff4c805"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d32c1b106f38f336ceb21531f4db9b57d777b9a33017dafdb6a5316388ebe50"},
{file = "hf_transfer-0.1.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff05aba3c83921e5c7635ba9f07c693cc893350c447644824043aeac27b285f5"},
{file = "hf_transfer-0.1.6-cp312-none-win32.whl", hash = "sha256:051ef0c55607652cb5974f59638da035773254b9a07d7ee5b574fe062de4c9d1"},
{file = "hf_transfer-0.1.6-cp312-none-win_amd64.whl", hash = "sha256:716fb5c574fcbdd8092ce73f9b6c66f42e3544337490f77c60ec07df02bd081b"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c0c981134a55965e279cb7be778c1ccaf93f902fc9ebe31da4f30caf824cc4d"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ef1f145f04c5b573915bcb1eb5db4039c74f6b46fce73fc473c4287e613b623"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0a7609b004db3347dbb7796df45403eceb171238210d054d93897d6d84c63a4"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60f0864bf5996773dbd5f8ae4d1649041f773fe9d5769f4c0eeb5553100acef3"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d01e55d630ffe70a4f5d0ed576a04c6a48d7c65ca9a7d18f2fca385f20685a9"},
{file = "hf_transfer-0.1.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d855946c5062b665190de15b2bdbd4c8eddfee35350bfb7564592e23d36fbbd3"},
{file = "hf_transfer-0.1.6-cp37-none-win32.whl", hash = "sha256:fd40b2409cfaf3e8aba20169ee09552f69140e029adeec261b988903ff0c8f6f"},
{file = "hf_transfer-0.1.6-cp37-none-win_amd64.whl", hash = "sha256:0e0eba49d46d3b5481919aea0794aec625fbc6ecdf13fe7e0e9f3fc5d5ad5971"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7e669fecb29fc454449739f9f53ed9253197e7c19e6a6eaa0f08334207af4287"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:89f701802892e5eb84f89f402686861f87dc227d6082b05f4e9d9b4e8015a3c3"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6f2b0c8b95b01409275d789a9b74d5f2e146346f985d384bf50ec727caf1ccc"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:aa855a2fa262792a230f9efcdb5da6d431b747d1861d2a69fe7834b19aea077e"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa8ca349afb2f0713475426946261eb2035e4efb50ebd2c1d5ad04f395f4217"},
{file = "hf_transfer-0.1.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01255f043996bc7d1bae62d8afc5033a90c7e36ce308b988eeb84afe0a69562f"},
{file = "hf_transfer-0.1.6-cp38-none-win32.whl", hash = "sha256:60b1db183e8a7540cd4f8b2160ff4de55f77cb0c3fc6a10be1e7c30eb1b2bdeb"},
{file = "hf_transfer-0.1.6-cp38-none-win_amd64.whl", hash = "sha256:fb8be3cba6aaa50ab2e9dffbd25c8eb2046785eeff642cf0cdd0dd9ae6be3539"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d09af35e3e3f09b664e6429e9a0dc200f29c5bdfd88bdd9666de51183b1fe202"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a4505bd707cc14d85c800f961fad8ca76f804a8ad22fbb7b1a217d8d0c15e6a5"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c453fd8b0be9740faa23cecd1f28ee9ead7d900cefa64ff836960c503a744c9"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:13cb8884e718a78c3b81a8cdec9c7ac196dd42961fce55c3ccff3dd783e5ad7a"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39cd39df171a2b5404de69c4e6cd14eee47f6fe91c1692f939bfb9e59a0110d8"},
{file = "hf_transfer-0.1.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ff0629ee9f98df57a783599602eb498f9ec3619dc69348b12e4d9d754abf0e9"},
{file = "hf_transfer-0.1.6-cp39-none-win32.whl", hash = "sha256:164a6ce445eb0cc7c645f5b6e1042c003d33292520c90052b6325f30c98e4c5f"},
{file = "hf_transfer-0.1.6-cp39-none-win_amd64.whl", hash = "sha256:11b8b4b73bf455f13218c5f827698a30ae10998ca31b8264b51052868c7a9f11"},
{file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16957ba057376a99ea361074ce1094f61b58e769defa6be2422ae59c0b6a6530"},
{file = "hf_transfer-0.1.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7db952112e3b8ee1a5cbf500d2443e9ce4fb893281c5310a3e31469898628005"},
{file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d39d826a7344f5e39f438d62632acd00467aa54a083b66496f61ef67a9885a56"},
{file = "hf_transfer-0.1.6-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e2653fbfa92e7651db73d99b697c8684e7345c479bd6857da80bed6138abb2"},
{file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:144277e6a86add10b90ec3b583253aec777130312256bfc8d5ade5377e253807"},
{file = "hf_transfer-0.1.6-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3bb53bcd16365313b2aa0dbdc28206f577d70770f31249cdabc387ac5841edcc"},
{file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:990d73a5a68d8261980f146c51f4c5f9995314011cb225222021ad7c39f3af2d"},
{file = "hf_transfer-0.1.6-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:652406037029ab9b4097b4c5f29321bad5f64c2b46fbff142509d918aec87c29"},
{file = "hf_transfer-0.1.6.tar.gz", hash = "sha256:deb505a7d417d7055fd7b3549eadb91dfe782941261f3344025c486c16d1d2f9"},
]
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "0.21.4" version = "0.21.4"
@@ -852,6 +924,7 @@ files = [
[package.dependencies] [package.dependencies]
filelock = "*" filelock = "*"
fsspec = ">=2023.5.0" fsspec = ">=2023.5.0"
hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf_transfer\""}
packaging = ">=20.9" packaging = ">=20.9"
pyyaml = ">=5.1" pyyaml = ">=5.1"
requests = "*" requests = "*"
@@ -3254,4 +3327,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "3d82309a7b2388d774b56ceb6f6906ef0732d8cedda0d76cc84a30e239949be8" content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9"

View File

@@ -50,6 +50,7 @@ diffusers = "^0.26.3"
torchvision = "^0.17.1" torchvision = "^0.17.1"
h5py = "^3.10.0" h5py = "^3.10.0"
dm-control = "1.0.14" dm-control = "1.0.14"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]