Refactoring VMware Integration and Implementing AWS Support (#44)
* Initailize aws support * Add README for the VM server * Refactor OSWorld for supporting more cloud services. * Initialize vmware and aws implementation v1, waiting for verification * Initlize files for azure, gcp and virtualbox support * Debug on the VMware provider * Fix on aws interface mapping * Fix instance type * Refactor * Clean * hk region; debug * Fix lock * Remove print * Remove key_name requirements when allocating aws vm * Clean README --------- Co-authored-by: XinyuanWangCS <xywang626@gmail.com>
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -183,3 +183,7 @@ at_processing
|
||||
|
||||
test.xlsx
|
||||
test2.xlsx
|
||||
|
||||
# vm info
|
||||
.vms
|
||||
/vm_data
|
||||
|
||||
13
README.md
13
README.md
@@ -31,6 +31,7 @@
|
||||
|
||||
|
||||
## 📢 Updates
|
||||
- 2024-06-15: We refactor the code of environment part to decompose VMware Integration, and start to support other platforms such as VitualBox, AWS, Azure, etc. Hold tight!
|
||||
- 2024-04-11: We released our [paper](https://arxiv.org/abs/2404.07972), [environment and benchmark](https://github.com/xlang-ai/OSWorld), and [project page](https://os-world.github.io/). Check it out!
|
||||
|
||||
## 💾 Installation
|
||||
@@ -58,7 +59,7 @@ Alternatively, you can install the environment without any benchmark tasks:
|
||||
pip install desktop-env
|
||||
```
|
||||
|
||||
2. Install [VMware Workstation Pro](https://www.vmware.com/products/workstation-pro/workstation-pro-evaluation.html) (for systems with Apple Chips, you should install [VMware Fusion](https://www.vmware.com/go/getfusion)) and configure the `vmrun` command. The installation process can refer to [How to install VMware Worksation Pro](./INSTALL_VMWARE.md). Verify the successful installation by running the following:
|
||||
2. Install [VMware Workstation Pro](https://www.vmware.com/products/workstation-pro/workstation-pro-evaluation.html) (for systems with Apple Chips, you should install [VMware Fusion](https://www.vmware.com/go/getfusion)) and configure the `vmrun` command. The installation process can refer to [How to install VMware Worksation Pro](desktop_env/providers/vmware/INSTALL_VMWARE.md). Verify the successful installation by running the following:
|
||||
```bash
|
||||
vmrun -T ws list
|
||||
```
|
||||
@@ -68,12 +69,18 @@ If the installation along with the environment variable set is successful, you w
|
||||
All set! Our setup script will automatically download the necessary virtual machines and configure the environment for you.
|
||||
|
||||
### On AWS or Azure (Virtualized platform)
|
||||
We are working on supporting it 👷. Please hold tight!
|
||||
#### On your AWS
|
||||
See [AWS_GUIDELINE](https://github.com/xlang-ai/OSWorld/blob/main/desktop_env/providers/aws/AWS_GUIDELINE.md)
|
||||
|
||||
#### Others
|
||||
We are working on supporting more 👷. Please hold tight!
|
||||
|
||||
|
||||
## 🚀 Quick Start
|
||||
Run the following minimal example to interact with the environment:
|
||||
|
||||
```python
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
|
||||
example = {
|
||||
"id": "94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||
|
||||
@@ -21,8 +21,8 @@ If you are interested in contributing to the project, please check the [CONTRIBU
|
||||
- [x] Add more tasks, maybe scale to 300 for v1.0.0, and create a dynamic leaderboard
|
||||
- [x] Multiprocess support, can enable reinforcement learning to be more efficient
|
||||
- [x] Add support for automatic VM download and configuration, enable auto-scaling management
|
||||
- [ ] VPN setup doc for those who need it
|
||||
- [ ] Support running on platforms that have nested virtualization, e.g. Google Cloud, AWS, etc.
|
||||
- [x] VPN setup doc for those who need it
|
||||
- [x] Support running on platforms that have nested virtualization, e.g. Google Cloud, AWS, etc.
|
||||
- [ ] Prepare for the first release of Windows vm image for the environment
|
||||
- [ ] Be able to run without virtual machine platform VMware Pro, e.g. VirtualBox, or other platforms
|
||||
|
||||
@@ -31,4 +31,4 @@ If you are interested in contributing to the project, please check the [CONTRIBU
|
||||
- [ ] Improve the annotation tool base on DuckTrack, and make it more robust which aligns on accessibility tree
|
||||
- [ ] Annotate the steps of doing the task
|
||||
- [ ] Crawl all resources we explored from the internet, and make it easy to access
|
||||
- [ ] Set up ways for the crowd-sourcing/community to contribute new examples
|
||||
- [ ] Set up ways for the crowdsourcing/community to contribute new examples
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional
|
||||
import time
|
||||
import requests
|
||||
|
||||
from desktop_env.envs.actions import KEYBOARD_KEYS
|
||||
from desktop_env.actions import KEYBOARD_KEYS
|
||||
|
||||
logger = logging.getLogger("desktopenv.pycontroller")
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Callable, Any, Optional, Tuple
|
||||
from typing import List, Dict, Union
|
||||
@@ -12,7 +11,7 @@ import gymnasium as gym
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
from desktop_env.evaluators import metrics, getters
|
||||
from . import _get_vm_path
|
||||
from desktop_env.providers import create_vm_manager_and_provider
|
||||
|
||||
logger = logging.getLogger("desktopenv.env")
|
||||
|
||||
@@ -20,25 +19,6 @@ Metric = Callable[[Any, Any], float]
|
||||
Getter = Callable[[gym.Env, Dict[str, Any]], Any]
|
||||
|
||||
|
||||
def _execute_command(command: List[str]) -> None:
|
||||
def _is_contained_in(a, b):
|
||||
for v in set(a):
|
||||
if a.count(v) > b.count(v):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Specially handled for the `vmrun` command in Windows
|
||||
if _is_contained_in(["vmrun", "-T", "ws", "start"], command):
|
||||
p = subprocess.Popen(command)
|
||||
p.wait()
|
||||
else:
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True,
|
||||
encoding="utf-8")
|
||||
if result.returncode != 0:
|
||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||
return result.stdout
|
||||
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""
|
||||
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
|
||||
@@ -46,6 +26,8 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_name: str = "vmware",
|
||||
region: str = None,
|
||||
path_to_vm: str = None,
|
||||
snapshot_name: str = "init_state",
|
||||
action_space: str = "computer_13",
|
||||
@@ -57,6 +39,8 @@ class DesktopEnv(gym.Env):
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
provider_name (str): virtualization provider name, default to "vmware"
|
||||
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
|
||||
path_to_vm (str): path to .vmx file
|
||||
snapshot_name (str): snapshot name to revert to, default to "init_state"
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
@@ -67,9 +51,13 @@ class DesktopEnv(gym.Env):
|
||||
require_a11y_tree (bool): whether to require accessibility tree
|
||||
require_terminal (bool): whether to require terminal output
|
||||
"""
|
||||
# Initialize VM manager and vitualization provider
|
||||
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
|
||||
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm if path_to_vm else _get_vm_path())))
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) if path_to_vm else \
|
||||
self.manager.get_vm_path(region)
|
||||
|
||||
self.snapshot_name = snapshot_name
|
||||
self.cache_dir_base: str = cache_dir
|
||||
# todo: add the logic to get the screen size from the VM
|
||||
@@ -80,14 +68,11 @@ class DesktopEnv(gym.Env):
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
self._start_emulator()
|
||||
self.vm_ip = self._get_vm_ip()
|
||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
self.action_space = action_space
|
||||
self.action_space = action_space # todo: refactor it to the ActType
|
||||
|
||||
# episodic stuffs, like counters, will be updated or reset
|
||||
# when calling self.reset()
|
||||
@@ -95,6 +80,63 @@ class DesktopEnv(gym.Env):
|
||||
self._step_no: int = 0
|
||||
self.action_history: List[Dict[str, any]] = []
|
||||
|
||||
def _start_emulator(self):
|
||||
# Power on the virtual machine
|
||||
self.provider.start_emulator(self.path_to_vm, self.headless)
|
||||
|
||||
# Get the ip from the virtual machine, and setup the controller
|
||||
self.vm_ip = self.provider.get_ip_address(self.path_to_vm)
|
||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||
|
||||
def _revert_to_snapshot(self):
|
||||
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm
|
||||
# due to the fact it could be changed when implemented by cloud services
|
||||
self.path_to_vm = self.provider.revert_to_snapshot(self.path_to_vm, self.snapshot_name)
|
||||
|
||||
def _save_state(self, snapshot_name=None):
|
||||
# Save the current virtual machine state to a certain snapshot name
|
||||
self.provider.save_state(self.path_to_vm, snapshot_name)
|
||||
|
||||
def close(self):
|
||||
# Close (release) the virtual machine
|
||||
self.provider.stop_emulator(self.path_to_vm)
|
||||
|
||||
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||
# Reset to certain task in OSWorld
|
||||
logger.info("Resetting environment...")
|
||||
logger.info("Switching task...")
|
||||
logger.info("Setting counters...")
|
||||
self._traj_no += 1
|
||||
self._step_no = 0
|
||||
self.action_history.clear()
|
||||
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||
self._revert_to_snapshot()
|
||||
logger.info("Starting emulator...")
|
||||
self._start_emulator()
|
||||
logger.info("Emulator started.")
|
||||
|
||||
if task_config is not None:
|
||||
self._set_task_info(task_config)
|
||||
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||
logger.info("Setting up environment...")
|
||||
self.setup_controller.setup(self.config)
|
||||
logger.info("Environment setup complete.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def _get_obs(self):
|
||||
# We provide screenshot, accessibility_tree (optional), terminal (optional), and instruction.
|
||||
# can be customized and scaled
|
||||
return {
|
||||
"screenshot": self.controller.get_screenshot(),
|
||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
||||
"instruction": self.instruction
|
||||
}
|
||||
|
||||
@property
|
||||
def vm_platform(self):
|
||||
return self.controller.get_vm_platform()
|
||||
@@ -103,49 +145,6 @@ class DesktopEnv(gym.Env):
|
||||
def vm_screen_size(self):
|
||||
return self.controller.get_vm_screen_size()
|
||||
|
||||
def _start_emulator(self):
|
||||
while True:
|
||||
try:
|
||||
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
||||
output = output.decode()
|
||||
output: List[str] = output.splitlines()
|
||||
# if self.path_to_vm.lstrip("~/") in output:
|
||||
if self.path_to_vm in output:
|
||||
logger.info("VM is running.")
|
||||
break
|
||||
else:
|
||||
logger.info("Starting VM...")
|
||||
_execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm]) if not self.headless \
|
||||
else _execute_command(["vmrun", "-T", "ws", "start", self.path_to_vm, "nogui"])
|
||||
time.sleep(3)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Error executing command: {e.output.decode().strip()}")
|
||||
|
||||
def _get_vm_ip(self):
|
||||
max_retries = 20
|
||||
logger.info("Getting IP Address...")
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
output = _execute_command(["vmrun", "-T", "ws", "getGuestIPAddress", self.path_to_vm, "-wait"]).strip()
|
||||
logger.info(f"IP address: {output}")
|
||||
return output
|
||||
except Exception as e:
|
||||
print(e)
|
||||
time.sleep(5)
|
||||
logger.info("Retrying...")
|
||||
raise Exception("Failed to get VM IP address!")
|
||||
|
||||
def _save_state(self):
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
||||
|
||||
def _get_obs(self):
|
||||
return {
|
||||
"screenshot": self.controller.get_screenshot(),
|
||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
||||
"instruction": self.instruction
|
||||
}
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
self.task_id: str = task_config["id"]
|
||||
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
|
||||
@@ -198,35 +197,6 @@ class DesktopEnv(gym.Env):
|
||||
or (len(self.metric) == len(self.result_getter) == len(self.expected_getter) == len(
|
||||
self.metric_options)))
|
||||
|
||||
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||
logger.info("Resetting environment...")
|
||||
|
||||
logger.info("Switching task...")
|
||||
|
||||
logger.info("Setting counters...")
|
||||
self._traj_no += 1
|
||||
self._step_no = 0
|
||||
self.action_history.clear()
|
||||
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
|
||||
time.sleep(5)
|
||||
|
||||
logger.info("Starting emulator...")
|
||||
self._start_emulator()
|
||||
logger.info("Emulator started.")
|
||||
|
||||
if task_config is not None:
|
||||
self._set_task_info(task_config)
|
||||
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||
logger.info("Setting up environment...")
|
||||
self.setup_controller.setup(self.config)
|
||||
time.sleep(5)
|
||||
logger.info("Environment setup complete.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def step(self, action, pause=0.5):
|
||||
self._step_no += 1
|
||||
self.action_history.append(action)
|
||||
@@ -319,9 +289,6 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self._get_screenshot()
|
||||
return self.controller.get_screenshot()
|
||||
else:
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
||||
def close(self):
|
||||
_execute_command(["vmrun", "stop", self.path_to_vm])
|
||||
0
desktop_env/providers/README.md
Normal file
0
desktop_env/providers/README.md
Normal file
18
desktop_env/providers/__init__.py
Normal file
18
desktop_env/providers/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from desktop_env.providers.base import VMManager, Provider
|
||||
from desktop_env.providers.vmware.manager import VMwareVMManager
|
||||
from desktop_env.providers.vmware.provider import VMwareProvider
|
||||
from desktop_env.providers.aws.manager import AWSVMManager
|
||||
from desktop_env.providers.aws.provider import AWSProvider
|
||||
|
||||
|
||||
def create_vm_manager_and_provider(provider_name: str, region: str):
|
||||
"""
|
||||
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
|
||||
"""
|
||||
provider_name = provider_name.lower().strip()
|
||||
if provider_name == "vmware":
|
||||
return VMwareVMManager(), VMwareProvider(region)
|
||||
elif provider_name in ["aws", "amazon web services"]:
|
||||
return AWSVMManager(), AWSProvider(region)
|
||||
else:
|
||||
raise NotImplementedError(f"{provider_name} not implemented!")
|
||||
57
desktop_env/providers/aws/AWS_GUIDELINE.md
Normal file
57
desktop_env/providers/aws/AWS_GUIDELINE.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# README for AWS VM Management
|
||||
|
||||
Welcome to the AWS VM Management documentation. Before you proceed with using the code to manage AWS services, please ensure the following variables are set correctly according to your AWS environment.
|
||||
|
||||
## Configuration Variables
|
||||
You need to assign values to several variables crucial for the operation of these scripts on AWS:
|
||||
|
||||
- **`REGISTRY_PATH`**: Sets the file path for VM registration logging.
|
||||
- Example: `'.aws_vms'`
|
||||
- **`DEFAULT_REGION`**: Default AWS region where your instances will be launched.
|
||||
- Example: `"us-east-1"`
|
||||
- **`IMAGE_ID_MAP`**: Dictionary mapping regions to specific AMI IDs that should be used for instance creation.
|
||||
- Example:
|
||||
```python
|
||||
IMAGE_ID_MAP = {
|
||||
"us-east-1": "ami-09bab251951b4272c",
|
||||
# Add other regions and corresponding AMIs
|
||||
}
|
||||
```
|
||||
- **`INSTANCE_TYPE`**: Specifies the type of EC2 instance to be launched.
|
||||
- Example: `"t3.medium"`
|
||||
- **`KEY_NAME`**: Specifies the name of the key pair to be used for the instances.
|
||||
- Example: `"osworld_key"`
|
||||
- **`NETWORK_INTERFACES`**: Configuration settings for network interfaces, which include subnet IDs, security group IDs, and public IP addressing.
|
||||
- Example:
|
||||
```python
|
||||
NETWORK_INTERFACES = {
|
||||
"us-east-1": [
|
||||
{
|
||||
"SubnetId": "subnet-037edfff66c2eb894",
|
||||
"AssociatePublicIpAddress": True,
|
||||
"DeviceIndex": 0,
|
||||
"Groups": ["sg-0342574803206ee9c"]
|
||||
}
|
||||
],
|
||||
# Add configurations for other regions
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### AWS CLI Configuration
|
||||
Before using these scripts, you must configure your AWS CLI with your credentials. This can be done via the following commands:
|
||||
|
||||
```bash
|
||||
aws configure
|
||||
```
|
||||
This command will prompt you for:
|
||||
- AWS Access Key ID
|
||||
- AWS Secret Access Key
|
||||
- Default region name (Optional, you can press enter)
|
||||
|
||||
Enter your credentials as required. This setup will allow you to interact with AWS services using the credentials provided.
|
||||
|
||||
### Disclaimer
|
||||
Use the provided scripts and configurations at your own risk. Ensure that you understand the AWS pricing model and potential costs associated with deploying instances, as using these scripts might result in charges on your AWS account.
|
||||
|
||||
> **Note:** Ensure all AMI images used in `IMAGE_ID_MAP` are accessible and permissioned correctly for your AWS account, and that they are available in the specified region.
|
||||
0
desktop_env/providers/aws/__init__.py
Normal file
0
desktop_env/providers/aws/__init__.py
Normal file
179
desktop_env/providers/aws/manager.py
Normal file
179
desktop_env/providers/aws/manager.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import os
|
||||
from filelock import FileLock
|
||||
import boto3
|
||||
import psutil
|
||||
|
||||
import logging
|
||||
|
||||
from desktop_env.providers.base import VMManager
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
REGISTRY_PATH = '.aws_vms'
|
||||
|
||||
DEFAULT_REGION = "us-east-1"
|
||||
# todo: Add doc for the configuration of image, security group and network interface
|
||||
# todo: public the AMI images
|
||||
IMAGE_ID_MAP = {
|
||||
"us-east-1": "ami-0b0531325a0d5d488",
|
||||
"ap-east-1": "ami-0b92a0bf157fecaa9"
|
||||
}
|
||||
|
||||
INSTANCE_TYPE = "t3.large"
|
||||
NETWORK_INTERFACE_MAP = {
|
||||
"us-east-1": [
|
||||
{
|
||||
"SubnetId": "subnet-037edfff66c2eb894",
|
||||
"AssociatePublicIpAddress": True,
|
||||
"DeviceIndex": 0,
|
||||
"Groups": [
|
||||
"sg-0342574803206ee9c"
|
||||
]
|
||||
}
|
||||
],
|
||||
"ap-east-1": [
|
||||
{
|
||||
"SubnetId": "subnet-011060501be0b589c",
|
||||
"AssociatePublicIpAddress": True,
|
||||
"DeviceIndex": 0,
|
||||
"Groups": [
|
||||
"sg-090470e64df78f6eb"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _allocate_vm(region=DEFAULT_REGION):
|
||||
run_instances_params = {
|
||||
"MaxCount": 1,
|
||||
"MinCount": 1,
|
||||
"ImageId": IMAGE_ID_MAP[region],
|
||||
"InstanceType": INSTANCE_TYPE,
|
||||
"EbsOptimized": True,
|
||||
"NetworkInterfaces": NETWORK_INTERFACE_MAP[region]
|
||||
}
|
||||
|
||||
ec2_client = boto3.client('ec2', region_name=region)
|
||||
response = ec2_client.run_instances(**run_instances_params)
|
||||
instance_id = response['Instances'][0]['InstanceId']
|
||||
logger.info(f"Waiting for instance {instance_id} to be running...")
|
||||
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
|
||||
logger.info(f"Waiting for instance {instance_id} status checks to pass...")
|
||||
ec2_client.get_waiter('instance_status_ok').wait(InstanceIds=[instance_id])
|
||||
logger.info(f"Instance {instance_id} is ready.")
|
||||
|
||||
return instance_id
|
||||
|
||||
|
||||
class AWSVMManager(VMManager):
|
||||
def __init__(self, registry_path=REGISTRY_PATH):
|
||||
self.registry_path = registry_path
|
||||
self.lock = FileLock(".aws_lck", timeout=10)
|
||||
self.initialize_registry()
|
||||
|
||||
def initialize_registry(self):
|
||||
with self.lock: # Locking during initialization
|
||||
if not os.path.exists(self.registry_path):
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.write('')
|
||||
|
||||
def add_vm(self, vm_path, region=DEFAULT_REGION):
|
||||
with self.lock:
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
vm_path_at_vm_region = "{}@{}".format(vm_path, region)
|
||||
new_lines = lines + [f'{vm_path_at_vm_region}|free\n']
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
|
||||
with self.lock:
|
||||
new_lines = []
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
registered_vm_path, _ = line.strip().split('|')
|
||||
if registered_vm_path == "{}@{}".format(vm_path, region):
|
||||
new_lines.append(f'{registered_vm_path}|{pid}\n')
|
||||
else:
|
||||
new_lines.append(line)
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
def check_and_clean(self):
|
||||
with self.lock: # Lock when cleaning up the registry and vms_dir
|
||||
# Check and clean on the running vms, detect the released ones and mark then as 'free'
|
||||
active_pids = {p.pid for p in psutil.process_iter()}
|
||||
new_lines = []
|
||||
vm_path_at_vm_regions = []
|
||||
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
vm_path_at_vm_region, pid_str = line.strip().split('|')
|
||||
vm_path, vm_region = vm_path_at_vm_region.split("@")
|
||||
ec2_client = boto3.client('ec2', region_name=vm_region)
|
||||
|
||||
try:
|
||||
response = ec2_client.describe_instances(InstanceIds=[vm_path])
|
||||
if not response['Reservations'] or response['Reservations'][0]['Instances'][0]['State'][
|
||||
'Name'] in ['terminated', 'shutting-down']:
|
||||
logger.info(f"VM {vm_path} not found or terminated, releasing it.")
|
||||
continue
|
||||
elif response['Reservations'][0]['Instances'][0]['State'][
|
||||
'Name'] == "Stopped":
|
||||
logger.info(f"VM {vm_path} stopped, mark it as free")
|
||||
new_lines.append(f'{vm_path}@{vm_region}|free\n')
|
||||
continue
|
||||
except ec2_client.exceptions.ClientError as e:
|
||||
if 'InvalidInstanceID.NotFound' in str(e):
|
||||
logger.info(f"VM {vm_path} not found, releasing it.")
|
||||
continue
|
||||
|
||||
vm_path_at_vm_regions.append(vm_path_at_vm_region)
|
||||
if pid_str == "free":
|
||||
new_lines.append(line)
|
||||
continue
|
||||
|
||||
if int(pid_str) in active_pids:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(f'{vm_path_at_vm_region}|free\n')
|
||||
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
# We won't check and clean on the files on aws and delete the unregistered ones
|
||||
# Since this can lead to unexpected delete on other server
|
||||
# PLease do monitor the instances to avoid additional cost
|
||||
|
||||
def list_free_vms(self, region=DEFAULT_REGION):
|
||||
with self.lock: # Lock when reading the registry
|
||||
free_vms = []
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
vm_path_at_vm_region, pid_str = line.strip().split('|')
|
||||
vm_path, vm_region = vm_path_at_vm_region.split("@")
|
||||
if pid_str == "free" and vm_region == region:
|
||||
free_vms.append((vm_path, pid_str))
|
||||
|
||||
return free_vms
|
||||
|
||||
def get_vm_path(self, region=DEFAULT_REGION):
|
||||
self.check_and_clean()
|
||||
free_vms_paths = self.list_free_vms(region)
|
||||
if len(free_vms_paths) == 0:
|
||||
# No free virtual machine available, generate a new one
|
||||
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
|
||||
new_vm_path = _allocate_vm(region)
|
||||
self.add_vm(new_vm_path, region)
|
||||
self.occupy_vm(new_vm_path, os.getpid(), region)
|
||||
return new_vm_path
|
||||
else:
|
||||
# Choose the first free virtual machine
|
||||
chosen_vm_path = free_vms_paths[0][0]
|
||||
self.occupy_vm(chosen_vm_path, os.getpid(), region)
|
||||
return chosen_vm_path
|
||||
112
desktop_env/providers/aws/provider.py
Normal file
112
desktop_env/providers/aws/provider.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
import logging
|
||||
|
||||
from .manager import INSTANCE_TYPE
|
||||
from desktop_env.providers.base import Provider
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.aws.AWSProvider")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
WAIT_DELAY = 15
|
||||
MAX_ATTEMPTS = 10
|
||||
|
||||
|
||||
class AWSProvider(Provider):
|
||||
|
||||
def start_emulator(self, path_to_vm: str, headless: bool):
|
||||
logger.info("Starting AWS VM...")
|
||||
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||
|
||||
try:
|
||||
# Start the instance
|
||||
ec2_client.start_instances(InstanceIds=[path_to_vm])
|
||||
logger.info(f"Instance {path_to_vm} is starting...")
|
||||
|
||||
# Wait for the instance to be in the 'running' state
|
||||
waiter = ec2_client.get_waiter('instance_running')
|
||||
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
|
||||
logger.info(f"Instance {path_to_vm} is now running.")
|
||||
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to start the AWS VM {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_ip_address(self, path_to_vm: str) -> str:
|
||||
logger.info("Getting AWS VM IP address...")
|
||||
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||
|
||||
try:
|
||||
response = ec2_client.describe_instances(InstanceIds=[path_to_vm])
|
||||
for reservation in response['Reservations']:
|
||||
for instance in reservation['Instances']:
|
||||
private_ip_address = instance.get('PrivateIpAddress', '')
|
||||
return private_ip_address
|
||||
return '' # Return an empty string if no IP address is found
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to retrieve private IP address for the instance {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_state(self, path_to_vm: str, snapshot_name: str):
|
||||
logger.info("Saving AWS VM state...")
|
||||
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||
|
||||
try:
|
||||
image_response = ec2_client.create_image(InstanceId=path_to_vm, ImageId=snapshot_name)
|
||||
image_id = image_response['ImageId']
|
||||
logger.info(f"AMI {image_id} created successfully from instance {path_to_vm}.")
|
||||
return image_id
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to create AMI from the instance {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
||||
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
|
||||
logger.info(f"Reverting AWS VM to snapshot: {snapshot_name}...")
|
||||
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||
|
||||
try:
|
||||
# Step 1: Retrieve the original instance details
|
||||
instance_details = ec2_client.describe_instances(InstanceIds=[path_to_vm])
|
||||
instance = instance_details['Reservations'][0]['Instances'][0]
|
||||
security_groups = [sg['GroupId'] for sg in instance['SecurityGroups']]
|
||||
subnet_id = instance['SubnetId']
|
||||
instance_type = instance['InstanceType']
|
||||
iam_instance_profile = instance.get('IamInstanceProfile', {}).get('Arn', '')
|
||||
|
||||
# Step 2: Launch a new instance from the snapshot
|
||||
logger.info(f"Launching a new instance from snapshot {snapshot_name}...")
|
||||
new_instance = ec2_client.run_instances(
|
||||
ImageId=snapshot_name,
|
||||
InstanceType=instance_type,
|
||||
SecurityGroupIds=security_groups,
|
||||
SubnetId=subnet_id,
|
||||
IamInstanceProfile={'Arn': iam_instance_profile} if iam_instance_profile else {},
|
||||
MinCount=1,
|
||||
MaxCount=1
|
||||
)
|
||||
new_instance_id = new_instance['Instances'][0]['InstanceId']
|
||||
logger.info(f"New instance {new_instance_id} launched from snapshot {snapshot_name}.")
|
||||
|
||||
# Step 3: Terminate the old instance
|
||||
ec2_client.terminate_instances(InstanceIds=[path_to_vm])
|
||||
logger.info(f"Old instance {path_to_vm} has been terminated.")
|
||||
|
||||
return new_instance_id
|
||||
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to revert to snapshot {snapshot_name} for the instance {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
|
||||
def stop_emulator(self, path_to_vm, region=None):
|
||||
logger.info(f"Stopping AWS VM {path_to_vm}...")
|
||||
ec2_client = boto3.client('ec2', region_name=self.region)
|
||||
|
||||
try:
|
||||
ec2_client.stop_instances(InstanceIds=[path_to_vm])
|
||||
waiter = ec2_client.get_waiter('instance_stopped')
|
||||
waiter.wait(InstanceIds=[path_to_vm], WaiterConfig={'Delay': WAIT_DELAY, 'MaxAttempts': MAX_ATTEMPTS})
|
||||
logger.info(f"Instance {path_to_vm} has been stopped.")
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to stop the AWS VM {path_to_vm}: {str(e)}")
|
||||
raise
|
||||
0
desktop_env/providers/azure/__init__.py
Normal file
0
desktop_env/providers/azure/__init__.py
Normal file
0
desktop_env/providers/azure/manager.py
Normal file
0
desktop_env/providers/azure/manager.py
Normal file
0
desktop_env/providers/azure/provider.py
Normal file
0
desktop_env/providers/azure/provider.py
Normal file
89
desktop_env/providers/base.py
Normal file
89
desktop_env/providers/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Provider(ABC):
|
||||
def __init__(self, region: str = None):
|
||||
"""
|
||||
Region of the cloud service.
|
||||
"""
|
||||
self.region = region
|
||||
|
||||
@abstractmethod
|
||||
def start_emulator(self, path_to_vm: str, headless: bool):
|
||||
"""
|
||||
Method to start the emulator.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_ip_address(self, path_to_vm: str) -> str:
|
||||
"""
|
||||
Method to get the private IP address of the VM. Private IP means inside the VPC.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_state(self, path_to_vm: str, snapshot_name: str):
|
||||
"""
|
||||
Method to save the state of the VM.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str) -> str:
|
||||
"""
|
||||
Method to revert the VM to a given snapshot.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_emulator(self, path_to_vm: str):
|
||||
"""
|
||||
Method to stop the emulator.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class VMManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def initialize_registry(self, **kwargs):
|
||||
"""
|
||||
Initialize registry.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_vm(self, vm_path, **kwargs):
|
||||
"""
|
||||
Add the path of new VM to the registration.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def occupy_vm(self, vm_path, pid, **kwargs):
|
||||
"""
|
||||
Mark the path of VM occupied by the pid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_free_vms(self, **kwargs):
|
||||
"""
|
||||
List the paths of VM that are free to use allocated.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_and_clean(self, **kwargs):
|
||||
"""
|
||||
Check the registration list, and remove the paths of VM that are not in use.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_vm_path(self, **kwargs):
|
||||
"""
|
||||
Get a virtual machine that is not occupied, generate a new one if no free VM.
|
||||
"""
|
||||
pass
|
||||
0
desktop_env/providers/gcp/__init__.py
Normal file
0
desktop_env/providers/gcp/__init__.py
Normal file
0
desktop_env/providers/gcp/manager.py
Normal file
0
desktop_env/providers/gcp/manager.py
Normal file
0
desktop_env/providers/gcp/provider.py
Normal file
0
desktop_env/providers/gcp/provider.py
Normal file
0
desktop_env/providers/virtualbox/__init__.py
Normal file
0
desktop_env/providers/virtualbox/__init__.py
Normal file
0
desktop_env/providers/virtualbox/manager.py
Normal file
0
desktop_env/providers/virtualbox/manager.py
Normal file
0
desktop_env/providers/virtualbox/provider.py
Normal file
0
desktop_env/providers/virtualbox/provider.py
Normal file
0
desktop_env/providers/vmware/__init__.py
Normal file
0
desktop_env/providers/vmware/__init__.py
Normal file
@@ -2,152 +2,44 @@ import os
|
||||
import platform
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import threading
|
||||
from filelock import FileLock
|
||||
import uuid
|
||||
import zipfile
|
||||
|
||||
from time import sleep
|
||||
import shutil
|
||||
import psutil
|
||||
import subprocess
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
__version__ = "0.1.15"
|
||||
import logging
|
||||
|
||||
from desktop_env.providers.base import VMManager
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.vmware.VMwareVMManager")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
MAX_RETRY_TIMES = 10
|
||||
UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_arm/resolve/main/Ubuntu.zip"
|
||||
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_x86/resolve/main/Ubuntu.zip"
|
||||
DOWNLOADED_FILE_NAME = "Ubuntu.zip"
|
||||
REGISTRY_PATH = '.vms'
|
||||
VMS_DIR = "./vm_data"
|
||||
REGISTRY_PATH = '.vmware_vms'
|
||||
VMS_DIR = "./vmware_vm_data"
|
||||
update_lock = threading.Lock()
|
||||
|
||||
|
||||
class VirtualMachineManager:
|
||||
def __init__(self, registry_path=REGISTRY_PATH):
|
||||
self.registry_path = registry_path
|
||||
self.lock = threading.Lock()
|
||||
self.initialize_registry()
|
||||
|
||||
def initialize_registry(self):
|
||||
with self.lock: # Locking during initialization
|
||||
if not os.path.exists(self.registry_path):
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.write('')
|
||||
|
||||
def add_vm(self, vm_path):
|
||||
with self.lock:
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
new_lines = lines + [f'{vm_path}|free\n']
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
def occupy_vm(self, vm_path, pid):
|
||||
with self.lock:
|
||||
new_lines = []
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
registered_vm_path, _ = line.strip().split('|')
|
||||
if registered_vm_path == vm_path:
|
||||
new_lines.append(f'{registered_vm_path}|{pid}\n')
|
||||
else:
|
||||
new_lines.append(line)
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
def release_vm(self, vm_path):
|
||||
with self.lock: # Lock when modifying the registry
|
||||
new_lines = []
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
registered_vm_path, _ = line.strip().split('|')
|
||||
if registered_vm_path != vm_path:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(f'{registered_vm_path}|free\n')
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
def check_and_clean(self, vms_dir):
|
||||
with self.lock: # Lock when cleaning up the registry and vms_dir
|
||||
|
||||
# Check and clean on the running vms, detect the released ones and mark then as 'free'
|
||||
active_pids = {p.pid for p in psutil.process_iter()}
|
||||
new_lines = []
|
||||
vm_paths = []
|
||||
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
vm_path, pid_str = line.strip().split('|')
|
||||
if not os.path.exists(vm_path):
|
||||
print(f"VM {vm_path} not found, releasing it.")
|
||||
new_lines.append(f'{vm_path}|free\n')
|
||||
continue
|
||||
|
||||
vm_paths.append(vm_path)
|
||||
if pid_str == "free":
|
||||
new_lines.append(line)
|
||||
continue
|
||||
|
||||
if int(pid_str) in active_pids:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(f'{vm_path}|free\n')
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
# Check and clean on the files inside vms_dir, delete the unregistered ones
|
||||
os.makedirs(vms_dir, exist_ok=True)
|
||||
vm_names = os.listdir(vms_dir)
|
||||
for vm_name in vm_names:
|
||||
# skip the downloaded .zip file
|
||||
if vm_name == DOWNLOADED_FILE_NAME:
|
||||
continue
|
||||
# Skip the .DS_Store file on macOS
|
||||
if vm_name == ".DS_Store":
|
||||
continue
|
||||
|
||||
flag = True
|
||||
for vm_path in vm_paths:
|
||||
if vm_name + ".vmx" in vm_path:
|
||||
flag = False
|
||||
if flag:
|
||||
shutil.rmtree(os.path.join(vms_dir, vm_name))
|
||||
|
||||
def list_vms(self):
|
||||
with self.lock: # Lock when reading the registry
|
||||
all_vms = []
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
vm_path, pid_str = line.strip().split('|')
|
||||
all_vms.append((vm_path, pid_str))
|
||||
return all_vms
|
||||
|
||||
def list_free_vms(self):
|
||||
with self.lock: # Lock when reading the registry
|
||||
free_vms = []
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
vm_path, pid_str = line.strip().split('|')
|
||||
if pid_str == "free":
|
||||
free_vms.append((vm_path, pid_str))
|
||||
return free_vms
|
||||
|
||||
def generate_new_vm_name(self, vms_dir):
|
||||
registry_idx = 0
|
||||
while True:
|
||||
attempted_new_name = f"Ubuntu{registry_idx}"
|
||||
if os.path.exists(
|
||||
os.path.join(vms_dir, attempted_new_name, attempted_new_name, attempted_new_name + ".vmx")):
|
||||
registry_idx += 1
|
||||
else:
|
||||
return attempted_new_name
|
||||
def generate_new_vm_name(vms_dir):
|
||||
registry_idx = 0
|
||||
while True:
|
||||
attempted_new_name = f"Ubuntu{registry_idx}"
|
||||
if os.path.exists(
|
||||
os.path.join(vms_dir, attempted_new_name, attempted_new_name, attempted_new_name + ".vmx")):
|
||||
registry_idx += 1
|
||||
else:
|
||||
return attempted_new_name
|
||||
|
||||
|
||||
def _update_vm(vmx_path, target_vm_name):
|
||||
@@ -186,7 +78,7 @@ def _update_vm(vmx_path, target_vm_name):
|
||||
with open(vmx_path, 'w') as file:
|
||||
file.write(updated_content)
|
||||
|
||||
print(".vmx file updated successfully.")
|
||||
logger.info(".vmx file updated successfully.")
|
||||
|
||||
vmx_file_base_name = os.path.splitext(vmx_file)[0]
|
||||
|
||||
@@ -205,10 +97,10 @@ def _update_vm(vmx_path, target_vm_name):
|
||||
target_dir_path = os.sep.join(path_parts)
|
||||
os.rename(dir_path, target_dir_path)
|
||||
|
||||
print("VM files renamed successfully.")
|
||||
logger.info("VM files renamed successfully.")
|
||||
|
||||
|
||||
def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu"):
|
||||
def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu"):
|
||||
os.makedirs(vms_dir, exist_ok=True)
|
||||
|
||||
def __download_and_unzip_vm():
|
||||
@@ -224,7 +116,7 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm
|
||||
raise Exception("Unsupported platform or architecture")
|
||||
|
||||
# Download the virtual machine image
|
||||
print("Downloading the virtual machine image...")
|
||||
logger.info("Downloading the virtual machine image...")
|
||||
downloaded_size = 0
|
||||
|
||||
while True:
|
||||
@@ -237,7 +129,7 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm
|
||||
with requests.get(url, headers=headers, stream=True) as response:
|
||||
if response.status_code == 416:
|
||||
# This means the range was not satisfiable, possibly the file was fully downloaded
|
||||
print("Fully downloaded or the file sized changed.")
|
||||
logger.info("Fully downloaded or the file sized changed.")
|
||||
break
|
||||
|
||||
response.raise_for_status()
|
||||
@@ -257,18 +149,18 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm
|
||||
size = file.write(data)
|
||||
progress_bar.update(size)
|
||||
except (requests.exceptions.RequestException, IOError) as e:
|
||||
print(f"Download error: {e}")
|
||||
logger.error(f"Download error: {e}")
|
||||
sleep(1) # Wait for 1 second before retrying
|
||||
print("Retrying...")
|
||||
logger.error("Retrying...")
|
||||
else:
|
||||
print("Download succeeds.")
|
||||
logger.info("Download succeeds.")
|
||||
break # Download completed successfully
|
||||
|
||||
# Unzip the downloaded file
|
||||
print("Unzipping the downloaded file...☕️")
|
||||
logger.info("Unzipping the downloaded file...☕️")
|
||||
with zipfile.ZipFile(downloaded_file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(os.path.join(vms_dir, vm_name))
|
||||
print("Files have been successfully extracted to the directory:", os.path.join(vms_dir, vm_name))
|
||||
logger.info("Files have been successfully extracted to the directory: " + str(os.path.join(vms_dir, vm_name)))
|
||||
|
||||
vm_path = os.path.join(vms_dir, vm_name, vm_name, vm_name + ".vmx")
|
||||
|
||||
@@ -277,7 +169,7 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm
|
||||
__download_and_unzip_vm()
|
||||
_update_vm(os.path.join(vms_dir, vm_name, original_vm_name, original_vm_name + ".vmx"), vm_name)
|
||||
else:
|
||||
print(f"Virtual machine exists: {vm_path}")
|
||||
logger.info(f"Virtual machine exists: {vm_path}")
|
||||
|
||||
# Determine the platform of the host machine and decide the parameter for vmrun
|
||||
def get_vmrun_type():
|
||||
@@ -294,16 +186,16 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm
|
||||
for attempt in range(max_retries):
|
||||
result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8")
|
||||
if result.returncode == 0:
|
||||
print("Virtual machine started.")
|
||||
logger.info("Virtual machine started.")
|
||||
return True
|
||||
else:
|
||||
if "Error" in result.stderr:
|
||||
print(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
|
||||
logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
|
||||
else:
|
||||
print(f"Attempt {attempt + 1} failed: {result.stderr}")
|
||||
logger.error(f"Attempt {attempt + 1} failed: {result.stderr}")
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
print("Maximum retry attempts reached, failed to start the virtual machine.")
|
||||
logger.error("Maximum retry attempts reached, failed to start the virtual machine.")
|
||||
return False
|
||||
|
||||
if not start_vm(vm_path):
|
||||
@@ -317,12 +209,12 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm
|
||||
return result.stdout.strip()
|
||||
else:
|
||||
if "Error" in result.stderr:
|
||||
print(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
|
||||
logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
|
||||
else:
|
||||
print(f"Attempt {attempt + 1} failed: {result.stderr}")
|
||||
logger.error(f"Attempt {attempt + 1} failed: {result.stderr}")
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
print("Maximum retry attempts reached, failed to get the IP of virtual machine.")
|
||||
logger.error("Maximum retry attempts reached, failed to get the IP of virtual machine.")
|
||||
return None
|
||||
|
||||
vm_ip = get_vm_ip(vm_path)
|
||||
@@ -338,33 +230,33 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Type: {type(e).__name__}")
|
||||
print(f"Error detail: {str(e)}")
|
||||
logger.error(f"Error: {e}")
|
||||
logger.error(f"Type: {type(e).__name__}")
|
||||
logger.error(f"Error detail: {str(e)}")
|
||||
sleep(2)
|
||||
return False
|
||||
|
||||
# Try downloading the screenshot until successful
|
||||
while not download_screenshot(vm_ip):
|
||||
print("Check whether the virtual machine is ready...")
|
||||
logger.info("Check whether the virtual machine is ready...")
|
||||
|
||||
print("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...")
|
||||
logger.info("Virtual machine is ready. Start to make a snapshot on the virtual machine. It would take a while...")
|
||||
|
||||
def create_vm_snapshot(vm_path, max_retries=20):
|
||||
command = f'vmrun {get_vmrun_type()} snapshot "{vm_path}" "init_state"'
|
||||
for attempt in range(max_retries):
|
||||
result = subprocess.run(command, shell=True, text=True, capture_output=True, encoding="utf-8")
|
||||
if result.returncode == 0:
|
||||
print("Snapshot created.")
|
||||
logger.info("Snapshot created.")
|
||||
return True
|
||||
else:
|
||||
if "Error" in result.stderr:
|
||||
print(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
|
||||
logger.error(f"Attempt {attempt + 1} failed with specific error: {result.stderr}")
|
||||
else:
|
||||
print(f"Attempt {attempt + 1} failed: {result.stderr}")
|
||||
logger.error(f"Attempt {attempt + 1} failed: {result.stderr}")
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
print("Maximum retry attempts reached, failed to create snapshot.")
|
||||
logger.error("Maximum retry attempts reached, failed to create snapshot.")
|
||||
return False
|
||||
|
||||
# Create a snapshot of the virtual machine
|
||||
@@ -374,20 +266,114 @@ def _install_virtual_machine(vm_name, vms_dir, downloaded_file_name, original_vm
|
||||
raise ValueError("Error encountered during installation, please rerun the code for retrying.")
|
||||
|
||||
|
||||
def _get_vm_path():
|
||||
vm_manager = VirtualMachineManager(REGISTRY_PATH)
|
||||
vm_manager.check_and_clean(vms_dir=VMS_DIR)
|
||||
free_vms_paths = vm_manager.list_free_vms()
|
||||
if len(free_vms_paths) == 0:
|
||||
# No free virtual machine available, generate a new one
|
||||
print("No free virtual machine available. Generating a new one, which would take a while...☕")
|
||||
new_vm_name = vm_manager.generate_new_vm_name(vms_dir=VMS_DIR)
|
||||
new_vm_path = _install_virtual_machine(new_vm_name, vms_dir=VMS_DIR, downloaded_file_name=DOWNLOADED_FILE_NAME)
|
||||
vm_manager.add_vm(new_vm_path)
|
||||
vm_manager.occupy_vm(new_vm_path, os.getpid())
|
||||
return new_vm_path
|
||||
else:
|
||||
# Choose the first free virtual machine
|
||||
chosen_vm_path = free_vms_paths[0][0]
|
||||
vm_manager.occupy_vm(chosen_vm_path, os.getpid())
|
||||
return chosen_vm_path
|
||||
class VMwareVMManager(VMManager):
|
||||
def __init__(self, registry_path=REGISTRY_PATH):
|
||||
self.registry_path = registry_path
|
||||
self.lock = FileLock(".vmware_lck", timeout=10)
|
||||
self.initialize_registry()
|
||||
|
||||
def initialize_registry(self):
|
||||
with self.lock: # Locking during initialization
|
||||
if not os.path.exists(self.registry_path):
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.write('')
|
||||
|
||||
def add_vm(self, vm_path, region=None):
|
||||
assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'."
|
||||
with self.lock:
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
new_lines = lines + [f'{vm_path}|free\n']
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
def occupy_vm(self, vm_path, pid, region=None):
|
||||
assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'."
|
||||
with self.lock:
|
||||
new_lines = []
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
registered_vm_path, _ = line.strip().split('|')
|
||||
if registered_vm_path == vm_path:
|
||||
new_lines.append(f'{registered_vm_path}|{pid}\n')
|
||||
else:
|
||||
new_lines.append(line)
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
def check_and_clean(self, vms_dir):
|
||||
with self.lock: # Lock when cleaning up the registry and vms_dir
|
||||
# Check and clean on the running vms, detect the released ones and mark then as 'free'
|
||||
active_pids = {p.pid for p in psutil.process_iter()}
|
||||
new_lines = []
|
||||
vm_paths = []
|
||||
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
vm_path, pid_str = line.strip().split('|')
|
||||
if not os.path.exists(vm_path):
|
||||
logger.info(f"VM {vm_path} not found, releasing it.")
|
||||
new_lines.append(f'{vm_path}|free\n')
|
||||
continue
|
||||
|
||||
vm_paths.append(vm_path)
|
||||
if pid_str == "free":
|
||||
new_lines.append(line)
|
||||
continue
|
||||
|
||||
if int(pid_str) in active_pids:
|
||||
new_lines.append(line)
|
||||
else:
|
||||
new_lines.append(f'{vm_path}|free\n')
|
||||
with open(self.registry_path, 'w') as file:
|
||||
file.writelines(new_lines)
|
||||
|
||||
# Check and clean on the files inside vms_dir, delete the unregistered ones
|
||||
os.makedirs(vms_dir, exist_ok=True)
|
||||
vm_names = os.listdir(vms_dir)
|
||||
for vm_name in vm_names:
|
||||
# skip the downloaded .zip file
|
||||
if vm_name == DOWNLOADED_FILE_NAME:
|
||||
continue
|
||||
# Skip the .DS_Store file on macOS
|
||||
if vm_name == ".DS_Store":
|
||||
continue
|
||||
|
||||
flag = True
|
||||
for vm_path in vm_paths:
|
||||
if vm_name + ".vmx" in vm_path:
|
||||
flag = False
|
||||
if flag:
|
||||
shutil.rmtree(os.path.join(vms_dir, vm_name))
|
||||
|
||||
def list_free_vms(self):
|
||||
with self.lock: # Lock when reading the registry
|
||||
free_vms = []
|
||||
with open(self.registry_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
vm_path, pid_str = line.strip().split('|')
|
||||
if pid_str == "free":
|
||||
free_vms.append((vm_path, pid_str))
|
||||
return free_vms
|
||||
|
||||
def get_vm_path(self, region=None):
|
||||
assert region in [None, 'local'], "For VMware provider, the region should be neither None or 'local'."
|
||||
self.check_and_clean(vms_dir=VMS_DIR)
|
||||
free_vms_paths = self.list_free_vms()
|
||||
if len(free_vms_paths) == 0:
|
||||
# No free virtual machine available, generate a new one
|
||||
logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
|
||||
new_vm_name = generate_new_vm_name(vms_dir=VMS_DIR)
|
||||
new_vm_path = _install_vm(new_vm_name, vms_dir=VMS_DIR,
|
||||
downloaded_file_name=DOWNLOADED_FILE_NAME)
|
||||
self.add_vm(new_vm_path)
|
||||
self.occupy_vm(new_vm_path, os.getpid())
|
||||
return new_vm_path
|
||||
else:
|
||||
# Choose the first free virtual machine
|
||||
chosen_vm_path = free_vms_paths[0][0]
|
||||
self.occupy_vm(chosen_vm_path, os.getpid())
|
||||
return chosen_vm_path
|
||||
90
desktop_env/providers/vmware/provider.py
Normal file
90
desktop_env/providers/vmware/provider.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
import time
|
||||
import os
|
||||
from desktop_env.providers.base import Provider
|
||||
|
||||
logger = logging.getLogger("desktopenv.providers.vmware.VMwareProvider")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
WAIT_TIME = 3
|
||||
|
||||
|
||||
def get_vmrun_type(return_list=False):
|
||||
if platform.system() == 'Windows' or platform.system() == 'Linux':
|
||||
if return_list:
|
||||
return ['-T', 'ws']
|
||||
else:
|
||||
return '-T ws'
|
||||
elif platform.system() == 'Darwin': # Darwin is the system name for macOS
|
||||
if return_list:
|
||||
return ['-T', 'fusion']
|
||||
else:
|
||||
return '-T fusion'
|
||||
else:
|
||||
raise Exception("Unsupported operating system")
|
||||
|
||||
|
||||
class VMwareProvider(Provider):
|
||||
@staticmethod
|
||||
def _execute_command(command: list):
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True,
|
||||
encoding="utf-8")
|
||||
if result.returncode != 0:
|
||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||
return result.stdout.strip()
|
||||
|
||||
def start_emulator(self, path_to_vm: str, headless: bool):
|
||||
print("Starting VMware VM...")
|
||||
logger.info("Starting VMware VM...")
|
||||
|
||||
while True:
|
||||
try:
|
||||
output = subprocess.check_output(f"vmrun {get_vmrun_type()} list", shell=True, stderr=subprocess.STDOUT)
|
||||
output = output.decode()
|
||||
output = output.splitlines()
|
||||
normalized_path_to_vm = os.path.abspath(os.path.normpath(path_to_vm))
|
||||
|
||||
if any(os.path.abspath(os.path.normpath(line)) == normalized_path_to_vm for line in output):
|
||||
logger.info("VM is running.")
|
||||
break
|
||||
else:
|
||||
logger.info("Starting VM...")
|
||||
VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["start", path_to_vm]) if not headless else \
|
||||
VMwareProvider._execute_command(
|
||||
["vmrun"] + get_vmrun_type(return_list=True) + ["start", path_to_vm, "nogui"])
|
||||
time.sleep(WAIT_TIME)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Error executing command: {e.output.decode().strip()}")
|
||||
|
||||
def get_ip_address(self, path_to_vm: str) -> str:
|
||||
logger.info("Getting VMware VM IP address...")
|
||||
while True:
|
||||
try:
|
||||
output = VMwareProvider._execute_command(
|
||||
["vmrun"] + get_vmrun_type(return_list=True) + ["getGuestIPAddress", path_to_vm, "-wait"]
|
||||
)
|
||||
logger.info(f"VMware VM IP address: {output}")
|
||||
return output
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
time.sleep(WAIT_TIME)
|
||||
logger.info("Retrying to get VMware VM IP address...")
|
||||
|
||||
def save_state(self, path_to_vm: str, snapshot_name: str):
|
||||
logger.info("Saving VMware VM state...")
|
||||
VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["snapshot", path_to_vm, snapshot_name])
|
||||
time.sleep(WAIT_TIME) # Wait for the VM to save
|
||||
|
||||
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
|
||||
logger.info(f"Reverting VMware VM to snapshot: {snapshot_name}...")
|
||||
VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["revertToSnapshot", path_to_vm, snapshot_name])
|
||||
time.sleep(WAIT_TIME) # Wait for the VM to revert
|
||||
return path_to_vm
|
||||
|
||||
def stop_emulator(self, path_to_vm: str):
|
||||
logger.info("Stopping VMware VM...")
|
||||
VMwareProvider._execute_command(["vmrun"] + get_vmrun_type(return_list=True) + ["stop", path_to_vm])
|
||||
time.sleep(WAIT_TIME) # Wait for the VM to stop
|
||||
@@ -1,8 +1,61 @@
|
||||
# Server setup
|
||||
|
||||
This README is useful if you want to set up your own machine for the environment. This README is not yet finished. Please contact the author if you need any assistance.
|
||||
|
||||
## Set up the OSWorld server service in VM
|
||||
|
||||
1. First please set up the environment:
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
if you customize the environment in this step, you should change the parameters in the service file we will mention later accordingly.
|
||||
|
||||
2. Copy the `main.py` and `pyxcursor.py` and to the `/home/user-name` where the `user-name` is your username of the ubuntu, here we make it `user` as default. If you customize the path of placing these files in this step, you should change the parameters in the service file we will mention later accordingly.
|
||||
|
||||
3. Copy the `osworld_server.service` to the systemd configuration directory at `/etc/systemd/system/`:
|
||||
```shell
|
||||
sudo cp osworld_server.service /etc/systemd/system/
|
||||
```
|
||||
|
||||
Reload the systemd daemon to recognize the new service:
|
||||
```shell
|
||||
sudo systemctl daemon-reload
|
||||
```
|
||||
|
||||
Enable the service to start on boot:
|
||||
```shell
|
||||
sudo systemctl enable osworld_server.service
|
||||
```
|
||||
|
||||
Start the service:
|
||||
```shell
|
||||
sudo systemctl start osworld_server.service
|
||||
```
|
||||
|
||||
Verify the service is running correctly:
|
||||
```shell
|
||||
sudo systemctl status osworld_server.service
|
||||
```
|
||||
|
||||
You should see output indicating the service is active and running. If there are errors, review the logs with `journalctl -xe` for further troubleshooting.
|
||||
|
||||
If you need to make adjustments to the service configuration, you can edit the `/etc/systemd/system/osworld_server.service` file:
|
||||
```shell
|
||||
sudo nano /etc/systemd/system/osworld_server.service
|
||||
```
|
||||
|
||||
After making changes, reload the daemon and restart the service:
|
||||
```shell
|
||||
sudo systemctl
|
||||
```
|
||||
|
||||
<!-- vimc: call SyntaxRange#Include('```xml', '```', 'xml', 'NonText'): -->
|
||||
<!-- vimc: call SyntaxRange#Include('```css', '```', 'css', 'NonText'): -->
|
||||
<!-- vimc: call SyntaxRange#Include('```sh', '```', 'sh', 'NonText'): -->
|
||||
<!-- vimc: call SyntaxRange#Include('```bash', '```', 'sh', 'NonText'): -->
|
||||
|
||||
## Others
|
||||
|
||||
### About the Converted Accessibility Tree
|
||||
|
||||
For several applications like Firefox or Thunderbird, you should first enable
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
[Unit]
|
||||
Description=OSBench Server
|
||||
StartLimitIntervalSec=60
|
||||
StartLimitBurst=4
|
||||
After=network.target auditd.service
|
||||
|
||||
[Service]
|
||||
ExecStart=/usr/bin/python3 /home/user/main.py
|
||||
User=user
|
||||
WorkingDirectory=/home/user
|
||||
Restart=on-failure
|
||||
RestartSec=1
|
||||
Environment="DISPLAY=%i"
|
||||
|
||||
[Install]
|
||||
WantedBy=graphical.target
|
||||
@@ -1,5 +1,5 @@
|
||||
[Unit]
|
||||
Description=OSBench Server
|
||||
Description=OSWorld Server
|
||||
StartLimitIntervalSec=60
|
||||
StartLimitBurst=4
|
||||
After=network.target auditd.service
|
||||
2
main.py
2
main.py
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
|
||||
3
run.py
3
run.py
@@ -6,14 +6,13 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
# import wandb
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import lib_run_single
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.agent import PromptAgent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
|
||||
Reference in New Issue
Block a user