Skip to content

Commit 2ea855d

Browse files
authored
Merge pull request #17707 from raghav-stripe/raghav-fix-responsesapi-rl
fix: responses api not applying tpm rate limits on api keys
2 parents 9d7a255 + face817 commit 2ea855d

File tree

3 files changed

+255
-15
lines changed

3 files changed

+255
-15
lines changed

litellm/proxy/hooks/parallel_request_limiter_v3.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from litellm.proxy._types import UserAPIKeyAuth
3030
from litellm.proxy.auth.auth_utils import get_model_rate_limit_from_metadata
3131
from litellm.types.llms.openai import BaseLiteLLMOpenAIResponseObject
32+
from litellm.types.utils import ModelResponse, Usage
3233

3334
if TYPE_CHECKING:
3435
from opentelemetry.trace import Span as _Span
@@ -1232,6 +1233,28 @@ def _create_pipeline_operations(
12321233

12331234
return pipeline_operations
12341235

1236+
def _get_total_tokens_from_usage(self, usage: Any | None, rate_limit_type: Literal["output", "input", "total"]) -> int:
1237+
# Get total tokens from response
1238+
total_tokens = 0
1239+
# spot fix for /responses api
1240+
if usage:
1241+
if isinstance(usage, Usage):
1242+
if rate_limit_type == "output":
1243+
total_tokens = usage.completion_tokens
1244+
elif rate_limit_type == "input":
1245+
total_tokens = usage.prompt_tokens
1246+
elif rate_limit_type == "total":
1247+
total_tokens = usage.total_tokens
1248+
elif isinstance(usage, dict):
1249+
# Responses API usage comes as a dict in ResponsesAPIResponse
1250+
if rate_limit_type == "output":
1251+
total_tokens = usage.get("completion_tokens", 0)
1252+
elif rate_limit_type == "input":
1253+
total_tokens = usage.get("prompt_tokens", 0)
1254+
elif rate_limit_type == "total":
1255+
total_tokens = usage.get("total_tokens", 0)
1256+
return total_tokens
1257+
12351258
async def _execute_token_increment_script(
12361259
self,
12371260
pipeline_operations: List["RedisPipelineIncrementOperation"],
@@ -1313,11 +1336,10 @@ async def async_increment_tokens_with_ttl_preservation(
13131336

13141337
def get_rate_limit_type(self) -> Literal["output", "input", "total"]:
13151338
from litellm.proxy.proxy_server import general_settings
1316-
13171339
specified_rate_limit_type = general_settings.get(
1318-
"token_rate_limit_type", "output"
1340+
"token_rate_limit_type", "total"
13191341
)
1320-
if not specified_rate_limit_type or specified_rate_limit_type not in [
1342+
if specified_rate_limit_type not in [
13211343
"output",
13221344
"input",
13231345
"total",
@@ -1336,7 +1358,6 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
13361358
get_model_group_from_litellm_kwargs,
13371359
)
13381360
from litellm.types.caching import RedisPipelineIncrementOperation
1339-
from litellm.types.utils import ModelResponse, Usage
13401361

13411362
rate_limit_type = self.get_rate_limit_type()
13421363

@@ -1372,13 +1393,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
13721393
response_obj, BaseLiteLLMOpenAIResponseObject
13731394
):
13741395
_usage = getattr(response_obj, "usage", None)
1375-
if _usage and isinstance(_usage, Usage):
1376-
if rate_limit_type == "output":
1377-
total_tokens = _usage.completion_tokens
1378-
elif rate_limit_type == "input":
1379-
total_tokens = _usage.prompt_tokens
1380-
elif rate_limit_type == "total":
1381-
total_tokens = _usage.total_tokens
1396+
total_tokens = self._get_total_tokens_from_usage(usage=_usage, rate_limit_type=rate_limit_type)
13821397

13831398
# Create pipeline operations for TPM increments
13841399
pipeline_operations: List[RedisPipelineIncrementOperation] = []

tests/test_litellm/proxy/hooks/test_dynamic_rate_limiter_v3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,13 +1403,13 @@ async def mock_increment(pipeline_operations, parent_otel_span=None):
14031403
end_time=None,
14041404
)
14051405

1406-
# Verify increments happened with actual token count (50 completion tokens)
1406+
# Verify increments happened with actual token count (60 total tokens)
14071407
assert len(increment_calls) == 2, f"Expected 2 increment calls, got {len(increment_calls)}"
14081408

1409-
# Both should increment by 50 (completion_tokens, since rate_limit_type defaults to 'output')
1409+
# Both should increment by 50 (total_tokens, since rate_limit_type defaults to 'total')
14101410
for call in increment_calls:
1411-
assert call["increment_value"] == 50, (
1412-
f"Expected increment of 50 tokens, got {call['increment_value']} for key {call['key']}"
1411+
assert call["increment_value"] == 60, (
1412+
f"Expected increment of 60 tokens, got {call['increment_value']} for key {call['key']}"
14131413
)
14141414

14151415
# Verify correct keys were used

tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,231 @@ async def mock_should_rate_limit(descriptors, **kwargs):
15831583
assert "Current limit: 2" in exc_info.value.detail
15841584

15851585

1586+
@pytest.mark.asyncio
1587+
async def test_get_rate_limit_type_default_is_total(monkeypatch):
1588+
"""
1589+
Test that get_rate_limit_type returns 'total' as the default when no setting is specified.
1590+
1591+
This verifies the change from 'output' to 'total' as the default value.
1592+
"""
1593+
local_cache = DualCache()
1594+
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
1595+
internal_usage_cache=InternalUsageCache(local_cache)
1596+
)
1597+
1598+
# Mock general_settings to return empty dict (no token_rate_limit_type set)
1599+
import litellm.proxy.proxy_server as proxy_server
1600+
original_settings = getattr(proxy_server, 'general_settings', {})
1601+
monkeypatch.setattr(proxy_server, 'general_settings', {})
1602+
1603+
try:
1604+
result = parallel_request_handler.get_rate_limit_type()
1605+
assert result == "total", f"Default rate limit type should be 'total', got '{result}'"
1606+
finally:
1607+
monkeypatch.setattr(proxy_server, 'general_settings', original_settings)
1608+
1609+
1610+
@pytest.mark.asyncio
1611+
async def test_get_rate_limit_type_invalid_falls_back_to_total(monkeypatch):
1612+
"""
1613+
Test that get_rate_limit_type falls back to 'total' when an invalid value is specified.
1614+
"""
1615+
local_cache = DualCache()
1616+
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
1617+
internal_usage_cache=InternalUsageCache(local_cache)
1618+
)
1619+
1620+
# Mock general_settings to return an invalid token_rate_limit_type
1621+
import litellm.proxy.proxy_server as proxy_server
1622+
original_settings = getattr(proxy_server, 'general_settings', {})
1623+
monkeypatch.setattr(proxy_server, 'general_settings', {'token_rate_limit_type': 'invalid_type'})
1624+
1625+
try:
1626+
result = parallel_request_handler.get_rate_limit_type()
1627+
assert result == "total", f"Invalid rate limit type should fall back to 'total', got '{result}'"
1628+
finally:
1629+
monkeypatch.setattr(proxy_server, 'general_settings', original_settings)
1630+
1631+
1632+
@pytest.mark.parametrize(
1633+
"token_rate_limit_type,expected_field",
1634+
[
1635+
("input", "prompt_tokens"),
1636+
("output", "completion_tokens"),
1637+
("total", "total_tokens"),
1638+
],
1639+
)
1640+
@pytest.mark.asyncio
1641+
async def test_async_log_success_event_with_dict_usage(monkeypatch, token_rate_limit_type, expected_field):
1642+
"""
1643+
Test that async_log_success_event correctly handles usage as a dict (Responses API format).
1644+
1645+
The Responses API returns usage as a dict in ResponsesAPIResponse instead of a Usage object.
1646+
This test verifies that token counting works correctly with dict-based usage.
1647+
"""
1648+
from unittest.mock import MagicMock
1649+
1650+
_api_key = "sk-12345"
1651+
_api_key = hash_token(_api_key)
1652+
local_cache = DualCache()
1653+
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
1654+
internal_usage_cache=InternalUsageCache(local_cache)
1655+
)
1656+
1657+
# Mock the get_rate_limit_type method
1658+
def mock_get_rate_limit_type():
1659+
return token_rate_limit_type
1660+
1661+
monkeypatch.setattr(
1662+
parallel_request_handler, "get_rate_limit_type", mock_get_rate_limit_type
1663+
)
1664+
1665+
# Create a mock response object with usage as a dict (Responses API format)
1666+
mock_response = MagicMock()
1667+
mock_response.usage = {
1668+
"prompt_tokens": 25,
1669+
"completion_tokens": 35,
1670+
"total_tokens": 60
1671+
}
1672+
# Make isinstance check for BaseLiteLLMOpenAIResponseObject return True
1673+
from litellm.types.utils import BaseLiteLLMOpenAIResponseObject
1674+
mock_response.__class__ = type('MockResponse', (BaseLiteLLMOpenAIResponseObject,), {})
1675+
1676+
# Create mock kwargs for the success event
1677+
mock_kwargs = {
1678+
"standard_logging_object": {
1679+
"metadata": {
1680+
"user_api_key_hash": _api_key,
1681+
"user_api_key_user_id": None,
1682+
"user_api_key_team_id": None,
1683+
"user_api_key_end_user_id": None,
1684+
}
1685+
},
1686+
"model": "gpt-3.5-turbo",
1687+
}
1688+
1689+
# Mock the pipeline increment method to capture the operations
1690+
captured_operations = []
1691+
1692+
async def mock_increment_pipeline(increment_list, **kwargs):
1693+
captured_operations.extend(increment_list)
1694+
return True
1695+
1696+
monkeypatch.setattr(
1697+
parallel_request_handler.internal_usage_cache.dual_cache,
1698+
"async_increment_cache_pipeline",
1699+
mock_increment_pipeline,
1700+
)
1701+
1702+
# Call the success event handler
1703+
await parallel_request_handler.async_log_success_event(
1704+
kwargs=mock_kwargs,
1705+
response_obj=mock_response,
1706+
start_time=datetime.now(),
1707+
end_time=datetime.now(),
1708+
)
1709+
1710+
# Find the TPM increment operation
1711+
tpm_operation = None
1712+
for op in captured_operations:
1713+
if op["key"].endswith(":tokens"):
1714+
tpm_operation = op
1715+
break
1716+
1717+
assert tpm_operation is not None, "Should have a TPM increment operation"
1718+
1719+
# Check that the correct token count was used based on the rate limit type
1720+
expected_tokens = {
1721+
"input": 25, # prompt_tokens
1722+
"output": 35, # completion_tokens
1723+
"total": 60, # total_tokens
1724+
}
1725+
1726+
assert (
1727+
tpm_operation["increment_value"] == expected_tokens[token_rate_limit_type]
1728+
), f"Expected {expected_tokens[token_rate_limit_type]} tokens for type '{token_rate_limit_type}', got {tpm_operation['increment_value']}"
1729+
1730+
1731+
@pytest.mark.asyncio
1732+
async def test_async_log_success_event_with_dict_usage_missing_fields(monkeypatch):
1733+
"""
1734+
Test that async_log_success_event handles dict usage with missing fields gracefully.
1735+
1736+
When usage dict is missing expected fields, it should default to 0.
1737+
"""
1738+
from unittest.mock import MagicMock
1739+
1740+
_api_key = "sk-12345"
1741+
_api_key = hash_token(_api_key)
1742+
local_cache = DualCache()
1743+
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
1744+
internal_usage_cache=InternalUsageCache(local_cache)
1745+
)
1746+
1747+
# Mock the get_rate_limit_type method
1748+
def mock_get_rate_limit_type():
1749+
return "output"
1750+
1751+
monkeypatch.setattr(
1752+
parallel_request_handler, "get_rate_limit_type", mock_get_rate_limit_type
1753+
)
1754+
1755+
# Create a mock response object with usage as a dict missing some fields
1756+
mock_response = MagicMock()
1757+
mock_response.usage = {
1758+
"prompt_tokens": 25,
1759+
# completion_tokens is missing
1760+
# total_tokens is missing
1761+
}
1762+
from litellm.types.utils import BaseLiteLLMOpenAIResponseObject
1763+
mock_response.__class__ = type('MockResponse', (BaseLiteLLMOpenAIResponseObject,), {})
1764+
1765+
# Create mock kwargs for the success event
1766+
mock_kwargs = {
1767+
"standard_logging_object": {
1768+
"metadata": {
1769+
"user_api_key_hash": _api_key,
1770+
"user_api_key_user_id": None,
1771+
"user_api_key_team_id": None,
1772+
"user_api_key_end_user_id": None,
1773+
}
1774+
},
1775+
"model": "gpt-3.5-turbo",
1776+
}
1777+
1778+
# Mock the pipeline increment method to capture the operations
1779+
captured_operations = []
1780+
1781+
async def mock_increment_pipeline(increment_list, **kwargs):
1782+
captured_operations.extend(increment_list)
1783+
return True
1784+
1785+
monkeypatch.setattr(
1786+
parallel_request_handler.internal_usage_cache.dual_cache,
1787+
"async_increment_cache_pipeline",
1788+
mock_increment_pipeline,
1789+
)
1790+
1791+
# Call the success event handler - should not raise exception
1792+
await parallel_request_handler.async_log_success_event(
1793+
kwargs=mock_kwargs,
1794+
response_obj=mock_response,
1795+
start_time=datetime.now(),
1796+
end_time=datetime.now(),
1797+
)
1798+
1799+
# Find the TPM increment operation
1800+
tpm_operation = None
1801+
for op in captured_operations:
1802+
if op["key"].endswith(":tokens"):
1803+
tpm_operation = op
1804+
break
1805+
1806+
assert tpm_operation is not None, "Should have a TPM increment operation"
1807+
# Should default to 0 when field is missing
1808+
assert tpm_operation["increment_value"] == 0, "Should default to 0 when completion_tokens is missing"
1809+
1810+
15861811
@pytest.mark.asyncio
15871812
async def test_execute_token_increment_script_cluster_compatibility():
15881813
"""

0 commit comments

Comments
 (0)