| @@ -27,6 +27,9 @@ class RateLimit: | |||
| def __init__(self, client_id: str, max_active_requests: int): | |||
| self.max_active_requests = max_active_requests | |||
| # must be called after max_active_requests is set | |||
| if self.disabled(): | |||
| return | |||
| if hasattr(self, "initialized"): | |||
| return | |||
| self.initialized = True | |||
| @@ -37,6 +40,8 @@ class RateLimit: | |||
| self.flush_cache(use_local_value=True) | |||
| def flush_cache(self, use_local_value=False): | |||
| if self.disabled(): | |||
| return | |||
| self.last_recalculate_time = time.time() | |||
| # flush max active requests | |||
| if use_local_value or not redis_client.exists(self.max_active_requests_key): | |||
| @@ -59,18 +64,18 @@ class RateLimit: | |||
| redis_client.hdel(self.active_requests_key, *timeout_requests) | |||
| def enter(self, request_id: Optional[str] = None) -> str: | |||
| if self.disabled(): | |||
| return RateLimit._UNLIMITED_REQUEST_ID | |||
| if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL: | |||
| self.flush_cache() | |||
| if self.max_active_requests <= 0: | |||
| return RateLimit._UNLIMITED_REQUEST_ID | |||
| if not request_id: | |||
| request_id = RateLimit.gen_request_key() | |||
| active_requests_count = redis_client.hlen(self.active_requests_key) | |||
| if active_requests_count >= self.max_active_requests: | |||
| raise AppInvokeQuotaExceededError( | |||
| "Too many requests. Please try again later. The current maximum " | |||
| "concurrent requests allowed is {}.".format(self.max_active_requests) | |||
| f"Too many requests. Please try again later. The current maximum concurrent requests allowed " | |||
| f"for {self.client_id} is {self.max_active_requests}." | |||
| ) | |||
| redis_client.hset(self.active_requests_key, request_id, str(time.time())) | |||
| return request_id | |||
| @@ -80,6 +85,9 @@ class RateLimit: | |||
| return | |||
| redis_client.hdel(self.active_requests_key, request_id) | |||
| def disabled(self): | |||
| return self.max_active_requests <= 0 | |||
| @staticmethod | |||
| def gen_request_key() -> str: | |||
| return str(uuid.uuid4()) | |||