diff --git a/README.md b/README.md index 0a52673..d3a9b39 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,11 @@ Set **OPENAI_API_KEY** environment variable with your API key export OPENAI_API_KEY='changeme' ``` +Optionally, set **OPENAI_BASE_URL** to use a custom OpenAI-compatible API endpoint +```bash +export OPENAI_BASE_URL='http://your-custom-endpoint.com/v1' # Optional: defaults to https://api.openai.com +``` + ```bash python run.py --path_to_vm Ubuntu/Ubuntu.vmx --headless --observation_type screenshot --model gpt-4-vision-preview --result_dir ./results ``` diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 1eefbba..35752b7 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -615,13 +615,17 @@ class PromptAgent: else: return response.json()['choices'][0]['message']['content'] elif self.model.startswith("gpt"): + # Support custom OpenAI base URL via environment variable + base_url = os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com') + # Smart handling: avoid duplicate /v1 if base_url already ends with /v1 + api_url = f"{base_url}/chat/completions" if base_url.endswith('/v1') else f"{base_url}/v1/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" } logger.info("Generating content with GPT model: %s", self.model) response = requests.post( - "https://api.openai.com/v1/chat/completions", + api_url, headers=headers, json=payload ) @@ -631,7 +635,7 @@ class PromptAgent: logger.error("Context length exceeded. Retrying with a smaller context.") payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:] retry_response = requests.post( - "https://api.openai.com/v1/chat/completions", + api_url, headers=headers, json=payload )