diff --git a/poe_client/rate_limiter.py b/poe_client/rate_limiter.py index 915ff58..42d9c54 100644 --- a/poe_client/rate_limiter.py +++ b/poe_client/rate_limiter.py @@ -33,7 +33,7 @@ class Policy(object): restriction: int state: PolicyState - mutex: asyncio.Lock = asyncio.Lock() + mutex: asyncio.Lock def __init__(self, name: str, max_hits: int, period: int, restriction: int): """Initialize a new policy.""" @@ -42,53 +42,64 @@ def __init__(self, name: str, max_hits: int, period: int, restriction: int): self.period = period self.restriction = restriction self.state = PolicyState(current_hits=0, restriction=0) + self.mutex = asyncio.Lock() async def update_state(self, current_hits: int, restriction: int): """Update the state of the policy.""" async with self.mutex: - logging.debug( - "Updating state[{0}] to {1} hits, {2} restriction".format( - self.name, - current_hits, - restriction, - ) + await self._update_state(current_hits, restriction) + + async def _update_state(self, current_hits: int, restriction: int): + # We DO NOT acquire the mutex. External callers shouldn't call this. + # Methods of this class can call this if they've already acquired the mutex. + logging.debug( + "Updating state[{0}] to {1} hits, {2} restriction".format( + self.name, + current_hits, + restriction, ) - self.state = PolicyState(current_hits, restriction) + ) + self.state = PolicyState(current_hits, restriction) async def get_semaphore(self) -> bool: """Check state to see if request is allowed.""" logging.debug("{0} = {1}".format(self.name, self.state.__dict__)) - # If last request was restricted, wait and allow - if self.state.restriction: - logging.info( - "Rate limiter restricted. Sleeping for {0} seconds".format( - self.state.restriction + async with self.mutex: + # If last request was restricted, wait and allow + if self.state.restriction: + logging.info( + "Rate limiter restricted. Sleeping for {0} seconds".format( + self.state.restriction + ) ) - ) - await asyncio.sleep(self.state.restriction + 1) - return True - - # Reset state and allow if last request is older restriction time. - if self.state.last_request > (datetime.now() + timedelta(seconds=self.period)): - await self.update_state(0, 0) - return True - - if self.state.current_hits >= self.max_hits: - logging.info( - "Rate limiter max hits reached. Sleeping for {0} seconds".format( - self.period + await asyncio.sleep(self.state.restriction + 1) + return True + + # Reset state and allow if last request is older restriction time. + if self.state.last_request > ( + datetime.now() + timedelta(seconds=self.period) + ): + await self._update_state(0, 0) + return True + + if self.state.current_hits >= self.max_hits: + logging.info( + "Rate limiter max hits reached. Sleeping for {0} seconds".format( + self.period + ) ) - ) - await asyncio.sleep(self.period + 1) - return True + await asyncio.sleep(self.period + 1) + return True - # If we haven't reached the quota, increase and allow - if self.state.current_hits < self.max_hits: - await self.update_state(self.state.current_hits + 1, self.state.restriction) - return True + # If we haven't reached the quota, increase and allow + if self.state.current_hits < self.max_hits: + await self._update_state( + self.state.current_hits + 1, self.state.restriction + ) + return True - # Don't allow by default - return False + # Don't allow by default + return False class RateLimiter(object): @@ -136,17 +147,18 @@ async def parse_headers(self, headers) -> str: async def get_semaphore(self, policy_name: str) -> bool: """Get a semaphore to make a request.""" - if not self.policies: - logging.debug("No policies, do a blocking request") - return False - - semaphores = [] - for name, policy in self.policies.items(): - if name.startswith(policy_name): - logging.debug("getting semaphore {0}".format(name)) - for limit in policy.values(): - semaphores.append(limit.get_semaphore()) - - if semaphores: - await asyncio.wait(semaphores) - return True + async with self.mutex: + if not self.policies: + logging.debug("No policies, do a blocking request") + return False + + semaphores = [] + for name, policy in self.policies.items(): + if name.startswith(policy_name): + logging.debug("getting semaphore {0}".format(name)) + for limit in policy.values(): + semaphores.append(limit.get_semaphore()) + + if semaphores: + await asyncio.wait(semaphores) + return True