Force pi0_libero to be re-downloaded (#275)
This commit is contained in:
@@ -37,7 +37,7 @@ def get_cache_dir() -> pathlib.Path:
|
|||||||
return cache_dir
|
return cache_dir
|
||||||
|
|
||||||
|
|
||||||
def maybe_download(url: str, **kwargs) -> pathlib.Path:
|
def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
|
||||||
"""Download a file or directory from a remote filesystem to the local cache, and return the local path.
|
"""Download a file or directory from a remote filesystem to the local cache, and return the local path.
|
||||||
|
|
||||||
If the local file already exists, it will be returned directly.
|
If the local file already exists, it will be returned directly.
|
||||||
@@ -47,6 +47,7 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
url: URL to the file to download.
|
url: URL to the file to download.
|
||||||
|
force_download: If True, the file will be downloaded even if it already exists in the cache.
|
||||||
**kwargs: Additional arguments to pass to fsspec.
|
**kwargs: Additional arguments to pass to fsspec.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -67,33 +68,56 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
|
|||||||
local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
|
local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
|
||||||
local_path = local_path.resolve()
|
local_path = local_path.resolve()
|
||||||
|
|
||||||
# Check if file already exists in cache.
|
# Check if the cache should be invalidated.
|
||||||
if local_path.exists() and not _invalidate_expired_cache(cache_dir, local_path):
|
invalidate_cache = False
|
||||||
return local_path
|
if local_path.exists():
|
||||||
|
if force_download or _should_invalidate_cache(cache_dir, local_path):
|
||||||
# Download file from remote file system.
|
invalidate_cache = True
|
||||||
logger.info(f"Downloading {url} to {local_path}")
|
|
||||||
with filelock.FileLock(local_path.with_suffix(".lock")):
|
|
||||||
scratch_path = local_path.with_suffix(".partial")
|
|
||||||
|
|
||||||
if _is_openpi_url(url):
|
|
||||||
# Download without credentials.
|
|
||||||
_download_boto3(
|
|
||||||
url,
|
|
||||||
scratch_path,
|
|
||||||
boto_session=boto3.Session(
|
|
||||||
region_name="us-west-1",
|
|
||||||
),
|
|
||||||
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
|
|
||||||
)
|
|
||||||
elif url.startswith("s3://"):
|
|
||||||
# Download with default boto3 credentials.
|
|
||||||
_download_boto3(url, scratch_path)
|
|
||||||
else:
|
else:
|
||||||
_download_fsspec(url, scratch_path, **kwargs)
|
return local_path
|
||||||
|
|
||||||
shutil.move(scratch_path, local_path)
|
try:
|
||||||
_ensure_permissions(local_path)
|
lock_path = local_path.with_suffix(".lock")
|
||||||
|
with filelock.FileLock(lock_path):
|
||||||
|
# Ensure consistent permissions for the lock file.
|
||||||
|
_ensure_permissions(lock_path)
|
||||||
|
# First, remove the existing cache if it is expired.
|
||||||
|
if invalidate_cache:
|
||||||
|
logger.info(f"Removing expired cached entry: {local_path}")
|
||||||
|
if local_path.is_dir():
|
||||||
|
shutil.rmtree(local_path)
|
||||||
|
else:
|
||||||
|
local_path.unlink()
|
||||||
|
|
||||||
|
# Download the data to a local cache.
|
||||||
|
logger.info(f"Downloading {url} to {local_path}")
|
||||||
|
scratch_path = local_path.with_suffix(".partial")
|
||||||
|
|
||||||
|
if _is_openpi_url(url):
|
||||||
|
# Download without credentials.
|
||||||
|
_download_boto3(
|
||||||
|
url,
|
||||||
|
scratch_path,
|
||||||
|
boto_session=boto3.Session(
|
||||||
|
region_name="us-west-1",
|
||||||
|
),
|
||||||
|
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
|
||||||
|
)
|
||||||
|
elif url.startswith("s3://"):
|
||||||
|
# Download with default boto3 credentials.
|
||||||
|
_download_boto3(url, scratch_path)
|
||||||
|
else:
|
||||||
|
_download_fsspec(url, scratch_path, **kwargs)
|
||||||
|
|
||||||
|
shutil.move(scratch_path, local_path)
|
||||||
|
_ensure_permissions(local_path)
|
||||||
|
|
||||||
|
except PermissionError as e:
|
||||||
|
msg = (
|
||||||
|
f"Local file permission error was encountered while downloading {url}. "
|
||||||
|
f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
|
||||||
|
)
|
||||||
|
raise PermissionError(msg) from e
|
||||||
|
|
||||||
return local_path
|
return local_path
|
||||||
|
|
||||||
@@ -285,11 +309,12 @@ def _get_mtime(year: int, month: int, day: int) -> float:
|
|||||||
# Partial matching will be used from top to bottom and the first match will be chosen.
|
# Partial matching will be used from top to bottom and the first match will be chosen.
|
||||||
# Cached entries will be retained only if they are newer than the expiration timestamp.
|
# Cached entries will be retained only if they are newer than the expiration timestamp.
|
||||||
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
|
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
|
||||||
|
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
|
||||||
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
|
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
|
def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
|
||||||
"""Invalidate the cache if it is expired. Return True if the cache was invalidated."""
|
"""Invalidate the cache if it is expired. Return True if the cache was invalidated."""
|
||||||
|
|
||||||
assert local_path.exists(), f"File not found at {local_path}"
|
assert local_path.exists(), f"File not found at {local_path}"
|
||||||
@@ -298,13 +323,6 @@ def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path)
|
|||||||
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
|
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
|
||||||
if pattern.match(relative_path):
|
if pattern.match(relative_path):
|
||||||
# Remove if not newer than the expiration timestamp.
|
# Remove if not newer than the expiration timestamp.
|
||||||
if local_path.stat().st_mtime <= expire_time:
|
return local_path.stat().st_mtime <= expire_time
|
||||||
logger.info(f"Removing expired cached entry: {local_path}")
|
|
||||||
if local_path.is_dir():
|
|
||||||
shutil.rmtree(local_path)
|
|
||||||
else:
|
|
||||||
local_path.unlink()
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
Reference in New Issue
Block a user