Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 61 additions & 49 deletions poe_client/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Policy(object):
restriction: int
state: PolicyState

mutex: asyncio.Lock = asyncio.Lock()
mutex: asyncio.Lock

def __init__(self, name: str, max_hits: int, period: int, restriction: int):
"""Initialize a new policy."""
Expand All @@ -42,53 +42,64 @@ def __init__(self, name: str, max_hits: int, period: int, restriction: int):
self.period = period
self.restriction = restriction
self.state = PolicyState(current_hits=0, restriction=0)
self.mutex = asyncio.Lock()

async def update_state(self, current_hits: int, restriction: int):
"""Update the state of the policy."""
async with self.mutex:
logging.debug(
"Updating state[{0}] to {1} hits, {2} restriction".format(
self.name,
current_hits,
restriction,
)
await self._update_state(current_hits, restriction)

async def _update_state(self, current_hits: int, restriction: int):
# We DO NOT acquire the mutex. External callers shouldn't call this.
# Methods of this class can call this if they've already acquired the mutex.
logging.debug(
"Updating state[{0}] to {1} hits, {2} restriction".format(
self.name,
current_hits,
restriction,
)
self.state = PolicyState(current_hits, restriction)
)
self.state = PolicyState(current_hits, restriction)

async def get_semaphore(self) -> bool:
"""Check state to see if request is allowed."""
logging.debug("{0} = {1}".format(self.name, self.state.__dict__))
# If last request was restricted, wait and allow
if self.state.restriction:
logging.info(
"Rate limiter restricted. Sleeping for {0} seconds".format(
self.state.restriction
async with self.mutex:
# If last request was restricted, wait and allow
if self.state.restriction:
logging.info(
"Rate limiter restricted. Sleeping for {0} seconds".format(
self.state.restriction
)
)
)
await asyncio.sleep(self.state.restriction + 1)
return True

# Reset state and allow if last request is older restriction time.
if self.state.last_request > (datetime.now() + timedelta(seconds=self.period)):
await self.update_state(0, 0)
return True

if self.state.current_hits >= self.max_hits:
logging.info(
"Rate limiter max hits reached. Sleeping for {0} seconds".format(
self.period
await asyncio.sleep(self.state.restriction + 1)
return True

# Reset state and allow if last request is older restriction time.
if self.state.last_request > (
datetime.now() + timedelta(seconds=self.period)
):
await self._update_state(0, 0)
return True

if self.state.current_hits >= self.max_hits:
logging.info(
"Rate limiter max hits reached. Sleeping for {0} seconds".format(
self.period
)
)
)
await asyncio.sleep(self.period + 1)
return True
await asyncio.sleep(self.period + 1)
return True

# If we haven't reached the quota, increase and allow
if self.state.current_hits < self.max_hits:
await self.update_state(self.state.current_hits + 1, self.state.restriction)
return True
# If we haven't reached the quota, increase and allow
if self.state.current_hits < self.max_hits:
await self._update_state(
self.state.current_hits + 1, self.state.restriction
)
return True

# Don't allow by default
return False
# Don't allow by default
return False


class RateLimiter(object):
Expand Down Expand Up @@ -136,17 +147,18 @@ async def parse_headers(self, headers) -> str:

async def get_semaphore(self, policy_name: str) -> bool:
"""Get a semaphore to make a request."""
if not self.policies:
logging.debug("No policies, do a blocking request")
return False

semaphores = []
for name, policy in self.policies.items():
if name.startswith(policy_name):
logging.debug("getting semaphore {0}".format(name))
for limit in policy.values():
semaphores.append(limit.get_semaphore())

if semaphores:
await asyncio.wait(semaphores)
return True
async with self.mutex:
if not self.policies:
logging.debug("No policies, do a blocking request")
return False

semaphores = []
for name, policy in self.policies.items():
if name.startswith(policy_name):
logging.debug("getting semaphore {0}".format(name))
for limit in policy.values():
semaphores.append(limit.get_semaphore())

if semaphores:
await asyncio.wait(semaphores)
return True