Skip to content

Commit 7bdf52c

Browse files
committed
fix linting error
1 parent 9e85dcb commit 7bdf52c

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

litellm/proxy/hooks/parallel_request_limiter_v3.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from litellm.proxy._types import UserAPIKeyAuth
3030
from litellm.proxy.auth.auth_utils import get_model_rate_limit_from_metadata
3131
from litellm.types.llms.openai import BaseLiteLLMOpenAIResponseObject
32+
from litellm.types.utils import ModelResponse, Usage
3233

3334
if TYPE_CHECKING:
3435
from opentelemetry.trace import Span as _Span
@@ -1232,6 +1233,28 @@ def _create_pipeline_operations(
12321233

12331234
return pipeline_operations
12341235

1236+
def _get_total_tokens_from_usage(self, usage: Any | None, rate_limit_type: Literal["output", "input", "total"]) -> int:
1237+
# Get total tokens from response
1238+
total_tokens = 0
1239+
# spot fix for /responses api
1240+
if usage:
1241+
if isinstance(usage, Usage):
1242+
if rate_limit_type == "output":
1243+
total_tokens = usage.completion_tokens
1244+
elif rate_limit_type == "input":
1245+
total_tokens = usage.prompt_tokens
1246+
elif rate_limit_type == "total":
1247+
total_tokens = usage.total_tokens
1248+
elif isinstance(usage, dict):
1249+
# Responses API usage comes as a dict in ResponsesAPIResponse
1250+
if rate_limit_type == "output":
1251+
total_tokens = usage.get("completion_tokens", 0)
1252+
elif rate_limit_type == "input":
1253+
total_tokens = usage.get("prompt_tokens", 0)
1254+
elif rate_limit_type == "total":
1255+
total_tokens = usage.get("total_tokens", 0)
1256+
return total_tokens
1257+
12351258
async def _execute_token_increment_script(
12361259
self,
12371260
pipeline_operations: List["RedisPipelineIncrementOperation"],
@@ -1335,7 +1358,6 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
13351358
get_model_group_from_litellm_kwargs,
13361359
)
13371360
from litellm.types.caching import RedisPipelineIncrementOperation
1338-
from litellm.types.utils import ModelResponse, Usage
13391361

13401362
rate_limit_type = self.get_rate_limit_type()
13411363

@@ -1371,22 +1393,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
13711393
response_obj, BaseLiteLLMOpenAIResponseObject
13721394
):
13731395
_usage = getattr(response_obj, "usage", None)
1374-
if _usage:
1375-
if isinstance(_usage, Usage):
1376-
if rate_limit_type == "output":
1377-
total_tokens = _usage.completion_tokens
1378-
elif rate_limit_type == "input":
1379-
total_tokens = _usage.prompt_tokens
1380-
elif rate_limit_type == "total":
1381-
total_tokens = _usage.total_tokens
1382-
elif isinstance(_usage, dict):
1383-
# Responses API usage comes as a dict in ResponsesAPIResponse
1384-
if rate_limit_type == "output":
1385-
total_tokens = _usage.get("completion_tokens", 0)
1386-
elif rate_limit_type == "input":
1387-
total_tokens = _usage.get("prompt_tokens", 0)
1388-
elif rate_limit_type == "total":
1389-
total_tokens = _usage.get("total_tokens", 0)
1396+
total_tokens = self._get_total_tokens_from_usage(usage=_usage, rate_limit_type=rate_limit_type)
13901397

13911398
# Create pipeline operations for TPM increments
13921399
pipeline_operations: List[RedisPipelineIncrementOperation] = []

0 commit comments

Comments
 (0)