Force pi0_libero to be re-downloaded (#275)

This commit is contained in:
uzhilinsky
2025-02-06 16:14:18 -08:00
committed by GitHub
parent f543cb1d87
commit bf30fa3d4c

View File

@@ -37,7 +37,7 @@ def get_cache_dir() -> pathlib.Path:
return cache_dir
def maybe_download(url: str, **kwargs) -> pathlib.Path:
def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
"""Download a file or directory from a remote filesystem to the local cache, and return the local path.
If the local file already exists, it will be returned directly.
@@ -47,6 +47,7 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
Args:
url: URL to the file to download.
force_download: If True, the file will be downloaded even if it already exists in the cache.
**kwargs: Additional arguments to pass to fsspec.
Returns:
@@ -67,33 +68,56 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
local_path = local_path.resolve()
# Check if file already exists in cache.
if local_path.exists() and not _invalidate_expired_cache(cache_dir, local_path):
return local_path
# Download file from remote file system.
logger.info(f"Downloading {url} to {local_path}")
with filelock.FileLock(local_path.with_suffix(".lock")):
scratch_path = local_path.with_suffix(".partial")
if _is_openpi_url(url):
# Download without credentials.
_download_boto3(
url,
scratch_path,
boto_session=boto3.Session(
region_name="us-west-1",
),
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
)
elif url.startswith("s3://"):
# Download with default boto3 credentials.
_download_boto3(url, scratch_path)
# Check if the cache should be invalidated.
invalidate_cache = False
if local_path.exists():
if force_download or _should_invalidate_cache(cache_dir, local_path):
invalidate_cache = True
else:
_download_fsspec(url, scratch_path, **kwargs)
return local_path
shutil.move(scratch_path, local_path)
_ensure_permissions(local_path)
try:
lock_path = local_path.with_suffix(".lock")
with filelock.FileLock(lock_path):
# Ensure consistent permissions for the lock file.
_ensure_permissions(lock_path)
# First, remove the existing cache if it is expired.
if invalidate_cache:
logger.info(f"Removing expired cached entry: {local_path}")
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
# Download the data to a local cache.
logger.info(f"Downloading {url} to {local_path}")
scratch_path = local_path.with_suffix(".partial")
if _is_openpi_url(url):
# Download without credentials.
_download_boto3(
url,
scratch_path,
boto_session=boto3.Session(
region_name="us-west-1",
),
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
)
elif url.startswith("s3://"):
# Download with default boto3 credentials.
_download_boto3(url, scratch_path)
else:
_download_fsspec(url, scratch_path, **kwargs)
shutil.move(scratch_path, local_path)
_ensure_permissions(local_path)
except PermissionError as e:
msg = (
f"Local file permission error was encountered while downloading {url}. "
f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
)
raise PermissionError(msg) from e
return local_path
@@ -285,11 +309,12 @@ def _get_mtime(year: int, month: int, day: int) -> float:
# Partial matching will be used from top to bottom and the first match will be chosen.
# Cached entries will be retained only if they are newer than the expiration timestamp.
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
}
def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
"""Invalidate the cache if it is expired. Return True if the cache was invalidated."""
assert local_path.exists(), f"File not found at {local_path}"
@@ -298,13 +323,6 @@ def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path)
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
if pattern.match(relative_path):
# Remove if not newer than the expiration timestamp.
if local_path.stat().st_mtime <= expire_time:
logger.info(f"Removing expired cached entry: {local_path}")
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
return True
return False
return local_path.stat().st_mtime <= expire_time
return False