Merge branch 'zdy'
This commit is contained in:
43
desktop_env/evaluators/getters/file.py
Normal file
43
desktop_env/evaluators/getters/file.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Dict
|
||||
|
||||
import os
|
||||
import requests
|
||||
|
||||
def get_cloud_file(env, config: Dict[str, str]) -> str:
|
||||
"""
|
||||
Config:
|
||||
path (str): the url to download from
|
||||
dest (str): file name of the downloaded file
|
||||
"""
|
||||
|
||||
_path = os.path.join(env.cache_dir, config["dest"])
|
||||
if os.path.exists(_path):
|
||||
return _path
|
||||
|
||||
url = config["path"]
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
return _path
|
||||
|
||||
def get_vm_file(env, config: Dict[str, str]) -> str:
|
||||
"""
|
||||
Config:
|
||||
path (str): absolute path on the VM to fetch
|
||||
dest (str): file name of the downloaded file
|
||||
"""
|
||||
|
||||
_path = os.path.join(env.cache_dir, config["dest"])
|
||||
if os.path.exists(_path):
|
||||
return _path
|
||||
|
||||
file = env.controller.get_file(config["path"])
|
||||
with open(_path, "wb") as f:
|
||||
f.write(file)
|
||||
|
||||
return _path
|
||||
Reference in New Issue
Block a user