Improve code logic for password & resolution
This commit is contained in:
@@ -27,12 +27,6 @@ import dotenv
|
|||||||
# Load environment variables from .env file
|
# Load environment variables from .env file
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
if os.environ.get("PROVIDER_NAME") == "aws":
|
|
||||||
os.environ["CLIENT_PASSWORD"] = os.environ.get("CLIENT_PASSWORD_AWS", "osworld-public-evaluation")
|
|
||||||
else:
|
|
||||||
os.environ["CLIENT_PASSWORD"] = os.environ.get("CLIENT_PASSWORD", "password")
|
|
||||||
|
|
||||||
CLIENT_PASSWORD = os.environ["CLIENT_PASSWORD"]
|
|
||||||
|
|
||||||
PROXY_CONFIG_FILE = os.getenv("PROXY_CONFIG_FILE", "evaluation_examples/settings/proxy/dataimpulse.json") # Default proxy config file
|
PROXY_CONFIG_FILE = os.getenv("PROXY_CONFIG_FILE", "evaluation_examples/settings/proxy/dataimpulse.json") # Default proxy config file
|
||||||
|
|
||||||
@@ -45,7 +39,7 @@ init_proxy_pool(PROXY_CONFIG_FILE) # initialize the global proxy pool
|
|||||||
MAX_RETRIES = 20
|
MAX_RETRIES = 20
|
||||||
|
|
||||||
class SetupController:
|
class SetupController:
|
||||||
def __init__(self, vm_ip: str, server_port: int = 5000, chromium_port: int = 9222, vlc_port: int = 8080, cache_dir: str = "cache"):
|
def __init__(self, vm_ip: str, server_port: int = 5000, chromium_port: int = 9222, vlc_port: int = 8080, cache_dir: str = "cache", client_password: str = "", screen_width: int = 1920, screen_height: int = 1080):
|
||||||
self.vm_ip: str = vm_ip
|
self.vm_ip: str = vm_ip
|
||||||
self.server_port: int = server_port
|
self.server_port: int = server_port
|
||||||
self.chromium_port: int = chromium_port
|
self.chromium_port: int = chromium_port
|
||||||
@@ -54,6 +48,9 @@ class SetupController:
|
|||||||
self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup"
|
self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup"
|
||||||
self.cache_dir: str = cache_dir
|
self.cache_dir: str = cache_dir
|
||||||
self.use_proxy: bool = False
|
self.use_proxy: bool = False
|
||||||
|
self.client_password: str = client_password
|
||||||
|
self.screen_width: int = screen_width
|
||||||
|
self.screen_height: int = screen_height
|
||||||
|
|
||||||
def reset_cache_dir(self, cache_dir: str):
|
def reset_cache_dir(self, cache_dir: str):
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
@@ -304,22 +301,31 @@ class SetupController:
|
|||||||
terminates: bool = False
|
terminates: bool = False
|
||||||
nb_failings = 0
|
nb_failings = 0
|
||||||
|
|
||||||
def replace_screen_env_in_command(command_list):
|
def replace_screen_env_in_command(command):
|
||||||
width = int(os.environ.get("SCREEN_WIDTH", 1920))
|
password = self.client_password
|
||||||
height = int(os.environ.get("SCREEN_HEIGHT", 1080))
|
width = self.screen_width
|
||||||
|
height = self.screen_height
|
||||||
width_half = str(width // 2)
|
width_half = str(width // 2)
|
||||||
height_half = str(height // 2)
|
height_half = str(height // 2)
|
||||||
new_command_list = []
|
new_command_list = []
|
||||||
for item in command_list:
|
new_command = ""
|
||||||
if isinstance(item, str):
|
if isinstance(command, str):
|
||||||
|
new_command = command.replace("{CLIENT_PASSWORD}", password)
|
||||||
|
new_command = new_command.replace("{SCREEN_WIDTH_HALF}", width_half)
|
||||||
|
new_command = new_command.replace("{SCREEN_HEIGHT_HALF}", height_half)
|
||||||
|
new_command = new_command.replace("{SCREEN_WIDTH}", str(width))
|
||||||
|
new_command = new_command.replace("{SCREEN_HEIGHT}", str(height))
|
||||||
|
return new_command
|
||||||
|
else:
|
||||||
|
for item in command:
|
||||||
|
item = item.replace("{CLIENT_PASSWORD}", password)
|
||||||
item = item.replace("{SCREEN_WIDTH_HALF}", width_half)
|
item = item.replace("{SCREEN_WIDTH_HALF}", width_half)
|
||||||
item = item.replace("{SCREEN_HEIGHT_HALF}", height_half)
|
item = item.replace("{SCREEN_HEIGHT_HALF}", height_half)
|
||||||
item = item.replace("{SCREEN_WIDTH}", str(width))
|
item = item.replace("{SCREEN_WIDTH}", str(width))
|
||||||
item = item.replace("{SCREEN_HEIGHT}", str(height))
|
item = item.replace("{SCREEN_HEIGHT}", str(height))
|
||||||
new_command_list.append(item)
|
new_command_list.append(item)
|
||||||
return new_command_list
|
return new_command_list
|
||||||
if isinstance(command, list):
|
command = replace_screen_env_in_command(command)
|
||||||
command = replace_screen_env_in_command(command)
|
|
||||||
payload = json.dumps({"command": command, "shell": shell})
|
payload = json.dumps({"command": command, "shell": shell})
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
@@ -467,7 +473,7 @@ class SetupController:
|
|||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.error("An error occurred while trying to send the request: %s", e)
|
logger.error("An error occurred while trying to send the request: %s", e)
|
||||||
|
|
||||||
def _proxy_setup(self, client_password: str = CLIENT_PASSWORD):
|
def _proxy_setup(self, client_password: str = ""):
|
||||||
"""Setup system-wide proxy configuration using proxy pool
|
"""Setup system-wide proxy configuration using proxy pool
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -26,18 +26,19 @@ class DesktopEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
provider_name: str = "vmware",
|
provider_name: str = "aws",
|
||||||
region: str = None,
|
region: str = None,
|
||||||
path_to_vm: str = None,
|
path_to_vm: str = None,
|
||||||
snapshot_name: str = "init_state",
|
snapshot_name: str = "init_state",
|
||||||
action_space: str = "computer_13",
|
action_space: str = "computer_13",
|
||||||
cache_dir: str = "cache",
|
cache_dir: str = "cache",
|
||||||
screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))),
|
screen_size: Tuple[int] = (1920, 1080),
|
||||||
headless: bool = False,
|
headless: bool = False,
|
||||||
require_a11y_tree: bool = True,
|
require_a11y_tree: bool = True,
|
||||||
require_terminal: bool = False,
|
require_terminal: bool = False,
|
||||||
os_type: str = "Ubuntu",
|
os_type: str = "Ubuntu",
|
||||||
enable_proxy: bool = False,
|
enable_proxy: bool = False,
|
||||||
|
client_password: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -59,6 +60,16 @@ class DesktopEnv(gym.Env):
|
|||||||
self.region = region
|
self.region = region
|
||||||
self.provider_name = provider_name
|
self.provider_name = provider_name
|
||||||
self.enable_proxy = enable_proxy # Store proxy enablement setting
|
self.enable_proxy = enable_proxy # Store proxy enablement setting
|
||||||
|
if client_password == "":
|
||||||
|
if self.provider_name == "aws":
|
||||||
|
self.client_password = "osworld-public-evaluation"
|
||||||
|
else:
|
||||||
|
self.client_password = "password"
|
||||||
|
else:
|
||||||
|
self.client_password = client_password
|
||||||
|
|
||||||
|
self.screen_width = screen_size[0]
|
||||||
|
self.screen_height = screen_size[1]
|
||||||
|
|
||||||
# Default
|
# Default
|
||||||
self.server_port = 5000
|
self.server_port = 5000
|
||||||
@@ -88,7 +99,7 @@ class DesktopEnv(gym.Env):
|
|||||||
if provider_name in {"vmware", "virtualbox"} else path_to_vm
|
if provider_name in {"vmware", "virtualbox"} else path_to_vm
|
||||||
else:
|
else:
|
||||||
|
|
||||||
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region)
|
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region, screen_size=(self.screen_width, self.screen_height))
|
||||||
try:
|
try:
|
||||||
self.snapshot_name = snapshot_name
|
self.snapshot_name = snapshot_name
|
||||||
self.cache_dir_base: str = cache_dir
|
self.cache_dir_base: str = cache_dir
|
||||||
@@ -136,7 +147,7 @@ class DesktopEnv(gym.Env):
|
|||||||
self.vnc_port = int(vm_ip_ports[3])
|
self.vnc_port = int(vm_ip_ports[3])
|
||||||
self.vlc_port = int(vm_ip_ports[4])
|
self.vlc_port = int(vm_ip_ports[4])
|
||||||
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
|
||||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base)
|
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base, client_password=self.client_password, screen_width=self.screen_width, screen_height=self.screen_height)
|
||||||
|
|
||||||
def _revert_to_snapshot(self):
|
def _revert_to_snapshot(self):
|
||||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||||
@@ -197,7 +208,7 @@ class DesktopEnv(gym.Env):
|
|||||||
if task_config is not None:
|
if task_config is not None:
|
||||||
if task_config.get("proxy", False) and self.enable_proxy:
|
if task_config.get("proxy", False) and self.enable_proxy:
|
||||||
# If using proxy and proxy is enabled, set up the proxy configuration
|
# If using proxy and proxy is enabled, set up the proxy configuration
|
||||||
self.setup_controller._proxy_setup()
|
self.setup_controller._proxy_setup(self.client_password)
|
||||||
self._set_task_info(task_config)
|
self._set_task_info(task_config)
|
||||||
self.setup_controller.reset_cache_dir(self.cache_dir)
|
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||||
logger.info("Setting up environment...")
|
logger.info("Setting up environment...")
|
||||||
|
|||||||
@@ -164,11 +164,11 @@ def _allocate_vm(region=DEFAULT_REGION, screen_size=(1920, 1080)):
|
|||||||
return instance_id
|
return instance_id
|
||||||
|
|
||||||
|
|
||||||
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None):
|
def _allocate_vm_with_proxy(region=DEFAULT_REGION, proxy_config_file=None, screen_size=(1920, 1080)):
|
||||||
"""Allocate a VM with proxy configuration"""
|
"""Allocate a VM with proxy configuration"""
|
||||||
if not PROXY_SUPPORT_AVAILABLE:
|
if not PROXY_SUPPORT_AVAILABLE:
|
||||||
logger.warning("Proxy support not available, falling back to regular VM allocation")
|
logger.warning("Proxy support not available, falling back to regular VM allocation")
|
||||||
return _allocate_vm(region)
|
return _allocate_vm(region, screen_size=screen_size)
|
||||||
|
|
||||||
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
|
from desktop_env.providers.aws.provider_with_proxy import AWSProviderWithProxy
|
||||||
|
|
||||||
@@ -268,11 +268,11 @@ class AWSVMManager(VMManager):
|
|||||||
def _list_free_vms(self, region=DEFAULT_REGION):
|
def _list_free_vms(self, region=DEFAULT_REGION):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_vm_path(self, region=DEFAULT_REGION, **kwargs):
|
def get_vm_path(self, region=DEFAULT_REGION, screen_size=(1920, 1080), **kwargs):
|
||||||
if self.proxy_config_file:
|
if self.proxy_config_file:
|
||||||
logger.info("Allocating a new VM with proxy configuration in region: {}".format(region))
|
logger.info("Allocating a new VM with proxy configuration in region: {}".format(region))
|
||||||
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file)
|
new_vm_path = _allocate_vm_with_proxy(region, self.proxy_config_file, screen_size=screen_size)
|
||||||
else:
|
else:
|
||||||
logger.info("Allocating a new VM in region: {}".format(region))
|
logger.info("Allocating a new VM in region: {}".format(region))
|
||||||
new_vm_path = _allocate_vm(region)
|
new_vm_path = _allocate_vm(region, screen_size=screen_size)
|
||||||
return new_vm_path
|
return new_vm_path
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
{
|
{
|
||||||
"type": "execute",
|
"type": "execute",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": "echo password | sudo -S apt update -y && echo password | sudo -S apt install jq -y",
|
"command": "echo {CLIENT_PASSWORD} | sudo -S apt update -y && echo {CLIENT_PASSWORD} | sudo -S apt install jq -y",
|
||||||
"shell": true
|
"shell": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
{
|
{
|
||||||
"type": "execute",
|
"type": "execute",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": "echo password | sudo -S apt-get update -y && echo password | sudo -S apt-get install unzip -y && unzip /home/user/Desktop/helloExtension.zip -d /home/user/Desktop/ && rm /home/user/Desktop/helloExtension.zip",
|
"command": "echo {CLIENT_PASSWORD} | sudo -S apt-get update -y && echo {CLIENT_PASSWORD} | sudo -S apt-get install unzip -y && unzip /home/user/Desktop/helloExtension.zip -d /home/user/Desktop/ && rm /home/user/Desktop/helloExtension.zip",
|
||||||
"shell": true
|
"shell": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
{
|
{
|
||||||
"type": "command",
|
"type": "command",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": "echo password | sudo -S apt-get update && echo password | sudo -S apt-get install sysstat",
|
"command": "echo {CLIENT_PASSWORD} | sudo -S apt-get update && echo {CLIENT_PASSWORD} | sudo -S apt-get install sysstat",
|
||||||
"shell": "true"
|
"shell": "true"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
{
|
{
|
||||||
"type": "command",
|
"type": "command",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": "echo password | sudo -S apt install xsel && xsel -bc",
|
"command": "echo {CLIENT_PASSWORD} | sudo -S apt install xsel && xsel -bc",
|
||||||
"shell": "true"
|
"shell": "true"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,14 +61,7 @@
|
|||||||
{
|
{
|
||||||
"type": "command",
|
"type": "command",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": "echo password | sudo -S pip install pysrt",
|
"command": "echo {CLIENT_PASSWORD} | sudo -S pip install pysrt",
|
||||||
"shell": "true"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "command",
|
|
||||||
"parameters": {
|
|
||||||
"command": "echo osworld-public-evaluation | sudo -S pip install pysrt",
|
|
||||||
"shell": "true"
|
"shell": "true"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
{
|
{
|
||||||
"type": "execute",
|
"type": "execute",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": "echo password | sudo -S apt-get update -y && echo password | sudo -S apt-get install unzip -y && unzip /home/user/Desktop/helloExtension.zip -d /home/user/Desktop/ && rm /home/user/Desktop/helloExtension.zip",
|
"command": "echo {CLIENT_PASSWORD} | sudo -S apt-get update -y && echo {CLIENT_PASSWORD} | sudo -S apt-get install unzip -y && unzip /home/user/Desktop/helloExtension.zip -d /home/user/Desktop/ && rm /home/user/Desktop/helloExtension.zip",
|
||||||
"shell": true
|
"shell": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
{
|
{
|
||||||
"type": "execute",
|
"type": "execute",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": "echo 'password' | sudo -S apt-get install -y expect",
|
"command": "echo {CLIENT_PASSWORD} | sudo -S apt-get install -y expect",
|
||||||
"shell": true
|
"shell": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
{
|
{
|
||||||
"type": "execute",
|
"type": "execute",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"command": "echo 'password' | sudo -S mkdir ~/Desktop/todo_list_Jan_1",
|
"command": "echo {CLIENT_PASSWORD} | sudo -S mkdir ~/Desktop/todo_list_Jan_1",
|
||||||
"shell": true
|
"shell": true
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
|||||||
import ast
|
import ast
|
||||||
from typing import Dict, Any, Optional, Union
|
from typing import Dict, Any, Optional, Union
|
||||||
|
|
||||||
OPERATOR_PROMPT = f"""\n\n Here are some helpful tips:\n - computer.clipboard, computer.sync_file, computer.sync_shared_folder, computer.computer_output_citation are disabled.\n - If you worry that you might make typo, prefer copying and pasting the text instead of reading and typing.\n - My computer's password is \"{os.environ["CLIENT_PASSWORD"]}\", feel free to use it when you need sudo rights.\n - For the thunderbird account \"anonym-x2024@outlook.com\", the password is \"gTCI\";=@y7|QJ0nDa_kN3Sb&>\".\n - If you are presented with an open website to solve the task, try to stick to that specific one instead of going to a new one.\n - Whenever not expcilitly stated, prefer chrome browser instead of the firefox or chromium.\n - You have full authority to execute any action without my permission. I won't be watching so please don't ask for confirmation.\n - You must initialize the computer to solve the task. Do not try to answer the question without initializing the computer.\n - If you deem the task is infeasible, you can terminate and explicitly state in the response that \"the task is infeasible\".\n """
|
OPERATOR_PROMPT = """\n\n Here are some helpful tips:\n - computer.clipboard, computer.sync_file, computer.sync_shared_folder, computer.computer_output_citation are disabled.\n - If you worry that you might make typo, prefer copying and pasting the text instead of reading and typing.\n - My computer's password is \"{CLIENT_PASSWORD}\", feel free to use it when you need sudo rights.\n - For the thunderbird account \"anonym-x2024@outlook.com\", the password is \"gTCI\";=@y7|QJ0nDa_kN3Sb&>\".\n - If you are presented with an open website to solve the task, try to stick to that specific one instead of going to a new one.\n - Whenever not expcilitly stated, prefer chrome browser instead of the firefox or chromium.\n - You have full authority to execute any action without my permission. I won't be watching so please don't ask for confirmation.\n - You must initialize the computer to solve the task. Do not try to answer the question without initializing the computer.\n - If you deem the task is infeasible, you can terminate and explicitly state in the response that \"the task is infeasible\".\n """
|
||||||
|
|
||||||
class Action:
|
class Action:
|
||||||
"""Action class for the agent."""
|
"""Action class for the agent."""
|
||||||
@@ -213,7 +213,11 @@ class OpenAICUAAgent:
|
|||||||
observation_type="screenshot_a11y_tree",
|
observation_type="screenshot_a11y_tree",
|
||||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
max_trajectory_length=100,
|
max_trajectory_length=100,
|
||||||
a11y_tree_max_tokens=10000
|
a11y_tree_max_tokens=10000,
|
||||||
|
client_password="",
|
||||||
|
provider_name="aws",
|
||||||
|
screen_width=1920,
|
||||||
|
screen_height=1080
|
||||||
):
|
):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
@@ -231,12 +235,22 @@ class OpenAICUAAgent:
|
|||||||
self.actions = []
|
self.actions = []
|
||||||
self.observations = []
|
self.observations = []
|
||||||
|
|
||||||
|
self.screen_width = screen_width
|
||||||
|
self.screen_height = screen_height
|
||||||
|
|
||||||
self.tools = [{
|
self.tools = [{
|
||||||
"type": "computer_use_preview",
|
"type": "computer_use_preview",
|
||||||
"display_width": int(os.environ["SCREEN_WIDTH"]),
|
"display_width": self.screen_width,
|
||||||
"display_height": int(os.environ["SCREEN_HEIGHT"]),
|
"display_height": self.screen_height,
|
||||||
"environment": "linux" if platform == "ubuntu" else "windows"
|
"environment": "linux" if platform == "ubuntu" else "windows"
|
||||||
}]
|
}]
|
||||||
|
if client_password == "":
|
||||||
|
if provider_name == "aws":
|
||||||
|
self.client_password = "osworld-public-evaluation"
|
||||||
|
else:
|
||||||
|
self.client_password = "password"
|
||||||
|
else:
|
||||||
|
self.client_password = client_password
|
||||||
|
|
||||||
if observation_type == "screenshot":
|
if observation_type == "screenshot":
|
||||||
if action_space == "computer_13":
|
if action_space == "computer_13":
|
||||||
@@ -630,6 +644,7 @@ class OpenAICUAAgent:
|
|||||||
"""
|
"""
|
||||||
Predict the next action(s) based on the current observation.
|
Predict the next action(s) based on the current observation.
|
||||||
"""
|
"""
|
||||||
|
prompt = OPERATOR_PROMPT.replace("{CLIENT_PASSWORD}", self.client_password)
|
||||||
|
|
||||||
base64_image = encode_image(obs["screenshot"])
|
base64_image = encode_image(obs["screenshot"])
|
||||||
if self.cua_messages == []:
|
if self.cua_messages == []:
|
||||||
@@ -642,7 +657,7 @@ class OpenAICUAAgent:
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "input_text",
|
"type": "input_text",
|
||||||
"text": "\n " + instruction + OPERATOR_PROMPT,
|
"text": "\n " + instruction + prompt,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -78,6 +78,18 @@ def config() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--client_password", type=str, default="", help="Client password"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--screen_width", type=int, default=1920, help="Screen width"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--screen_height", type=int, default=1080, help="Screen height"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -180,19 +192,20 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
|
|
||||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||||
REGION = args.region
|
REGION = args.region
|
||||||
screen_size = (int(os.environ["SCREEN_WIDTH"]), int(os.environ["SCREEN_HEIGHT"]))
|
screen_size = (args.screen_width, args.screen_height)
|
||||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||||
env = DesktopEnv(
|
env = DesktopEnv(
|
||||||
path_to_vm=args.path_to_vm,
|
path_to_vm=args.path_to_vm,
|
||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
provider_name=os.environ["PROVIDER_NAME"],
|
provider_name=args.provider_name,
|
||||||
region=REGION,
|
region=REGION,
|
||||||
snapshot_name=ami_id,
|
snapshot_name=ami_id,
|
||||||
screen_size=screen_size,
|
screen_size=screen_size,
|
||||||
headless=args.headless,
|
headless=args.headless,
|
||||||
os_type="Ubuntu",
|
os_type="Ubuntu",
|
||||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
enable_proxy=True
|
enable_proxy=True,
|
||||||
|
client_password=args.client_password
|
||||||
)
|
)
|
||||||
active_environments.append(env)
|
active_environments.append(env)
|
||||||
agent = OpenAICUAAgent(
|
agent = OpenAICUAAgent(
|
||||||
@@ -204,6 +217,10 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
observation_type=args.observation_type,
|
observation_type=args.observation_type,
|
||||||
max_trajectory_length=args.max_trajectory_length,
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
client_password=args.client_password,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
screen_width=args.screen_width,
|
||||||
|
screen_height=args.screen_height
|
||||||
)
|
)
|
||||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user