diff --git a/gradio/external_utils.py b/gradio/external_utils.py index 8b69721aa9..28af83474a 100644 --- a/gradio/external_utils.py +++ b/gradio/external_utils.py @@ -16,6 +16,7 @@ from gradio import components from gradio.exceptions import Error, TooManyRequestsError +from typing import Any def get_model_info(model_name, hf_token=None): @@ -150,28 +151,29 @@ def chat_fn(message, history): def encode_to_base64(r: httpx.Response) -> str: # Handles the different ways HF API returns the prediction - base64_repr = base64.b64encode(r.content).decode("utf-8") + # Handles the different ways HF API returns the prediction + # Use a function-local variable for the frequently used data prefix data_prefix = ";base64," + # Fast-path: if content-type header says json, try that branch directly to avoid wasteful base64 encoding (big win for most JSON API responses) + content_type = r.headers.get("content-type") + if content_type == "application/json": + try: + # Only decode once + data: Any = r.json()[0] + content_type = data["content-type"] + base64_repr = data["blob"] + except KeyError as ke: + raise ValueError( + "Cannot determine content type returned by external API." + ) from ke + return f"data:{content_type}{data_prefix}{base64_repr}" + # Otherwise: Non-JSON direct response. Do base64 encoding. + base64_repr = base64.b64encode(r.content).decode("utf-8") # Case 1: base64 representation already includes data prefix if data_prefix in base64_repr: return base64_repr - else: - content_type = r.headers.get("content-type") - # Case 2: the data prefix is a key in the response - if content_type == "application/json": - try: - data = r.json()[0] - content_type = data["content-type"] - base64_repr = data["blob"] - except KeyError as ke: - raise ValueError( - "Cannot determine content type returned by external API." - ) from ke - # Case 3: the data prefix is included in the response headers - else: - pass - new_base64 = f"data:{content_type};base64,{base64_repr}" - return new_base64 + # If not, must synthesize the prefix + return f"data:{content_type}{data_prefix}{base64_repr}" def format_ner_list(input_string: str, ner_groups: list[dict[str, str | int]]):