Merge remote-tracking branch 'origin/main'

This commit is contained in:
Timothyxxx
2024-05-21 21:08:43 +08:00
2 changed files with 6 additions and 0 deletions

View File

@@ -15,6 +15,7 @@ import dashscope
import google.generativeai as genai import google.generativeai as genai
import openai import openai
import requests import requests
from requests.exceptions import SSLError
import tiktoken import tiktoken
from PIL import Image from PIL import Image
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
@@ -518,6 +519,9 @@ class PromptAgent:
# but you are forbidden to add "Exception", that is, a common type of exception # but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit # because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
( (
# General exceptions
SSLError,
# OpenAI exceptions # OpenAI exceptions
openai.RateLimitError, openai.RateLimitError,
openai.BadRequestError, openai.BadRequestError,

2
run.py
View File

@@ -132,6 +132,8 @@ def test(
agent = PromptAgent( agent = PromptAgent(
model=args.model, model=args.model,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
action_space=args.action_space, action_space=args.action_space,
observation_type=args.observation_type, observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length, max_trajectory_length=args.max_trajectory_length,