Force pi0_libero to be re-downloaded (#275)

This commit is contained in:
uzhilinsky
2025-02-06 16:14:18 -08:00
committed by GitHub
parent f543cb1d87
commit bf30fa3d4c

View File

@@ -37,7 +37,7 @@ def get_cache_dir() -> pathlib.Path:
return cache_dir return cache_dir
def maybe_download(url: str, **kwargs) -> pathlib.Path: def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
"""Download a file or directory from a remote filesystem to the local cache, and return the local path. """Download a file or directory from a remote filesystem to the local cache, and return the local path.
If the local file already exists, it will be returned directly. If the local file already exists, it will be returned directly.
@@ -47,6 +47,7 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
Args: Args:
url: URL to the file to download. url: URL to the file to download.
force_download: If True, the file will be downloaded even if it already exists in the cache.
**kwargs: Additional arguments to pass to fsspec. **kwargs: Additional arguments to pass to fsspec.
Returns: Returns:
@@ -67,33 +68,56 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
local_path = cache_dir / parsed.netloc / parsed.path.strip("/") local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
local_path = local_path.resolve() local_path = local_path.resolve()
# Check if file already exists in cache. # Check if the cache should be invalidated.
if local_path.exists() and not _invalidate_expired_cache(cache_dir, local_path): invalidate_cache = False
return local_path if local_path.exists():
if force_download or _should_invalidate_cache(cache_dir, local_path):
# Download file from remote file system. invalidate_cache = True
logger.info(f"Downloading {url} to {local_path}")
with filelock.FileLock(local_path.with_suffix(".lock")):
scratch_path = local_path.with_suffix(".partial")
if _is_openpi_url(url):
# Download without credentials.
_download_boto3(
url,
scratch_path,
boto_session=boto3.Session(
region_name="us-west-1",
),
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
)
elif url.startswith("s3://"):
# Download with default boto3 credentials.
_download_boto3(url, scratch_path)
else: else:
_download_fsspec(url, scratch_path, **kwargs) return local_path
shutil.move(scratch_path, local_path) try:
_ensure_permissions(local_path) lock_path = local_path.with_suffix(".lock")
with filelock.FileLock(lock_path):
# Ensure consistent permissions for the lock file.
_ensure_permissions(lock_path)
# First, remove the existing cache if it is expired.
if invalidate_cache:
logger.info(f"Removing expired cached entry: {local_path}")
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
# Download the data to a local cache.
logger.info(f"Downloading {url} to {local_path}")
scratch_path = local_path.with_suffix(".partial")
if _is_openpi_url(url):
# Download without credentials.
_download_boto3(
url,
scratch_path,
boto_session=boto3.Session(
region_name="us-west-1",
),
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
)
elif url.startswith("s3://"):
# Download with default boto3 credentials.
_download_boto3(url, scratch_path)
else:
_download_fsspec(url, scratch_path, **kwargs)
shutil.move(scratch_path, local_path)
_ensure_permissions(local_path)
except PermissionError as e:
msg = (
f"Local file permission error was encountered while downloading {url}. "
f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
)
raise PermissionError(msg) from e
return local_path return local_path
@@ -285,11 +309,12 @@ def _get_mtime(year: int, month: int, day: int) -> float:
# Partial matching will be used from top to bottom and the first match will be chosen. # Partial matching will be used from top to bottom and the first match will be chosen.
# Cached entries will be retained only if they are newer than the expiration timestamp. # Cached entries will be retained only if they are newer than the expiration timestamp.
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = { _INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3), re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
} }
def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool: def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
"""Invalidate the cache if it is expired. Return True if the cache was invalidated.""" """Invalidate the cache if it is expired. Return True if the cache was invalidated."""
assert local_path.exists(), f"File not found at {local_path}" assert local_path.exists(), f"File not found at {local_path}"
@@ -298,13 +323,6 @@ def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path)
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items(): for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
if pattern.match(relative_path): if pattern.match(relative_path):
# Remove if not newer than the expiration timestamp. # Remove if not newer than the expiration timestamp.
if local_path.stat().st_mtime <= expire_time: return local_path.stat().st_mtime <= expire_time
logger.info(f"Removing expired cached entry: {local_path}")
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
return True
return False
return False return False