* multi_env

* multi_env

---------

Co-authored-by: Timothyxxx <384084775@qq.com>
This commit is contained in:
Dunjie Lu
2024-11-02 22:28:23 +08:00
committed by GitHub
parent 3933e0d303
commit 8be2a40967
7 changed files with 493 additions and 42 deletions

View File

@@ -26,7 +26,7 @@ class DesktopEnv(gym.Env):
def __init__(
self,
provider_name: str = "vmware",
provider_name: str = "docker",
region: str = None,
path_to_vm: str = None,
snapshot_name: str = "init_state",
@@ -36,7 +36,7 @@ class DesktopEnv(gym.Env):
headless: bool = False,
require_a11y_tree: bool = True,
require_terminal: bool = False,
os_type: str = "Ubuntu",
os_type: str = "Windows",
):
"""
Args:
@@ -60,6 +60,18 @@ class DesktopEnv(gym.Env):
self.chromium_port = 9222
self.vnc_port = 8006
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
# self.server_port = server_port or 5000
# self.chromium_port = chromium_port or 9222
# self.vnc_port = vnc_port or 8006
# # Initialize provider with custom ports
# self.manager, self.provider = create_vm_manager_and_provider(
# provider_name,
# region,
# vnc_port=self.vnc_port,
# server_port=self.server_port,
# chromium_port=self.chromium_port
# )
self.os_type = os_type

View File

@@ -1,5 +1,6 @@
from desktop_env.providers.base import VMManager, Provider
# def create_vm_manager_and_provider(provider_name: str, region: str, vnc_port: int = None, server_port: int = None, chromium_port: int = None):
def create_vm_manager_and_provider(provider_name: str, region: str):
"""
Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name.
@@ -24,6 +25,7 @@ def create_vm_manager_and_provider(provider_name: str, region: str):
elif provider_name == "docker":
from desktop_env.providers.docker.manager import DockerVMManager
from desktop_env.providers.docker.provider import DockerProvider
# return DockerVMManager(), DockerProvider(region, vnc_port, server_port, chromium_port)
return DockerVMManager(), DockerProvider(region)
else:
raise NotImplementedError(f"{provider_name} not implemented!")

View File

@@ -30,58 +30,128 @@ class DockerProvider(Provider):
self.chromium_port = None
self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed
temp_dir = Path(os.getenv('TEMP' if platform.system() == 'Windows' else '/tmp'))
# temp_dir = Path(os.getenv('TEMP' if platform.system() == 'Windows' else '/tmp'))
temp_dir = Path(os.getenv('TEMP') if platform.system() == 'Windows' else '/tmp')
self.lock_file = temp_dir / "docker_port_allocation.lck"
self.lock_file.parent.mkdir(parents=True, exist_ok=True)
def _get_available_port(self, port: int, lock_file: Path = None):
if lock_file is None:
lock_file = self.lock_file
lock = FileLock(str(lock_file), timeout=LOCK_TIMEOUT)
with lock:
while port < 65354:
if port not in [conn.laddr.port for conn in psutil.net_connections()]:
return port
port += 1
def _get_used_ports(self):
"""Get all currently used ports (both system and Docker)."""
# Get system ports
system_ports = set(conn.laddr.port for conn in psutil.net_connections())
# Get Docker container ports
docker_ports = set()
for container in self.client.containers.list():
ports = container.attrs['NetworkSettings']['Ports']
if ports:
for port_mappings in ports.values():
if port_mappings:
docker_ports.update(int(p['HostPort']) for p in port_mappings)
return system_ports | docker_ports
def _get_available_port(self, start_port: int) -> int:
"""Find next available port starting from start_port."""
used_ports = self._get_used_ports()
port = start_port
while port < 65354:
if port not in used_ports:
return port
port += 1
raise PortAllocationError(f"No available ports found starting from {start_port}")
def _wait_for_vm_ready(self, timeout: int = 300):
"""Wait for VM to be ready by checking screenshot endpoint."""
start_time = time.time()
def check_screenshot():
try:
response = requests.get(
f"http://localhost:{self.server_port}/screenshot",
timeout=(10, 10)
)
return response.status_code == 200
except Exception:
return False
while time.time() - start_time < timeout:
if check_screenshot():
return True
logger.info("Checking if virtual machine is ready...")
time.sleep(RETRY_INTERVAL)
raise TimeoutError("VM failed to become ready within timeout period")
def start_emulator(self, path_to_vm: str, headless: bool, os_type: str):
self.vnc_port = self._get_available_port(8006)
self.server_port = self._get_available_port(5000)
# self.remote_debugging_port = self._get_available_port(1337)
self.chromium_port = self._get_available_port(9222)
logger.info(f"Occupying ports: {self.vnc_port}, {self.server_port}, {self.chromium_port}")
self.container = self.client.containers.run("happysixd/osworld-docker", environment=self.environment,
cap_add=["NET_ADMIN"], devices=["/dev/kvm"], volumes={
os.path.abspath(path_to_vm): {"bind": "/System.qcow2", "mode": "ro"}},
ports={8006: self.vnc_port, 5000: self.server_port,
9222: self.chromium_port}, detach=True)
# Use a single lock for all port allocation and container startup
lock = FileLock(str(self.lock_file), timeout=LOCK_TIMEOUT)
try:
with lock:
# Allocate all required ports
self.vnc_port = self._get_available_port(8006)
self.server_port = self._get_available_port(5000)
self.chromium_port = self._get_available_port(9222)
def download_screenshot(ip, port):
url = f"http://{ip}:{port}/screenshot"
try:
# max trey times 1, max timeout 1
response = requests.get(url, timeout=(10, 10))
if response.status_code == 200:
return True
except Exception as e:
time.sleep(RETRY_INTERVAL)
return False
# Start container while still holding the lock
self.container = self.client.containers.run(
"happysixd/osworld-docker",
environment=self.environment,
cap_add=["NET_ADMIN"],
devices=["/dev/kvm"],
volumes={
os.path.abspath(path_to_vm): {
"bind": "/System.qcow2",
"mode": "ro"
}
},
ports={
8006: self.vnc_port,
5000: self.server_port,
9222: self.chromium_port
},
detach=True
)
# Try downloading the screenshot until successful
while not download_screenshot("localhost", self.server_port):
logger.info("Check whether the virtual machine is ready...")
logger.info(f"Started container with ports - VNC: {self.vnc_port}, "
f"Server: {self.server_port}, Chrome: {self.chromium_port}")
# Wait for VM to be ready
self._wait_for_vm_ready()
except Exception as e:
# Clean up if anything goes wrong
if self.container:
try:
self.container.stop()
self.container.remove()
except:
pass
raise e
def get_ip_address(self, path_to_vm: str) -> str:
if not all([self.server_port, self.chromium_port, self.vnc_port]):
raise RuntimeError("VM not started - ports not allocated")
return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}"
def save_state(self, path_to_vm: str, snapshot_name: str):
raise NotImplementedError("Not available for Docker.")
raise NotImplementedError("Snapshots not available for Docker provider")
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
pass
def stop_emulator(self, path_to_vm: str):
logger.info("Stopping VM...")
self.container.stop()
self.container.remove()
time.sleep(WAIT_TIME)
if self.container:
logger.info("Stopping VM...")
try:
self.container.stop()
self.container.remove()
time.sleep(WAIT_TIME)
except Exception as e:
logger.error(f"Error stopping container: {e}")
finally:
self.container = None
self.server_port = None
self.vnc_port = None
self.chromium_port = None