Skip to content

Commit 9e85dcb

Browse files
committed
read responses api usage
1 parent b83bc10 commit 9e85dcb

File tree

2 files changed

+243
-10
lines changed

2 files changed

+243
-10
lines changed

litellm/proxy/hooks/parallel_request_limiter_v3.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,11 +1313,10 @@ async def async_increment_tokens_with_ttl_preservation(
13131313

13141314
def get_rate_limit_type(self) -> Literal["output", "input", "total"]:
13151315
from litellm.proxy.proxy_server import general_settings
1316-
13171316
specified_rate_limit_type = general_settings.get(
1318-
"token_rate_limit_type", "output"
1317+
"token_rate_limit_type", "total"
13191318
)
1320-
if not specified_rate_limit_type or specified_rate_limit_type not in [
1319+
if specified_rate_limit_type not in [
13211320
"output",
13221321
"input",
13231322
"total",
@@ -1372,13 +1371,22 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
13721371
response_obj, BaseLiteLLMOpenAIResponseObject
13731372
):
13741373
_usage = getattr(response_obj, "usage", None)
1375-
if _usage and isinstance(_usage, Usage):
1376-
if rate_limit_type == "output":
1377-
total_tokens = _usage.completion_tokens
1378-
elif rate_limit_type == "input":
1379-
total_tokens = _usage.prompt_tokens
1380-
elif rate_limit_type == "total":
1381-
total_tokens = _usage.total_tokens
1374+
if _usage:
1375+
if isinstance(_usage, Usage):
1376+
if rate_limit_type == "output":
1377+
total_tokens = _usage.completion_tokens
1378+
elif rate_limit_type == "input":
1379+
total_tokens = _usage.prompt_tokens
1380+
elif rate_limit_type == "total":
1381+
total_tokens = _usage.total_tokens
1382+
elif isinstance(_usage, dict):
1383+
# Responses API usage comes as a dict in ResponsesAPIResponse
1384+
if rate_limit_type == "output":
1385+
total_tokens = _usage.get("completion_tokens", 0)
1386+
elif rate_limit_type == "input":
1387+
total_tokens = _usage.get("prompt_tokens", 0)
1388+
elif rate_limit_type == "total":
1389+
total_tokens = _usage.get("total_tokens", 0)
13821390

13831391
# Create pipeline operations for TPM increments
13841392
pipeline_operations: List[RedisPipelineIncrementOperation] = []

tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,6 +1583,231 @@ async def mock_should_rate_limit(descriptors, **kwargs):
15831583
assert "Current limit: 2" in exc_info.value.detail
15841584

15851585

1586+
@pytest.mark.asyncio
1587+
async def test_get_rate_limit_type_default_is_total(monkeypatch):
1588+
"""
1589+
Test that get_rate_limit_type returns 'total' as the default when no setting is specified.
1590+
1591+
This verifies the change from 'output' to 'total' as the default value.
1592+
"""
1593+
local_cache = DualCache()
1594+
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
1595+
internal_usage_cache=InternalUsageCache(local_cache)
1596+
)
1597+
1598+
# Mock general_settings to return empty dict (no token_rate_limit_type set)
1599+
import litellm.proxy.proxy_server as proxy_server
1600+
original_settings = getattr(proxy_server, 'general_settings', {})
1601+
monkeypatch.setattr(proxy_server, 'general_settings', {})
1602+
1603+
try:
1604+
result = parallel_request_handler.get_rate_limit_type()
1605+
assert result == "total", f"Default rate limit type should be 'total', got '{result}'"
1606+
finally:
1607+
monkeypatch.setattr(proxy_server, 'general_settings', original_settings)
1608+
1609+
1610+
@pytest.mark.asyncio
1611+
async def test_get_rate_limit_type_invalid_falls_back_to_total(monkeypatch):
1612+
"""
1613+
Test that get_rate_limit_type falls back to 'total' when an invalid value is specified.
1614+
"""
1615+
local_cache = DualCache()
1616+
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
1617+
internal_usage_cache=InternalUsageCache(local_cache)
1618+
)
1619+
1620+
# Mock general_settings to return an invalid token_rate_limit_type
1621+
import litellm.proxy.proxy_server as proxy_server
1622+
original_settings = getattr(proxy_server, 'general_settings', {})
1623+
monkeypatch.setattr(proxy_server, 'general_settings', {'token_rate_limit_type': 'invalid_type'})
1624+
1625+
try:
1626+
result = parallel_request_handler.get_rate_limit_type()
1627+
assert result == "total", f"Invalid rate limit type should fall back to 'total', got '{result}'"
1628+
finally:
1629+
monkeypatch.setattr(proxy_server, 'general_settings', original_settings)
1630+
1631+
1632+
@pytest.mark.parametrize(
1633+
"token_rate_limit_type,expected_field",
1634+
[
1635+
("input", "prompt_tokens"),
1636+
("output", "completion_tokens"),
1637+
("total", "total_tokens"),
1638+
],
1639+
)
1640+
@pytest.mark.asyncio
1641+
async def test_async_log_success_event_with_dict_usage(monkeypatch, token_rate_limit_type, expected_field):
1642+
"""
1643+
Test that async_log_success_event correctly handles usage as a dict (Responses API format).
1644+
1645+
The Responses API returns usage as a dict in ResponsesAPIResponse instead of a Usage object.
1646+
This test verifies that token counting works correctly with dict-based usage.
1647+
"""
1648+
from unittest.mock import MagicMock
1649+
1650+
_api_key = "sk-12345"
1651+
_api_key = hash_token(_api_key)
1652+
local_cache = DualCache()
1653+
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
1654+
internal_usage_cache=InternalUsageCache(local_cache)
1655+
)
1656+
1657+
# Mock the get_rate_limit_type method
1658+
def mock_get_rate_limit_type():
1659+
return token_rate_limit_type
1660+
1661+
monkeypatch.setattr(
1662+
parallel_request_handler, "get_rate_limit_type", mock_get_rate_limit_type
1663+
)
1664+
1665+
# Create a mock response object with usage as a dict (Responses API format)
1666+
mock_response = MagicMock()
1667+
mock_response.usage = {
1668+
"prompt_tokens": 25,
1669+
"completion_tokens": 35,
1670+
"total_tokens": 60
1671+
}
1672+
# Make isinstance check for BaseLiteLLMOpenAIResponseObject return True
1673+
from litellm.types.utils import BaseLiteLLMOpenAIResponseObject
1674+
mock_response.__class__ = type('MockResponse', (BaseLiteLLMOpenAIResponseObject,), {})
1675+
1676+
# Create mock kwargs for the success event
1677+
mock_kwargs = {
1678+
"standard_logging_object": {
1679+
"metadata": {
1680+
"user_api_key_hash": _api_key,
1681+
"user_api_key_user_id": None,
1682+
"user_api_key_team_id": None,
1683+
"user_api_key_end_user_id": None,
1684+
}
1685+
},
1686+
"model": "gpt-3.5-turbo",
1687+
}
1688+
1689+
# Mock the pipeline increment method to capture the operations
1690+
captured_operations = []
1691+
1692+
async def mock_increment_pipeline(increment_list, **kwargs):
1693+
captured_operations.extend(increment_list)
1694+
return True
1695+
1696+
monkeypatch.setattr(
1697+
parallel_request_handler.internal_usage_cache.dual_cache,
1698+
"async_increment_cache_pipeline",
1699+
mock_increment_pipeline,
1700+
)
1701+
1702+
# Call the success event handler
1703+
await parallel_request_handler.async_log_success_event(
1704+
kwargs=mock_kwargs,
1705+
response_obj=mock_response,
1706+
start_time=datetime.now(),
1707+
end_time=datetime.now(),
1708+
)
1709+
1710+
# Find the TPM increment operation
1711+
tpm_operation = None
1712+
for op in captured_operations:
1713+
if op["key"].endswith(":tokens"):
1714+
tpm_operation = op
1715+
break
1716+
1717+
assert tpm_operation is not None, "Should have a TPM increment operation"
1718+
1719+
# Check that the correct token count was used based on the rate limit type
1720+
expected_tokens = {
1721+
"input": 25, # prompt_tokens
1722+
"output": 35, # completion_tokens
1723+
"total": 60, # total_tokens
1724+
}
1725+
1726+
assert (
1727+
tpm_operation["increment_value"] == expected_tokens[token_rate_limit_type]
1728+
), f"Expected {expected_tokens[token_rate_limit_type]} tokens for type '{token_rate_limit_type}', got {tpm_operation['increment_value']}"
1729+
1730+
1731+
@pytest.mark.asyncio
1732+
async def test_async_log_success_event_with_dict_usage_missing_fields(monkeypatch):
1733+
"""
1734+
Test that async_log_success_event handles dict usage with missing fields gracefully.
1735+
1736+
When usage dict is missing expected fields, it should default to 0.
1737+
"""
1738+
from unittest.mock import MagicMock
1739+
1740+
_api_key = "sk-12345"
1741+
_api_key = hash_token(_api_key)
1742+
local_cache = DualCache()
1743+
parallel_request_handler = _PROXY_MaxParallelRequestsHandler(
1744+
internal_usage_cache=InternalUsageCache(local_cache)
1745+
)
1746+
1747+
# Mock the get_rate_limit_type method
1748+
def mock_get_rate_limit_type():
1749+
return "output"
1750+
1751+
monkeypatch.setattr(
1752+
parallel_request_handler, "get_rate_limit_type", mock_get_rate_limit_type
1753+
)
1754+
1755+
# Create a mock response object with usage as a dict missing some fields
1756+
mock_response = MagicMock()
1757+
mock_response.usage = {
1758+
"prompt_tokens": 25,
1759+
# completion_tokens is missing
1760+
# total_tokens is missing
1761+
}
1762+
from litellm.types.utils import BaseLiteLLMOpenAIResponseObject
1763+
mock_response.__class__ = type('MockResponse', (BaseLiteLLMOpenAIResponseObject,), {})
1764+
1765+
# Create mock kwargs for the success event
1766+
mock_kwargs = {
1767+
"standard_logging_object": {
1768+
"metadata": {
1769+
"user_api_key_hash": _api_key,
1770+
"user_api_key_user_id": None,
1771+
"user_api_key_team_id": None,
1772+
"user_api_key_end_user_id": None,
1773+
}
1774+
},
1775+
"model": "gpt-3.5-turbo",
1776+
}
1777+
1778+
# Mock the pipeline increment method to capture the operations
1779+
captured_operations = []
1780+
1781+
async def mock_increment_pipeline(increment_list, **kwargs):
1782+
captured_operations.extend(increment_list)
1783+
return True
1784+
1785+
monkeypatch.setattr(
1786+
parallel_request_handler.internal_usage_cache.dual_cache,
1787+
"async_increment_cache_pipeline",
1788+
mock_increment_pipeline,
1789+
)
1790+
1791+
# Call the success event handler - should not raise exception
1792+
await parallel_request_handler.async_log_success_event(
1793+
kwargs=mock_kwargs,
1794+
response_obj=mock_response,
1795+
start_time=datetime.now(),
1796+
end_time=datetime.now(),
1797+
)
1798+
1799+
# Find the TPM increment operation
1800+
tpm_operation = None
1801+
for op in captured_operations:
1802+
if op["key"].endswith(":tokens"):
1803+
tpm_operation = op
1804+
break
1805+
1806+
assert tpm_operation is not None, "Should have a TPM increment operation"
1807+
# Should default to 0 when field is missing
1808+
assert tpm_operation["increment_value"] == 0, "Should default to 0 when completion_tokens is missing"
1809+
1810+
15861811
@pytest.mark.asyncio
15871812
async def test_execute_token_increment_script_cluster_compatibility():
15881813
"""

0 commit comments

Comments
 (0)