Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions poe_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,8 @@ async def _get_json(
# We ignore typing in the dict assignment. kwargs only has dicts as values,
# but we're assigning booleans here. We can't set the typing inline without
# flake8 complaining about overly complex annotation.
logging.debug("NOT BLOCKING")
kwargs["raise_for_status"] = True # type: ignore
else:
logging.debug("BLOCKING")
kwargs["raise_for_status"] = False # type: ignore

# The types are ignored because for some reason it can't understand
Expand All @@ -189,6 +187,12 @@ async def _get_json(
] = await self._limiter.parse_headers(resp.headers)

if resp.status != 200:
logging.debug(
"Got status code %s with text %s and headers %s",
resp.status,
await resp.text(),
resp.headers,
)
raise ValueError(
"Invalid request: status code {0}, expected 200".format(
resp.status,
Expand Down
15 changes: 2 additions & 13 deletions poe_client/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,19 @@ def __init__(self, name: str, max_hits: int, period: int, restriction: int):
async def update_state(self, current_hits: int, restriction: int):
"""Update the state of the policy."""
async with self.mutex:
logging.debug(
"Updating state[{0}] to {1} hits, {2} restriction".format(
self.name,
current_hits,
restriction,
)
)
self.state.current_hits = current_hits
self.state.restriction = restriction

async def get_semaphore(self) -> bool:
"""Check state to see if request is allowed."""
logging.debug("{0} = {1}".format(self.name, self.state.__dict__))

# If last request was restricted, wait and allow
if self.state.restriction:
logging.info(
"Rate limiter restricted. Sleeping for {0} seconds".format(
self.state.restriction
)
)
await asyncio.sleep(self.state.restriction + 1)
await asyncio.sleep(self.state.restriction)
return True

if self.state.current_hits >= self.max_hits:
Expand All @@ -77,7 +68,7 @@ async def get_semaphore(self) -> bool:
self.period
)
)
await asyncio.sleep(self.period + 1)
await asyncio.sleep(self.period)
return True

# If we haven't reached the quota, increase and allow
Expand Down Expand Up @@ -146,15 +137,13 @@ async def get_semaphore(self, policy_name: str) -> bool:
"""Get a semaphore to make a request."""
async with self.mutex:
if not self.policies:
logging.debug("No policies, do a blocking request")
return False

semaphores = []
for name, policy in self.policies.items():
if not name.startswith(policy_name):
continue

logging.debug("getting semaphore {0}".format(name))
for limit in policy.values():
semaphores.append(limit.get_semaphore())

Expand Down