| @@ -25,7 +25,7 @@ class FireCrawlDataSource(BearerDataSource): | |||
| TEST_CRAWL_SITE_URL = "https://www.google.com" | |||
| FIRECRAWL_API_VERSION = "v0" | |||
| test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape" | |||
| test_api_endpoint = self.api_base_url.rstrip("/") + f"/{FIRECRAWL_API_VERSION}/scrape" | |||
| headers = { | |||
| "Authorization": f"Bearer {self.api_key}", | |||
| @@ -45,9 +45,9 @@ class FireCrawlDataSource(BearerDataSource): | |||
| data_source_binding = DataSourceBearerBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBearerBinding.provider == 'firecrawl', | |||
| DataSourceBearerBinding.provider == "firecrawl", | |||
| DataSourceBearerBinding.endpoint_url == self.api_base_url, | |||
| DataSourceBearerBinding.bearer_key == self.api_key | |||
| DataSourceBearerBinding.bearer_key == self.api_key, | |||
| ) | |||
| ).first() | |||
| if data_source_binding: | |||
| @@ -56,9 +56,9 @@ class FireCrawlDataSource(BearerDataSource): | |||
| else: | |||
| new_data_source_binding = DataSourceBearerBinding( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider='firecrawl', | |||
| provider="firecrawl", | |||
| endpoint_url=self.api_base_url, | |||
| bearer_key=self.api_key | |||
| bearer_key=self.api_key, | |||
| ) | |||
| db.session.add(new_data_source_binding) | |||
| db.session.commit() | |||
| @@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException | |||
| class BaseHTTPException(HTTPException): | |||
| error_code: str = 'unknown' | |||
| error_code: str = "unknown" | |||
| data: Optional[dict] = None | |||
| def __init__(self, description=None, response=None): | |||
| @@ -14,4 +14,4 @@ class BaseHTTPException(HTTPException): | |||
| "code": self.error_code, | |||
| "message": self.description, | |||
| "status": self.code, | |||
| } | |||
| } | |||
| @@ -10,7 +10,6 @@ from core.errors.error import AppInvokeQuotaExceededError | |||
| class ExternalApi(Api): | |||
| def handle_error(self, e): | |||
| """Error handler for the API transforms a raised exception into a Flask | |||
| response, with the appropriate HTTP status code and body. | |||
| @@ -29,54 +28,57 @@ class ExternalApi(Api): | |||
| status_code = e.code | |||
| default_data = { | |||
| 'code': re.sub(r'(?<!^)(?=[A-Z])', '_', type(e).__name__).lower(), | |||
| 'message': getattr(e, 'description', http_status_message(status_code)), | |||
| 'status': status_code | |||
| "code": re.sub(r"(?<!^)(?=[A-Z])", "_", type(e).__name__).lower(), | |||
| "message": getattr(e, "description", http_status_message(status_code)), | |||
| "status": status_code, | |||
| } | |||
| if default_data['message'] and default_data['message'] == 'Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)': | |||
| default_data['message'] = 'Invalid JSON payload received or JSON payload is empty.' | |||
| if ( | |||
| default_data["message"] | |||
| and default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" | |||
| ): | |||
| default_data["message"] = "Invalid JSON payload received or JSON payload is empty." | |||
| headers = e.get_response().headers | |||
| elif isinstance(e, ValueError): | |||
| status_code = 400 | |||
| default_data = { | |||
| 'code': 'invalid_param', | |||
| 'message': str(e), | |||
| 'status': status_code | |||
| "code": "invalid_param", | |||
| "message": str(e), | |||
| "status": status_code, | |||
| } | |||
| elif isinstance(e, AppInvokeQuotaExceededError): | |||
| status_code = 429 | |||
| default_data = { | |||
| 'code': 'too_many_requests', | |||
| 'message': str(e), | |||
| 'status': status_code | |||
| "code": "too_many_requests", | |||
| "message": str(e), | |||
| "status": status_code, | |||
| } | |||
| else: | |||
| status_code = 500 | |||
| default_data = { | |||
| 'message': http_status_message(status_code), | |||
| "message": http_status_message(status_code), | |||
| } | |||
| # Werkzeug exceptions generate a content-length header which is added | |||
| # to the response in addition to the actual content-length header | |||
| # https://github.com/flask-restful/flask-restful/issues/534 | |||
| remove_headers = ('Content-Length',) | |||
| remove_headers = ("Content-Length",) | |||
| for header in remove_headers: | |||
| headers.pop(header, None) | |||
| data = getattr(e, 'data', default_data) | |||
| data = getattr(e, "data", default_data) | |||
| error_cls_name = type(e).__name__ | |||
| if error_cls_name in self.errors: | |||
| custom_data = self.errors.get(error_cls_name, {}) | |||
| custom_data = custom_data.copy() | |||
| status_code = custom_data.get('status', 500) | |||
| status_code = custom_data.get("status", 500) | |||
| if 'message' in custom_data: | |||
| custom_data['message'] = custom_data['message'].format( | |||
| message=str(e.description if hasattr(e, 'description') else e) | |||
| if "message" in custom_data: | |||
| custom_data["message"] = custom_data["message"].format( | |||
| message=str(e.description if hasattr(e, "description") else e) | |||
| ) | |||
| data.update(custom_data) | |||
| @@ -94,32 +96,20 @@ class ExternalApi(Api): | |||
| # another NotAcceptable error). | |||
| supported_mediatypes = list(self.representations.keys()) # only supported application/json | |||
| fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" | |||
| data = { | |||
| 'code': 'not_acceptable', | |||
| 'message': data.get('message') | |||
| } | |||
| resp = self.make_response( | |||
| data, | |||
| status_code, | |||
| headers, | |||
| fallback_mediatype = fallback_mediatype | |||
| ) | |||
| data = {"code": "not_acceptable", "message": data.get("message")} | |||
| resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) | |||
| elif status_code == 400: | |||
| if isinstance(data.get('message'), dict): | |||
| param_key, param_value = list(data.get('message').items())[0] | |||
| data = { | |||
| 'code': 'invalid_param', | |||
| 'message': param_value, | |||
| 'params': param_key | |||
| } | |||
| if isinstance(data.get("message"), dict): | |||
| param_key, param_value = list(data.get("message").items())[0] | |||
| data = {"code": "invalid_param", "message": param_value, "params": param_key} | |||
| else: | |||
| if 'code' not in data: | |||
| data['code'] = 'unknown' | |||
| if "code" not in data: | |||
| data["code"] = "unknown" | |||
| resp = self.make_response(data, status_code, headers) | |||
| else: | |||
| if 'code' not in data: | |||
| data['code'] = 'unknown' | |||
| if "code" not in data: | |||
| data["code"] = "unknown" | |||
| resp = self.make_response(data, status_code, headers) | |||
| @@ -70,7 +70,7 @@ class PKCS1OAEP_Cipher: | |||
| if mgfunc: | |||
| self._mgf = mgfunc | |||
| else: | |||
| self._mgf = lambda x,y: MGF1(x,y,self._hashObj) | |||
| self._mgf = lambda x, y: MGF1(x, y, self._hashObj) | |||
| self._label = _copy_bytes(None, None, label) | |||
| self._randfunc = randfunc | |||
| @@ -107,7 +107,7 @@ class PKCS1OAEP_Cipher: | |||
| # See 7.1.1 in RFC3447 | |||
| modBits = Crypto.Util.number.size(self._key.n) | |||
| k = ceil_div(modBits, 8) # Convert from bits to bytes | |||
| k = ceil_div(modBits, 8) # Convert from bits to bytes | |||
| hLen = self._hashObj.digest_size | |||
| mLen = len(message) | |||
| @@ -118,13 +118,13 @@ class PKCS1OAEP_Cipher: | |||
| # Step 2a | |||
| lHash = sha1(self._label).digest() | |||
| # Step 2b | |||
| ps = b'\x00' * ps_len | |||
| ps = b"\x00" * ps_len | |||
| # Step 2c | |||
| db = lHash + ps + b'\x01' + _copy_bytes(None, None, message) | |||
| db = lHash + ps + b"\x01" + _copy_bytes(None, None, message) | |||
| # Step 2d | |||
| ros = self._randfunc(hLen) | |||
| # Step 2e | |||
| dbMask = self._mgf(ros, k-hLen-1) | |||
| dbMask = self._mgf(ros, k - hLen - 1) | |||
| # Step 2f | |||
| maskedDB = strxor(db, dbMask) | |||
| # Step 2g | |||
| @@ -132,7 +132,7 @@ class PKCS1OAEP_Cipher: | |||
| # Step 2h | |||
| maskedSeed = strxor(ros, seedMask) | |||
| # Step 2i | |||
| em = b'\x00' + maskedSeed + maskedDB | |||
| em = b"\x00" + maskedSeed + maskedDB | |||
| # Step 3a (OS2IP) | |||
| em_int = bytes_to_long(em) | |||
| # Step 3b (RSAEP) | |||
| @@ -160,10 +160,10 @@ class PKCS1OAEP_Cipher: | |||
| """ | |||
| # See 7.1.2 in RFC3447 | |||
| modBits = Crypto.Util.number.size(self._key.n) | |||
| k = ceil_div(modBits,8) # Convert from bits to bytes | |||
| k = ceil_div(modBits, 8) # Convert from bits to bytes | |||
| hLen = self._hashObj.digest_size | |||
| # Step 1b and 1c | |||
| if len(ciphertext) != k or k<hLen+2: | |||
| if len(ciphertext) != k or k < hLen + 2: | |||
| raise ValueError("Ciphertext with incorrect length.") | |||
| # Step 2a (O2SIP) | |||
| ct_int = bytes_to_long(ciphertext) | |||
| @@ -178,18 +178,18 @@ class PKCS1OAEP_Cipher: | |||
| y = em[0] | |||
| # y must be 0, but we MUST NOT check it here in order not to | |||
| # allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143) | |||
| maskedSeed = em[1:hLen+1] | |||
| maskedDB = em[hLen+1:] | |||
| maskedSeed = em[1 : hLen + 1] | |||
| maskedDB = em[hLen + 1 :] | |||
| # Step 3c | |||
| seedMask = self._mgf(maskedDB, hLen) | |||
| # Step 3d | |||
| seed = strxor(maskedSeed, seedMask) | |||
| # Step 3e | |||
| dbMask = self._mgf(seed, k-hLen-1) | |||
| dbMask = self._mgf(seed, k - hLen - 1) | |||
| # Step 3f | |||
| db = strxor(maskedDB, dbMask) | |||
| # Step 3g | |||
| one_pos = hLen + db[hLen:].find(b'\x01') | |||
| one_pos = hLen + db[hLen:].find(b"\x01") | |||
| lHash1 = db[:hLen] | |||
| invalid = bord(y) | int(one_pos < hLen) | |||
| hash_compare = strxor(lHash1, lHash) | |||
| @@ -200,9 +200,10 @@ class PKCS1OAEP_Cipher: | |||
| if invalid != 0: | |||
| raise ValueError("Incorrect decryption.") | |||
| # Step 4 | |||
| return db[one_pos + 1:] | |||
| return db[one_pos + 1 :] | |||
| def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None): | |||
| def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): | |||
| """Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption. | |||
| :param key: | |||
| @@ -21,7 +21,7 @@ from models.account import Account | |||
| def run(script): | |||
| return subprocess.getstatusoutput('source /root/.bashrc && ' + script) | |||
| return subprocess.getstatusoutput("source /root/.bashrc && " + script) | |||
| class TimestampField(fields.Raw): | |||
| @@ -36,29 +36,29 @@ def email(email): | |||
| if re.match(pattern, email) is not None: | |||
| return email | |||
| error = ('{email} is not a valid email.' | |||
| .format(email=email)) | |||
| error = "{email} is not a valid email.".format(email=email) | |||
| raise ValueError(error) | |||
| def uuid_value(value): | |||
| if value == '': | |||
| if value == "": | |||
| return str(value) | |||
| try: | |||
| uuid_obj = uuid.UUID(value) | |||
| return str(uuid_obj) | |||
| except ValueError: | |||
| error = ('{value} is not a valid uuid.' | |||
| .format(value=value)) | |||
| error = "{value} is not a valid uuid.".format(value=value) | |||
| raise ValueError(error) | |||
| def alphanumeric(value: str): | |||
| # check if the value is alphanumeric and underlined | |||
| if re.match(r'^[a-zA-Z0-9_]+$', value): | |||
| if re.match(r"^[a-zA-Z0-9_]+$", value): | |||
| return value | |||
| raise ValueError(f'{value} is not a valid alphanumeric value') | |||
| raise ValueError(f"{value} is not a valid alphanumeric value") | |||
| def timestamp_value(timestamp): | |||
| try: | |||
| @@ -67,31 +67,32 @@ def timestamp_value(timestamp): | |||
| raise ValueError | |||
| return int_timestamp | |||
| except ValueError: | |||
| error = ('{timestamp} is not a valid timestamp.' | |||
| .format(timestamp=timestamp)) | |||
| error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) | |||
| raise ValueError(error) | |||
| class str_len: | |||
| """ Restrict input to an integer in a range (inclusive) """ | |||
| """Restrict input to an integer in a range (inclusive)""" | |||
| def __init__(self, max_length, argument='argument'): | |||
| def __init__(self, max_length, argument="argument"): | |||
| self.max_length = max_length | |||
| self.argument = argument | |||
| def __call__(self, value): | |||
| length = len(value) | |||
| if length > self.max_length: | |||
| error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' | |||
| .format(arg=self.argument, val=value, length=self.max_length)) | |||
| error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( | |||
| arg=self.argument, val=value, length=self.max_length | |||
| ) | |||
| raise ValueError(error) | |||
| return value | |||
| class float_range: | |||
| """ Restrict input to an float in a range (inclusive) """ | |||
| def __init__(self, low, high, argument='argument'): | |||
| """Restrict input to an float in a range (inclusive)""" | |||
| def __init__(self, low, high, argument="argument"): | |||
| self.low = low | |||
| self.high = high | |||
| self.argument = argument | |||
| @@ -99,15 +100,16 @@ class float_range: | |||
| def __call__(self, value): | |||
| value = _get_float(value) | |||
| if value < self.low or value > self.high: | |||
| error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' | |||
| .format(arg=self.argument, val=value, lo=self.low, hi=self.high)) | |||
| error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( | |||
| arg=self.argument, val=value, lo=self.low, hi=self.high | |||
| ) | |||
| raise ValueError(error) | |||
| return value | |||
| class datetime_string: | |||
| def __init__(self, format, argument='argument'): | |||
| def __init__(self, format, argument="argument"): | |||
| self.format = format | |||
| self.argument = argument | |||
| @@ -115,8 +117,9 @@ class datetime_string: | |||
| try: | |||
| datetime.strptime(value, self.format) | |||
| except ValueError: | |||
| error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' | |||
| .format(arg=self.argument, val=value, format=self.format)) | |||
| error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( | |||
| arg=self.argument, val=value, format=self.format | |||
| ) | |||
| raise ValueError(error) | |||
| return value | |||
| @@ -126,14 +129,14 @@ def _get_float(value): | |||
| try: | |||
| return float(value) | |||
| except (TypeError, ValueError): | |||
| raise ValueError('{} is not a valid float'.format(value)) | |||
| raise ValueError("{} is not a valid float".format(value)) | |||
| def timezone(timezone_string): | |||
| if timezone_string and timezone_string in available_timezones(): | |||
| return timezone_string | |||
| error = ('{timezone_string} is not a valid timezone.' | |||
| .format(timezone_string=timezone_string)) | |||
| error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) | |||
| raise ValueError(error) | |||
| @@ -147,8 +150,8 @@ def generate_string(n): | |||
| def get_remote_ip(request) -> str: | |||
| if request.headers.get('CF-Connecting-IP'): | |||
| return request.headers.get('Cf-Connecting-Ip') | |||
| if request.headers.get("CF-Connecting-IP"): | |||
| return request.headers.get("Cf-Connecting-Ip") | |||
| elif request.headers.getlist("X-Forwarded-For"): | |||
| return request.headers.getlist("X-Forwarded-For")[0] | |||
| else: | |||
| @@ -156,54 +159,45 @@ def get_remote_ip(request) -> str: | |||
| def generate_text_hash(text: str) -> str: | |||
| hash_text = str(text) + 'None' | |||
| hash_text = str(text) + "None" | |||
| return sha256(hash_text.encode()).hexdigest() | |||
| def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: | |||
| if isinstance(response, dict): | |||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||
| return Response(response=json.dumps(response), status=200, mimetype="application/json") | |||
| else: | |||
| def generate() -> Generator: | |||
| yield from response | |||
| return Response(stream_with_context(generate()), status=200, | |||
| mimetype='text/event-stream') | |||
| return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") | |||
| class TokenManager: | |||
| @classmethod | |||
| def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: | |||
| old_token = cls._get_current_token_for_account(account.id, token_type) | |||
| if old_token: | |||
| if isinstance(old_token, bytes): | |||
| old_token = old_token.decode('utf-8') | |||
| old_token = old_token.decode("utf-8") | |||
| cls.revoke_token(old_token, token_type) | |||
| token = str(uuid.uuid4()) | |||
| token_data = { | |||
| 'account_id': account.id, | |||
| 'email': account.email, | |||
| 'token_type': token_type | |||
| } | |||
| token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} | |||
| if additional_data: | |||
| token_data.update(additional_data) | |||
| expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] | |||
| expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] | |||
| token_key = cls._get_token_key(token, token_type) | |||
| redis_client.setex( | |||
| token_key, | |||
| expiry_hours * 60 * 60, | |||
| json.dumps(token_data) | |||
| ) | |||
| redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) | |||
| cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) | |||
| return token | |||
| @classmethod | |||
| def _get_token_key(cls, token: str, token_type: str) -> str: | |||
| return f'{token_type}:token:{token}' | |||
| return f"{token_type}:token:{token}" | |||
| @classmethod | |||
| def revoke_token(cls, token: str, token_type: str): | |||
| @@ -233,7 +227,7 @@ class TokenManager: | |||
| @classmethod | |||
| def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | |||
| return f'{token_type}:account:{account_id}' | |||
| return f"{token_type}:account:{account_id}" | |||
| class RateLimiter: | |||
| @@ -250,7 +244,7 @@ class RateLimiter: | |||
| current_time = int(time.time()) | |||
| window_start_time = current_time - self.time_window | |||
| redis_client.zremrangebyscore(key, '-inf', window_start_time) | |||
| redis_client.zremrangebyscore(key, "-inf", window_start_time) | |||
| attempts = redis_client.zcard(key) | |||
| if attempts and int(attempts) >= self.max_attempts: | |||
| @@ -1,4 +1,3 @@ | |||
| class InfiniteScrollPagination: | |||
| def __init__(self, data, limit, has_more): | |||
| self.data = data | |||
| @@ -10,13 +10,13 @@ def parse_json_markdown(json_string: str) -> dict: | |||
| end_index = json_string.find("```", start_index + len("```json")) | |||
| if start_index != -1 and end_index != -1: | |||
| extracted_content = json_string[start_index + len("```json"):end_index].strip() | |||
| extracted_content = json_string[start_index + len("```json") : end_index].strip() | |||
| # Parse the JSON string into a Python dictionary | |||
| parsed = json.loads(extracted_content) | |||
| elif start_index != -1 and end_index == -1 and json_string.endswith("``"): | |||
| end_index = json_string.find("``", start_index + len("```json")) | |||
| extracted_content = json_string[start_index + len("```json"):end_index].strip() | |||
| extracted_content = json_string[start_index + len("```json") : end_index].strip() | |||
| # Parse the JSON string into a Python dictionary | |||
| parsed = json.loads(extracted_content) | |||
| @@ -37,7 +37,6 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: | |||
| for key in expected_keys: | |||
| if key not in json_obj: | |||
| raise OutputParserException( | |||
| f"Got invalid return object. Expected key `{key}` " | |||
| f"to be present, but got {json_obj}" | |||
| f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" | |||
| ) | |||
| return json_obj | |||
| @@ -51,27 +51,29 @@ def login_required(func): | |||
| @wraps(func) | |||
| def decorated_view(*args, **kwargs): | |||
| auth_header = request.headers.get('Authorization') | |||
| admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False') | |||
| if admin_api_key_enable.lower() == 'true': | |||
| auth_header = request.headers.get("Authorization") | |||
| admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") | |||
| if admin_api_key_enable.lower() == "true": | |||
| if auth_header: | |||
| if ' ' not in auth_header: | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| if " " not in auth_header: | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| admin_api_key = os.getenv('ADMIN_API_KEY') | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| admin_api_key = os.getenv("ADMIN_API_KEY") | |||
| if admin_api_key: | |||
| if os.getenv('ADMIN_API_KEY') == auth_token: | |||
| workspace_id = request.headers.get('X-WORKSPACE-ID') | |||
| if os.getenv("ADMIN_API_KEY") == auth_token: | |||
| workspace_id = request.headers.get("X-WORKSPACE-ID") | |||
| if workspace_id: | |||
| tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ | |||
| .filter(Tenant.id == workspace_id) \ | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | |||
| .filter(TenantAccountJoin.role == 'owner') \ | |||
| tenant_account_join = ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .filter(Tenant.id == workspace_id) | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .filter(TenantAccountJoin.role == "owner") | |||
| .one_or_none() | |||
| ) | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| account = Account.query.filter_by(id=ta.account_id).first() | |||
| @@ -35,31 +35,31 @@ class OAuth: | |||
| class GitHubOAuth(OAuth): | |||
| _AUTH_URL = 'https://github.com/login/oauth/authorize' | |||
| _TOKEN_URL = 'https://github.com/login/oauth/access_token' | |||
| _USER_INFO_URL = 'https://api.github.com/user' | |||
| _EMAIL_INFO_URL = 'https://api.github.com/user/emails' | |||
| _AUTH_URL = "https://github.com/login/oauth/authorize" | |||
| _TOKEN_URL = "https://github.com/login/oauth/access_token" | |||
| _USER_INFO_URL = "https://api.github.com/user" | |||
| _EMAIL_INFO_URL = "https://api.github.com/user/emails" | |||
| def get_authorization_url(self): | |||
| params = { | |||
| 'client_id': self.client_id, | |||
| 'redirect_uri': self.redirect_uri, | |||
| 'scope': 'user:email' # Request only basic user information | |||
| "client_id": self.client_id, | |||
| "redirect_uri": self.redirect_uri, | |||
| "scope": "user:email", # Request only basic user information | |||
| } | |||
| return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | |||
| def get_access_token(self, code: str): | |||
| data = { | |||
| 'client_id': self.client_id, | |||
| 'client_secret': self.client_secret, | |||
| 'code': code, | |||
| 'redirect_uri': self.redirect_uri | |||
| "client_id": self.client_id, | |||
| "client_secret": self.client_secret, | |||
| "code": code, | |||
| "redirect_uri": self.redirect_uri, | |||
| } | |||
| headers = {'Accept': 'application/json'} | |||
| headers = {"Accept": "application/json"} | |||
| response = requests.post(self._TOKEN_URL, data=data, headers=headers) | |||
| response_json = response.json() | |||
| access_token = response_json.get('access_token') | |||
| access_token = response_json.get("access_token") | |||
| if not access_token: | |||
| raise ValueError(f"Error in GitHub OAuth: {response_json}") | |||
| @@ -67,55 +67,51 @@ class GitHubOAuth(OAuth): | |||
| return access_token | |||
| def get_raw_user_info(self, token: str): | |||
| headers = {'Authorization': f"token {token}"} | |||
| headers = {"Authorization": f"token {token}"} | |||
| response = requests.get(self._USER_INFO_URL, headers=headers) | |||
| response.raise_for_status() | |||
| user_info = response.json() | |||
| email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) | |||
| email_info = email_response.json() | |||
| primary_email = next((email for email in email_info if email['primary'] == True), None) | |||
| primary_email = next((email for email in email_info if email["primary"] == True), None) | |||
| return {**user_info, 'email': primary_email['email']} | |||
| return {**user_info, "email": primary_email["email"]} | |||
| def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: | |||
| email = raw_info.get('email') | |||
| email = raw_info.get("email") | |||
| if not email: | |||
| email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" | |||
| return OAuthUserInfo( | |||
| id=str(raw_info['id']), | |||
| name=raw_info['name'], | |||
| email=email | |||
| ) | |||
| return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) | |||
| class GoogleOAuth(OAuth): | |||
| _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth' | |||
| _TOKEN_URL = 'https://oauth2.googleapis.com/token' | |||
| _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo' | |||
| _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" | |||
| _TOKEN_URL = "https://oauth2.googleapis.com/token" | |||
| _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" | |||
| def get_authorization_url(self): | |||
| params = { | |||
| 'client_id': self.client_id, | |||
| 'response_type': 'code', | |||
| 'redirect_uri': self.redirect_uri, | |||
| 'scope': 'openid email' | |||
| "client_id": self.client_id, | |||
| "response_type": "code", | |||
| "redirect_uri": self.redirect_uri, | |||
| "scope": "openid email", | |||
| } | |||
| return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | |||
| def get_access_token(self, code: str): | |||
| data = { | |||
| 'client_id': self.client_id, | |||
| 'client_secret': self.client_secret, | |||
| 'code': code, | |||
| 'grant_type': 'authorization_code', | |||
| 'redirect_uri': self.redirect_uri | |||
| "client_id": self.client_id, | |||
| "client_secret": self.client_secret, | |||
| "code": code, | |||
| "grant_type": "authorization_code", | |||
| "redirect_uri": self.redirect_uri, | |||
| } | |||
| headers = {'Accept': 'application/json'} | |||
| headers = {"Accept": "application/json"} | |||
| response = requests.post(self._TOKEN_URL, data=data, headers=headers) | |||
| response_json = response.json() | |||
| access_token = response_json.get('access_token') | |||
| access_token = response_json.get("access_token") | |||
| if not access_token: | |||
| raise ValueError(f"Error in Google OAuth: {response_json}") | |||
| @@ -123,16 +119,10 @@ class GoogleOAuth(OAuth): | |||
| return access_token | |||
| def get_raw_user_info(self, token: str): | |||
| headers = {'Authorization': f"Bearer {token}"} | |||
| headers = {"Authorization": f"Bearer {token}"} | |||
| response = requests.get(self._USER_INFO_URL, headers=headers) | |||
| response.raise_for_status() | |||
| return response.json() | |||
| def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: | |||
| return OAuthUserInfo( | |||
| id=str(raw_info['sub']), | |||
| name=None, | |||
| email=raw_info['email'] | |||
| ) | |||
| return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) | |||
| @@ -21,53 +21,49 @@ class OAuthDataSource: | |||
| class NotionOAuth(OAuthDataSource): | |||
| _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' | |||
| _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' | |||
| _AUTH_URL = "https://api.notion.com/v1/oauth/authorize" | |||
| _TOKEN_URL = "https://api.notion.com/v1/oauth/token" | |||
| _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" | |||
| _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" | |||
| _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" | |||
| def get_authorization_url(self): | |||
| params = { | |||
| 'client_id': self.client_id, | |||
| 'response_type': 'code', | |||
| 'redirect_uri': self.redirect_uri, | |||
| 'owner': 'user' | |||
| "client_id": self.client_id, | |||
| "response_type": "code", | |||
| "redirect_uri": self.redirect_uri, | |||
| "owner": "user", | |||
| } | |||
| return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | |||
| def get_access_token(self, code: str): | |||
| data = { | |||
| 'code': code, | |||
| 'grant_type': 'authorization_code', | |||
| 'redirect_uri': self.redirect_uri | |||
| } | |||
| headers = {'Accept': 'application/json'} | |||
| data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} | |||
| headers = {"Accept": "application/json"} | |||
| auth = (self.client_id, self.client_secret) | |||
| response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) | |||
| response_json = response.json() | |||
| access_token = response_json.get('access_token') | |||
| access_token = response_json.get("access_token") | |||
| if not access_token: | |||
| raise ValueError(f"Error in Notion OAuth: {response_json}") | |||
| workspace_name = response_json.get('workspace_name') | |||
| workspace_icon = response_json.get('workspace_icon') | |||
| workspace_id = response_json.get('workspace_id') | |||
| workspace_name = response_json.get("workspace_name") | |||
| workspace_icon = response_json.get("workspace_icon") | |||
| workspace_id = response_json.get("workspace_id") | |||
| # get all authorized pages | |||
| pages = self.get_authorized_pages(access_token) | |||
| source_info = { | |||
| 'workspace_name': workspace_name, | |||
| 'workspace_icon': workspace_icon, | |||
| 'workspace_id': workspace_id, | |||
| 'pages': pages, | |||
| 'total': len(pages) | |||
| "workspace_name": workspace_name, | |||
| "workspace_icon": workspace_icon, | |||
| "workspace_id": workspace_id, | |||
| "pages": pages, | |||
| "total": len(pages), | |||
| } | |||
| # save data source binding | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.access_token == access_token | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.access_token == access_token, | |||
| ) | |||
| ).first() | |||
| if data_source_binding: | |||
| @@ -79,7 +75,7 @@ class NotionOAuth(OAuthDataSource): | |||
| tenant_id=current_user.current_tenant_id, | |||
| access_token=access_token, | |||
| source_info=source_info, | |||
| provider='notion' | |||
| provider="notion", | |||
| ) | |||
| db.session.add(new_data_source_binding) | |||
| db.session.commit() | |||
| @@ -91,18 +87,18 @@ class NotionOAuth(OAuthDataSource): | |||
| # get all authorized pages | |||
| pages = self.get_authorized_pages(access_token) | |||
| source_info = { | |||
| 'workspace_name': workspace_name, | |||
| 'workspace_icon': workspace_icon, | |||
| 'workspace_id': workspace_id, | |||
| 'pages': pages, | |||
| 'total': len(pages) | |||
| "workspace_name": workspace_name, | |||
| "workspace_icon": workspace_icon, | |||
| "workspace_id": workspace_id, | |||
| "pages": pages, | |||
| "total": len(pages), | |||
| } | |||
| # save data source binding | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.access_token == access_token | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.access_token == access_token, | |||
| ) | |||
| ).first() | |||
| if data_source_binding: | |||
| @@ -114,7 +110,7 @@ class NotionOAuth(OAuthDataSource): | |||
| tenant_id=current_user.current_tenant_id, | |||
| access_token=access_token, | |||
| source_info=source_info, | |||
| provider='notion' | |||
| provider="notion", | |||
| ) | |||
| db.session.add(new_data_source_binding) | |||
| db.session.commit() | |||
| @@ -124,9 +120,9 @@ class NotionOAuth(OAuthDataSource): | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.id == binding_id, | |||
| DataSourceOauthBinding.disabled == False | |||
| DataSourceOauthBinding.disabled == False, | |||
| ) | |||
| ).first() | |||
| if data_source_binding: | |||
| @@ -134,17 +130,17 @@ class NotionOAuth(OAuthDataSource): | |||
| pages = self.get_authorized_pages(data_source_binding.access_token) | |||
| source_info = data_source_binding.source_info | |||
| new_source_info = { | |||
| 'workspace_name': source_info['workspace_name'], | |||
| 'workspace_icon': source_info['workspace_icon'], | |||
| 'workspace_id': source_info['workspace_id'], | |||
| 'pages': pages, | |||
| 'total': len(pages) | |||
| "workspace_name": source_info["workspace_name"], | |||
| "workspace_icon": source_info["workspace_icon"], | |||
| "workspace_id": source_info["workspace_id"], | |||
| "pages": pages, | |||
| "total": len(pages), | |||
| } | |||
| data_source_binding.source_info = new_source_info | |||
| data_source_binding.disabled = False | |||
| db.session.commit() | |||
| else: | |||
| raise ValueError('Data source binding not found') | |||
| raise ValueError("Data source binding not found") | |||
| def get_authorized_pages(self, access_token: str): | |||
| pages = [] | |||
| @@ -152,143 +148,121 @@ class NotionOAuth(OAuthDataSource): | |||
| database_results = self.notion_database_search(access_token) | |||
| # get page detail | |||
| for page_result in page_results: | |||
| page_id = page_result['id'] | |||
| page_name = 'Untitled' | |||
| for key in page_result['properties']: | |||
| if 'title' in page_result['properties'][key] and page_result['properties'][key]['title']: | |||
| title_list = page_result['properties'][key]['title'] | |||
| if len(title_list) > 0 and 'plain_text' in title_list[0]: | |||
| page_name = title_list[0]['plain_text'] | |||
| page_icon = page_result['icon'] | |||
| page_id = page_result["id"] | |||
| page_name = "Untitled" | |||
| for key in page_result["properties"]: | |||
| if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]: | |||
| title_list = page_result["properties"][key]["title"] | |||
| if len(title_list) > 0 and "plain_text" in title_list[0]: | |||
| page_name = title_list[0]["plain_text"] | |||
| page_icon = page_result["icon"] | |||
| if page_icon: | |||
| icon_type = page_icon['type'] | |||
| if icon_type == 'external' or icon_type == 'file': | |||
| url = page_icon[icon_type]['url'] | |||
| icon = { | |||
| 'type': 'url', | |||
| 'url': url if url.startswith('http') else f'https://www.notion.so{url}' | |||
| } | |||
| icon_type = page_icon["type"] | |||
| if icon_type == "external" or icon_type == "file": | |||
| url = page_icon[icon_type]["url"] | |||
| icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} | |||
| else: | |||
| icon = { | |||
| 'type': 'emoji', | |||
| 'emoji': page_icon[icon_type] | |||
| } | |||
| icon = {"type": "emoji", "emoji": page_icon[icon_type]} | |||
| else: | |||
| icon = None | |||
| parent = page_result['parent'] | |||
| parent_type = parent['type'] | |||
| if parent_type == 'block_id': | |||
| parent = page_result["parent"] | |||
| parent_type = parent["type"] | |||
| if parent_type == "block_id": | |||
| parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) | |||
| elif parent_type == 'workspace': | |||
| parent_id = 'root' | |||
| elif parent_type == "workspace": | |||
| parent_id = "root" | |||
| else: | |||
| parent_id = parent[parent_type] | |||
| page = { | |||
| 'page_id': page_id, | |||
| 'page_name': page_name, | |||
| 'page_icon': icon, | |||
| 'parent_id': parent_id, | |||
| 'type': 'page' | |||
| "page_id": page_id, | |||
| "page_name": page_name, | |||
| "page_icon": icon, | |||
| "parent_id": parent_id, | |||
| "type": "page", | |||
| } | |||
| pages.append(page) | |||
| # get database detail | |||
| for database_result in database_results: | |||
| page_id = database_result['id'] | |||
| if len(database_result['title']) > 0: | |||
| page_name = database_result['title'][0]['plain_text'] | |||
| page_id = database_result["id"] | |||
| if len(database_result["title"]) > 0: | |||
| page_name = database_result["title"][0]["plain_text"] | |||
| else: | |||
| page_name = 'Untitled' | |||
| page_icon = database_result['icon'] | |||
| page_name = "Untitled" | |||
| page_icon = database_result["icon"] | |||
| if page_icon: | |||
| icon_type = page_icon['type'] | |||
| if icon_type == 'external' or icon_type == 'file': | |||
| url = page_icon[icon_type]['url'] | |||
| icon = { | |||
| 'type': 'url', | |||
| 'url': url if url.startswith('http') else f'https://www.notion.so{url}' | |||
| } | |||
| icon_type = page_icon["type"] | |||
| if icon_type == "external" or icon_type == "file": | |||
| url = page_icon[icon_type]["url"] | |||
| icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} | |||
| else: | |||
| icon = { | |||
| 'type': icon_type, | |||
| icon_type: page_icon[icon_type] | |||
| } | |||
| icon = {"type": icon_type, icon_type: page_icon[icon_type]} | |||
| else: | |||
| icon = None | |||
| parent = database_result['parent'] | |||
| parent_type = parent['type'] | |||
| if parent_type == 'block_id': | |||
| parent = database_result["parent"] | |||
| parent_type = parent["type"] | |||
| if parent_type == "block_id": | |||
| parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) | |||
| elif parent_type == 'workspace': | |||
| parent_id = 'root' | |||
| elif parent_type == "workspace": | |||
| parent_id = "root" | |||
| else: | |||
| parent_id = parent[parent_type] | |||
| page = { | |||
| 'page_id': page_id, | |||
| 'page_name': page_name, | |||
| 'page_icon': icon, | |||
| 'parent_id': parent_id, | |||
| 'type': 'database' | |||
| "page_id": page_id, | |||
| "page_name": page_name, | |||
| "page_icon": icon, | |||
| "parent_id": parent_id, | |||
| "type": "database", | |||
| } | |||
| pages.append(page) | |||
| return pages | |||
| def notion_page_search(self, access_token: str): | |||
| data = { | |||
| 'filter': { | |||
| "value": "page", | |||
| "property": "object" | |||
| } | |||
| } | |||
| data = {"filter": {"value": "page", "property": "object"}} | |||
| headers = { | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f"Bearer {access_token}", | |||
| 'Notion-Version': '2022-06-28', | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {access_token}", | |||
| "Notion-Version": "2022-06-28", | |||
| } | |||
| response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | |||
| response_json = response.json() | |||
| results = response_json.get('results', []) | |||
| results = response_json.get("results", []) | |||
| return results | |||
| def notion_block_parent_page_id(self, access_token: str, block_id: str): | |||
| headers = { | |||
| 'Authorization': f"Bearer {access_token}", | |||
| 'Notion-Version': '2022-06-28', | |||
| "Authorization": f"Bearer {access_token}", | |||
| "Notion-Version": "2022-06-28", | |||
| } | |||
| response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers) | |||
| response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) | |||
| response_json = response.json() | |||
| parent = response_json['parent'] | |||
| parent_type = parent['type'] | |||
| if parent_type == 'block_id': | |||
| parent = response_json["parent"] | |||
| parent_type = parent["type"] | |||
| if parent_type == "block_id": | |||
| return self.notion_block_parent_page_id(access_token, parent[parent_type]) | |||
| return parent[parent_type] | |||
| def notion_workspace_name(self, access_token: str): | |||
| headers = { | |||
| 'Authorization': f"Bearer {access_token}", | |||
| 'Notion-Version': '2022-06-28', | |||
| "Authorization": f"Bearer {access_token}", | |||
| "Notion-Version": "2022-06-28", | |||
| } | |||
| response = requests.get(url=self._NOTION_BOT_USER, headers=headers) | |||
| response_json = response.json() | |||
| if 'object' in response_json and response_json['object'] == 'user': | |||
| user_type = response_json['type'] | |||
| if "object" in response_json and response_json["object"] == "user": | |||
| user_type = response_json["type"] | |||
| user_info = response_json[user_type] | |||
| if 'workspace_name' in user_info: | |||
| return user_info['workspace_name'] | |||
| return 'workspace' | |||
| if "workspace_name" in user_info: | |||
| return user_info["workspace_name"] | |||
| return "workspace" | |||
| def notion_database_search(self, access_token: str): | |||
| data = { | |||
| 'filter': { | |||
| "value": "database", | |||
| "property": "object" | |||
| } | |||
| } | |||
| data = {"filter": {"value": "database", "property": "object"}} | |||
| headers = { | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f"Bearer {access_token}", | |||
| 'Notion-Version': '2022-06-28', | |||
| "Content-Type": "application/json", | |||
| "Authorization": f"Bearer {access_token}", | |||
| "Notion-Version": "2022-06-28", | |||
| } | |||
| response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | |||
| response_json = response.json() | |||
| results = response_json.get('results', []) | |||
| results = response_json.get("results", []) | |||
| return results | |||
| @@ -9,14 +9,14 @@ class PassportService: | |||
| self.sk = dify_config.SECRET_KEY | |||
| def issue(self, payload): | |||
| return jwt.encode(payload, self.sk, algorithm='HS256') | |||
| return jwt.encode(payload, self.sk, algorithm="HS256") | |||
| def verify(self, token): | |||
| try: | |||
| return jwt.decode(token, self.sk, algorithms=['HS256']) | |||
| return jwt.decode(token, self.sk, algorithms=["HS256"]) | |||
| except jwt.exceptions.InvalidSignatureError: | |||
| raise Unauthorized('Invalid token signature.') | |||
| raise Unauthorized("Invalid token signature.") | |||
| except jwt.exceptions.DecodeError: | |||
| raise Unauthorized('Invalid token.') | |||
| raise Unauthorized("Invalid token.") | |||
| except jwt.exceptions.ExpiredSignatureError: | |||
| raise Unauthorized('Token has expired.') | |||
| raise Unauthorized("Token has expired.") | |||
| @@ -5,6 +5,7 @@ import re | |||
| password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" | |||
| def valid_password(password): | |||
| # Define a regex pattern for password rules | |||
| pattern = password_pattern | |||
| @@ -12,11 +13,11 @@ def valid_password(password): | |||
| if re.match(pattern, password) is not None: | |||
| return password | |||
| raise ValueError('Not a valid password.') | |||
| raise ValueError("Not a valid password.") | |||
| def hash_password(password_str, salt_byte): | |||
| dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000) | |||
| dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) | |||
| return binascii.hexlify(dk) | |||
| @@ -48,7 +48,7 @@ def encrypt(text, public_key): | |||
| def get_decrypt_decoding(tenant_id): | |||
| filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" | |||
| cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) | |||
| cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) | |||
| private_key = redis_client.get(cache_key) | |||
| if not private_key: | |||
| try: | |||
| @@ -66,12 +66,12 @@ def get_decrypt_decoding(tenant_id): | |||
| def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): | |||
| if encrypted_text.startswith(prefix_hybrid): | |||
| encrypted_text = encrypted_text[len(prefix_hybrid):] | |||
| encrypted_text = encrypted_text[len(prefix_hybrid) :] | |||
| enc_aes_key = encrypted_text[:rsa_key.size_in_bytes()] | |||
| nonce = encrypted_text[rsa_key.size_in_bytes():rsa_key.size_in_bytes() + 16] | |||
| tag = encrypted_text[rsa_key.size_in_bytes() + 16:rsa_key.size_in_bytes() + 32] | |||
| ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32:] | |||
| enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()] | |||
| nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16] | |||
| tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32] | |||
| ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :] | |||
| aes_key = cipher_rsa.decrypt(enc_aes_key) | |||
| @@ -5,7 +5,9 @@ from email.mime.text import MIMEText | |||
| class SMTPClient: | |||
| def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False): | |||
| def __init__( | |||
| self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False | |||
| ): | |||
| self.server = server | |||
| self.port = port | |||
| self._from = _from | |||
| @@ -25,17 +27,17 @@ class SMTPClient: | |||
| smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) | |||
| else: | |||
| smtp = smtplib.SMTP(self.server, self.port, timeout=10) | |||
| if self.username and self.password: | |||
| smtp.login(self.username, self.password) | |||
| msg = MIMEMultipart() | |||
| msg['Subject'] = mail['subject'] | |||
| msg['From'] = self._from | |||
| msg['To'] = mail['to'] | |||
| msg.attach(MIMEText(mail['html'], 'html')) | |||
| msg["Subject"] = mail["subject"] | |||
| msg["From"] = self._from | |||
| msg["To"] = mail["to"] | |||
| msg.attach(MIMEText(mail["html"], "html")) | |||
| smtp.sendmail(self._from, mail['to'], msg.as_string()) | |||
| smtp.sendmail(self._from, mail["to"], msg.as_string()) | |||
| except smtplib.SMTPException as e: | |||
| logging.error(f"SMTP error occurred: {str(e)}") | |||
| raise | |||
| @@ -73,12 +73,10 @@ exclude = [ | |||
| "core/**/*.py", | |||
| "controllers/**/*.py", | |||
| "models/**/*.py", | |||
| "utils/**/*.py", | |||
| "migrations/**/*", | |||
| "services/**/*.py", | |||
| "tasks/**/*.py", | |||
| "tests/**/*.py", | |||
| "libs/**/*.py", | |||
| "configs/**/*.py", | |||
| ] | |||