diff --git a/src/openpi/shared/download.py b/src/openpi/shared/download.py index 2b025b7..44f927f 100644 --- a/src/openpi/shared/download.py +++ b/src/openpi/shared/download.py @@ -75,9 +75,10 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path: if _is_openpi_url(url): # Download with openpi credentials. # TODO(ury): Remove once the bucket becomes public. - boto_session = boto3.Session( + boto_session = boto3.session.Session( aws_access_key_id="AKIA4MTWIIQIZBO44C62", aws_secret_access_key="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", + region_name="us-east-1", ) _download_boto3(url, scratch_path, boto_session=boto_session) elif url.startswith("s3://"): @@ -136,13 +137,14 @@ def _download_boto3( return bucket_name, prefix bucket_name, prefix = validate_and_parse_url(url) - s3api = boto3.resource("s3") + session = boto_session or boto3.Session() + + s3api = session.resource("s3") bucket = s3api.Bucket(bucket_name) objects = list(bucket.objects.filter(Prefix=prefix)) total_size = sum(obj.size for obj in objects) - session = boto_session or boto3.Session() s3t = _get_s3_transfer_manager(session, workers) def transfer(