From bf30fa3d4c1714f6f6ec7a2e799cf256cc5755ff Mon Sep 17 00:00:00 2001 From: uzhilinsky Date: Thu, 6 Feb 2025 16:14:18 -0800 Subject: [PATCH] Force pi0_libero to be re-downloaded (#275) --- src/openpi/shared/download.py | 88 +++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 35 deletions(-) diff --git a/src/openpi/shared/download.py b/src/openpi/shared/download.py index 95b1c5f..d919477 100644 --- a/src/openpi/shared/download.py +++ b/src/openpi/shared/download.py @@ -37,7 +37,7 @@ def get_cache_dir() -> pathlib.Path: return cache_dir -def maybe_download(url: str, **kwargs) -> pathlib.Path: +def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path: """Download a file or directory from a remote filesystem to the local cache, and return the local path. If the local file already exists, it will be returned directly. @@ -47,6 +47,7 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path: Args: url: URL to the file to download. + force_download: If True, the file will be downloaded even if it already exists in the cache. **kwargs: Additional arguments to pass to fsspec. Returns: @@ -67,33 +68,56 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path: local_path = cache_dir / parsed.netloc / parsed.path.strip("/") local_path = local_path.resolve() - # Check if file already exists in cache. - if local_path.exists() and not _invalidate_expired_cache(cache_dir, local_path): - return local_path - - # Download file from remote file system. - logger.info(f"Downloading {url} to {local_path}") - with filelock.FileLock(local_path.with_suffix(".lock")): - scratch_path = local_path.with_suffix(".partial") - - if _is_openpi_url(url): - # Download without credentials. - _download_boto3( - url, - scratch_path, - boto_session=boto3.Session( - region_name="us-west-1", - ), - botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED), - ) - elif url.startswith("s3://"): - # Download with default boto3 credentials. - _download_boto3(url, scratch_path) + # Check if the cache should be invalidated. + invalidate_cache = False + if local_path.exists(): + if force_download or _should_invalidate_cache(cache_dir, local_path): + invalidate_cache = True else: - _download_fsspec(url, scratch_path, **kwargs) + return local_path - shutil.move(scratch_path, local_path) - _ensure_permissions(local_path) + try: + lock_path = local_path.with_suffix(".lock") + with filelock.FileLock(lock_path): + # Ensure consistent permissions for the lock file. + _ensure_permissions(lock_path) + # First, remove the existing cache if it is expired. + if invalidate_cache: + logger.info(f"Removing expired cached entry: {local_path}") + if local_path.is_dir(): + shutil.rmtree(local_path) + else: + local_path.unlink() + + # Download the data to a local cache. + logger.info(f"Downloading {url} to {local_path}") + scratch_path = local_path.with_suffix(".partial") + + if _is_openpi_url(url): + # Download without credentials. + _download_boto3( + url, + scratch_path, + boto_session=boto3.Session( + region_name="us-west-1", + ), + botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED), + ) + elif url.startswith("s3://"): + # Download with default boto3 credentials. + _download_boto3(url, scratch_path) + else: + _download_fsspec(url, scratch_path, **kwargs) + + shutil.move(scratch_path, local_path) + _ensure_permissions(local_path) + + except PermissionError as e: + msg = ( + f"Local file permission error was encountered while downloading {url}. " + f"Please try again after removing the cached data using: `rm -rf {local_path}*`" + ) + raise PermissionError(msg) from e return local_path @@ -285,11 +309,12 @@ def _get_mtime(year: int, month: int, day: int) -> float: # Partial matching will be used from top to bottom and the first match will be chosen. # Cached entries will be retained only if they are newer than the expiration timestamp. _INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = { + re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6), re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3), } -def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool: +def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool: """Invalidate the cache if it is expired. Return True if the cache was invalidated.""" assert local_path.exists(), f"File not found at {local_path}" @@ -298,13 +323,6 @@ def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items(): if pattern.match(relative_path): # Remove if not newer than the expiration timestamp. - if local_path.stat().st_mtime <= expire_time: - logger.info(f"Removing expired cached entry: {local_path}") - if local_path.is_dir(): - shutil.rmtree(local_path) - else: - local_path.unlink() - return True - return False + return local_path.stat().st_mtime <= expire_time return False