@@ -1583,6 +1583,231 @@ async def mock_should_rate_limit(descriptors, **kwargs):
15831583 assert "Current limit: 2" in exc_info .value .detail
15841584
15851585
1586+ @pytest .mark .asyncio
1587+ async def test_get_rate_limit_type_default_is_total (monkeypatch ):
1588+ """
1589+ Test that get_rate_limit_type returns 'total' as the default when no setting is specified.
1590+
1591+ This verifies the change from 'output' to 'total' as the default value.
1592+ """
1593+ local_cache = DualCache ()
1594+ parallel_request_handler = _PROXY_MaxParallelRequestsHandler (
1595+ internal_usage_cache = InternalUsageCache (local_cache )
1596+ )
1597+
1598+ # Mock general_settings to return empty dict (no token_rate_limit_type set)
1599+ import litellm .proxy .proxy_server as proxy_server
1600+ original_settings = getattr (proxy_server , 'general_settings' , {})
1601+ monkeypatch .setattr (proxy_server , 'general_settings' , {})
1602+
1603+ try :
1604+ result = parallel_request_handler .get_rate_limit_type ()
1605+ assert result == "total" , f"Default rate limit type should be 'total', got '{ result } '"
1606+ finally :
1607+ monkeypatch .setattr (proxy_server , 'general_settings' , original_settings )
1608+
1609+
1610+ @pytest .mark .asyncio
1611+ async def test_get_rate_limit_type_invalid_falls_back_to_total (monkeypatch ):
1612+ """
1613+ Test that get_rate_limit_type falls back to 'total' when an invalid value is specified.
1614+ """
1615+ local_cache = DualCache ()
1616+ parallel_request_handler = _PROXY_MaxParallelRequestsHandler (
1617+ internal_usage_cache = InternalUsageCache (local_cache )
1618+ )
1619+
1620+ # Mock general_settings to return an invalid token_rate_limit_type
1621+ import litellm .proxy .proxy_server as proxy_server
1622+ original_settings = getattr (proxy_server , 'general_settings' , {})
1623+ monkeypatch .setattr (proxy_server , 'general_settings' , {'token_rate_limit_type' : 'invalid_type' })
1624+
1625+ try :
1626+ result = parallel_request_handler .get_rate_limit_type ()
1627+ assert result == "total" , f"Invalid rate limit type should fall back to 'total', got '{ result } '"
1628+ finally :
1629+ monkeypatch .setattr (proxy_server , 'general_settings' , original_settings )
1630+
1631+
1632+ @pytest .mark .parametrize (
1633+ "token_rate_limit_type,expected_field" ,
1634+ [
1635+ ("input" , "prompt_tokens" ),
1636+ ("output" , "completion_tokens" ),
1637+ ("total" , "total_tokens" ),
1638+ ],
1639+ )
1640+ @pytest .mark .asyncio
1641+ async def test_async_log_success_event_with_dict_usage (monkeypatch , token_rate_limit_type , expected_field ):
1642+ """
1643+ Test that async_log_success_event correctly handles usage as a dict (Responses API format).
1644+
1645+ The Responses API returns usage as a dict in ResponsesAPIResponse instead of a Usage object.
1646+ This test verifies that token counting works correctly with dict-based usage.
1647+ """
1648+ from unittest .mock import MagicMock
1649+
1650+ _api_key = "sk-12345"
1651+ _api_key = hash_token (_api_key )
1652+ local_cache = DualCache ()
1653+ parallel_request_handler = _PROXY_MaxParallelRequestsHandler (
1654+ internal_usage_cache = InternalUsageCache (local_cache )
1655+ )
1656+
1657+ # Mock the get_rate_limit_type method
1658+ def mock_get_rate_limit_type ():
1659+ return token_rate_limit_type
1660+
1661+ monkeypatch .setattr (
1662+ parallel_request_handler , "get_rate_limit_type" , mock_get_rate_limit_type
1663+ )
1664+
1665+ # Create a mock response object with usage as a dict (Responses API format)
1666+ mock_response = MagicMock ()
1667+ mock_response .usage = {
1668+ "prompt_tokens" : 25 ,
1669+ "completion_tokens" : 35 ,
1670+ "total_tokens" : 60
1671+ }
1672+ # Make isinstance check for BaseLiteLLMOpenAIResponseObject return True
1673+ from litellm .types .utils import BaseLiteLLMOpenAIResponseObject
1674+ mock_response .__class__ = type ('MockResponse' , (BaseLiteLLMOpenAIResponseObject ,), {})
1675+
1676+ # Create mock kwargs for the success event
1677+ mock_kwargs = {
1678+ "standard_logging_object" : {
1679+ "metadata" : {
1680+ "user_api_key_hash" : _api_key ,
1681+ "user_api_key_user_id" : None ,
1682+ "user_api_key_team_id" : None ,
1683+ "user_api_key_end_user_id" : None ,
1684+ }
1685+ },
1686+ "model" : "gpt-3.5-turbo" ,
1687+ }
1688+
1689+ # Mock the pipeline increment method to capture the operations
1690+ captured_operations = []
1691+
1692+ async def mock_increment_pipeline (increment_list , ** kwargs ):
1693+ captured_operations .extend (increment_list )
1694+ return True
1695+
1696+ monkeypatch .setattr (
1697+ parallel_request_handler .internal_usage_cache .dual_cache ,
1698+ "async_increment_cache_pipeline" ,
1699+ mock_increment_pipeline ,
1700+ )
1701+
1702+ # Call the success event handler
1703+ await parallel_request_handler .async_log_success_event (
1704+ kwargs = mock_kwargs ,
1705+ response_obj = mock_response ,
1706+ start_time = datetime .now (),
1707+ end_time = datetime .now (),
1708+ )
1709+
1710+ # Find the TPM increment operation
1711+ tpm_operation = None
1712+ for op in captured_operations :
1713+ if op ["key" ].endswith (":tokens" ):
1714+ tpm_operation = op
1715+ break
1716+
1717+ assert tpm_operation is not None , "Should have a TPM increment operation"
1718+
1719+ # Check that the correct token count was used based on the rate limit type
1720+ expected_tokens = {
1721+ "input" : 25 , # prompt_tokens
1722+ "output" : 35 , # completion_tokens
1723+ "total" : 60 , # total_tokens
1724+ }
1725+
1726+ assert (
1727+ tpm_operation ["increment_value" ] == expected_tokens [token_rate_limit_type ]
1728+ ), f"Expected { expected_tokens [token_rate_limit_type ]} tokens for type '{ token_rate_limit_type } ', got { tpm_operation ['increment_value' ]} "
1729+
1730+
1731+ @pytest .mark .asyncio
1732+ async def test_async_log_success_event_with_dict_usage_missing_fields (monkeypatch ):
1733+ """
1734+ Test that async_log_success_event handles dict usage with missing fields gracefully.
1735+
1736+ When usage dict is missing expected fields, it should default to 0.
1737+ """
1738+ from unittest .mock import MagicMock
1739+
1740+ _api_key = "sk-12345"
1741+ _api_key = hash_token (_api_key )
1742+ local_cache = DualCache ()
1743+ parallel_request_handler = _PROXY_MaxParallelRequestsHandler (
1744+ internal_usage_cache = InternalUsageCache (local_cache )
1745+ )
1746+
1747+ # Mock the get_rate_limit_type method
1748+ def mock_get_rate_limit_type ():
1749+ return "output"
1750+
1751+ monkeypatch .setattr (
1752+ parallel_request_handler , "get_rate_limit_type" , mock_get_rate_limit_type
1753+ )
1754+
1755+ # Create a mock response object with usage as a dict missing some fields
1756+ mock_response = MagicMock ()
1757+ mock_response .usage = {
1758+ "prompt_tokens" : 25 ,
1759+ # completion_tokens is missing
1760+ # total_tokens is missing
1761+ }
1762+ from litellm .types .utils import BaseLiteLLMOpenAIResponseObject
1763+ mock_response .__class__ = type ('MockResponse' , (BaseLiteLLMOpenAIResponseObject ,), {})
1764+
1765+ # Create mock kwargs for the success event
1766+ mock_kwargs = {
1767+ "standard_logging_object" : {
1768+ "metadata" : {
1769+ "user_api_key_hash" : _api_key ,
1770+ "user_api_key_user_id" : None ,
1771+ "user_api_key_team_id" : None ,
1772+ "user_api_key_end_user_id" : None ,
1773+ }
1774+ },
1775+ "model" : "gpt-3.5-turbo" ,
1776+ }
1777+
1778+ # Mock the pipeline increment method to capture the operations
1779+ captured_operations = []
1780+
1781+ async def mock_increment_pipeline (increment_list , ** kwargs ):
1782+ captured_operations .extend (increment_list )
1783+ return True
1784+
1785+ monkeypatch .setattr (
1786+ parallel_request_handler .internal_usage_cache .dual_cache ,
1787+ "async_increment_cache_pipeline" ,
1788+ mock_increment_pipeline ,
1789+ )
1790+
1791+ # Call the success event handler - should not raise exception
1792+ await parallel_request_handler .async_log_success_event (
1793+ kwargs = mock_kwargs ,
1794+ response_obj = mock_response ,
1795+ start_time = datetime .now (),
1796+ end_time = datetime .now (),
1797+ )
1798+
1799+ # Find the TPM increment operation
1800+ tpm_operation = None
1801+ for op in captured_operations :
1802+ if op ["key" ].endswith (":tokens" ):
1803+ tpm_operation = op
1804+ break
1805+
1806+ assert tpm_operation is not None , "Should have a TPM increment operation"
1807+ # Should default to 0 when field is missing
1808+ assert tpm_operation ["increment_value" ] == 0 , "Should default to 0 when completion_tokens is missing"
1809+
1810+
15861811@pytest .mark .asyncio
15871812async def test_execute_token_increment_script_cluster_compatibility ():
15881813 """
0 commit comments