From 6d7fb3d78685effce8086960e93552b12394f45a Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Sat, 29 Apr 2023 11:46:28 -0700 Subject: [PATCH] raise content_filter error (#1018) * raise content_filter error * import error handling --- flaml/autogen/oai/completion.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/flaml/autogen/oai/completion.py b/flaml/autogen/oai/completion.py index 6bce0e171..c92bd4819 100644 --- a/flaml/autogen/oai/completion.py +++ b/flaml/autogen/oai/completion.py @@ -17,11 +17,13 @@ try: APIConnectionError, Timeout, ) + from openai import Completion as openai_Completion import diskcache ERROR = None except ImportError: ERROR = ImportError("please install flaml[openai] option to use the flaml.oai subpackage.") + openai_Completion = object logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -46,7 +48,7 @@ def get_key(config): return config -class Completion: +class Completion(openai_Completion): """A class for OpenAI completion API. It also supports: ChatCompletion, Azure OpenAI API. @@ -151,26 +153,33 @@ class Completion: response = openai_completion.create(request_timeout=request_timeout, **config) except ( ServiceUnavailableError, - APIError, APIConnectionError, ): # transient error logger.warning(f"retrying in {cls.retry_time} seconds...", exc_info=1) sleep(cls.retry_time) - except (RateLimitError, Timeout) as e: + except APIError as err: + error_code = err and err.json_body and err.json_body.get("error") + error_code = error_code and error_code.get("code") + if error_code == "content_filter": + raise + # transient error + logger.warning(f"retrying in {cls.retry_time} seconds...", exc_info=1) + sleep(cls.retry_time) + except (RateLimitError, Timeout) as err: time_left = cls.retry_timeout - (time.time() - start_time + cls.retry_time) if ( time_left > 0 - and isinstance(e, RateLimitError) + and isinstance(err, RateLimitError) or time_left > request_timeout - and isinstance(e, Timeout) + and isinstance(err, Timeout) ): logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1) elif eval_only: raise else: break - if isinstance(e, Timeout): + if isinstance(err, Timeout): if "request_timeout" in config: raise request_timeout <<= 1