|
29 | 29 | from litellm.proxy._types import UserAPIKeyAuth |
30 | 30 | from litellm.proxy.auth.auth_utils import get_model_rate_limit_from_metadata |
31 | 31 | from litellm.types.llms.openai import BaseLiteLLMOpenAIResponseObject |
| 32 | +from litellm.types.utils import ModelResponse, Usage |
32 | 33 |
|
33 | 34 | if TYPE_CHECKING: |
34 | 35 | from opentelemetry.trace import Span as _Span |
@@ -1232,6 +1233,28 @@ def _create_pipeline_operations( |
1232 | 1233 |
|
1233 | 1234 | return pipeline_operations |
1234 | 1235 |
|
| 1236 | + def _get_total_tokens_from_usage(self, usage: Any | None, rate_limit_type: Literal["output", "input", "total"]) -> int: |
| 1237 | + # Get total tokens from response |
| 1238 | + total_tokens = 0 |
| 1239 | + # spot fix for /responses api |
| 1240 | + if usage: |
| 1241 | + if isinstance(usage, Usage): |
| 1242 | + if rate_limit_type == "output": |
| 1243 | + total_tokens = usage.completion_tokens |
| 1244 | + elif rate_limit_type == "input": |
| 1245 | + total_tokens = usage.prompt_tokens |
| 1246 | + elif rate_limit_type == "total": |
| 1247 | + total_tokens = usage.total_tokens |
| 1248 | + elif isinstance(usage, dict): |
| 1249 | + # Responses API usage comes as a dict in ResponsesAPIResponse |
| 1250 | + if rate_limit_type == "output": |
| 1251 | + total_tokens = usage.get("completion_tokens", 0) |
| 1252 | + elif rate_limit_type == "input": |
| 1253 | + total_tokens = usage.get("prompt_tokens", 0) |
| 1254 | + elif rate_limit_type == "total": |
| 1255 | + total_tokens = usage.get("total_tokens", 0) |
| 1256 | + return total_tokens |
| 1257 | + |
1235 | 1258 | async def _execute_token_increment_script( |
1236 | 1259 | self, |
1237 | 1260 | pipeline_operations: List["RedisPipelineIncrementOperation"], |
@@ -1335,7 +1358,6 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti |
1335 | 1358 | get_model_group_from_litellm_kwargs, |
1336 | 1359 | ) |
1337 | 1360 | from litellm.types.caching import RedisPipelineIncrementOperation |
1338 | | - from litellm.types.utils import ModelResponse, Usage |
1339 | 1361 |
|
1340 | 1362 | rate_limit_type = self.get_rate_limit_type() |
1341 | 1363 |
|
@@ -1371,22 +1393,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti |
1371 | 1393 | response_obj, BaseLiteLLMOpenAIResponseObject |
1372 | 1394 | ): |
1373 | 1395 | _usage = getattr(response_obj, "usage", None) |
1374 | | - if _usage: |
1375 | | - if isinstance(_usage, Usage): |
1376 | | - if rate_limit_type == "output": |
1377 | | - total_tokens = _usage.completion_tokens |
1378 | | - elif rate_limit_type == "input": |
1379 | | - total_tokens = _usage.prompt_tokens |
1380 | | - elif rate_limit_type == "total": |
1381 | | - total_tokens = _usage.total_tokens |
1382 | | - elif isinstance(_usage, dict): |
1383 | | - # Responses API usage comes as a dict in ResponsesAPIResponse |
1384 | | - if rate_limit_type == "output": |
1385 | | - total_tokens = _usage.get("completion_tokens", 0) |
1386 | | - elif rate_limit_type == "input": |
1387 | | - total_tokens = _usage.get("prompt_tokens", 0) |
1388 | | - elif rate_limit_type == "total": |
1389 | | - total_tokens = _usage.get("total_tokens", 0) |
| 1396 | + total_tokens = self._get_total_tokens_from_usage(usage=_usage, rate_limit_type=rate_limit_type) |
1390 | 1397 |
|
1391 | 1398 | # Create pipeline operations for TPM increments |
1392 | 1399 | pipeline_operations: List[RedisPipelineIncrementOperation] = [] |
|
0 commit comments