| TEST_CRAWL_SITE_URL = "https://www.google.com" | TEST_CRAWL_SITE_URL = "https://www.google.com" | ||||
| FIRECRAWL_API_VERSION = "v0" | FIRECRAWL_API_VERSION = "v0" | ||||
| test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape" | |||||
| test_api_endpoint = self.api_base_url.rstrip("/") + f"/{FIRECRAWL_API_VERSION}/scrape" | |||||
| headers = { | headers = { | ||||
| "Authorization": f"Bearer {self.api_key}", | "Authorization": f"Bearer {self.api_key}", | ||||
| data_source_binding = DataSourceBearerBinding.query.filter( | data_source_binding = DataSourceBearerBinding.query.filter( | ||||
| db.and_( | db.and_( | ||||
| DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, | DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, | ||||
| DataSourceBearerBinding.provider == 'firecrawl', | |||||
| DataSourceBearerBinding.provider == "firecrawl", | |||||
| DataSourceBearerBinding.endpoint_url == self.api_base_url, | DataSourceBearerBinding.endpoint_url == self.api_base_url, | ||||
| DataSourceBearerBinding.bearer_key == self.api_key | |||||
| DataSourceBearerBinding.bearer_key == self.api_key, | |||||
| ) | ) | ||||
| ).first() | ).first() | ||||
| if data_source_binding: | if data_source_binding: | ||||
| else: | else: | ||||
| new_data_source_binding = DataSourceBearerBinding( | new_data_source_binding = DataSourceBearerBinding( | ||||
| tenant_id=current_user.current_tenant_id, | tenant_id=current_user.current_tenant_id, | ||||
| provider='firecrawl', | |||||
| provider="firecrawl", | |||||
| endpoint_url=self.api_base_url, | endpoint_url=self.api_base_url, | ||||
| bearer_key=self.api_key | |||||
| bearer_key=self.api_key, | |||||
| ) | ) | ||||
| db.session.add(new_data_source_binding) | db.session.add(new_data_source_binding) | ||||
| db.session.commit() | db.session.commit() |
| class BaseHTTPException(HTTPException): | class BaseHTTPException(HTTPException): | ||||
| error_code: str = 'unknown' | |||||
| error_code: str = "unknown" | |||||
| data: Optional[dict] = None | data: Optional[dict] = None | ||||
| def __init__(self, description=None, response=None): | def __init__(self, description=None, response=None): | ||||
| "code": self.error_code, | "code": self.error_code, | ||||
| "message": self.description, | "message": self.description, | ||||
| "status": self.code, | "status": self.code, | ||||
| } | |||||
| } |
| class ExternalApi(Api): | class ExternalApi(Api): | ||||
| def handle_error(self, e): | def handle_error(self, e): | ||||
| """Error handler for the API transforms a raised exception into a Flask | """Error handler for the API transforms a raised exception into a Flask | ||||
| response, with the appropriate HTTP status code and body. | response, with the appropriate HTTP status code and body. | ||||
| status_code = e.code | status_code = e.code | ||||
| default_data = { | default_data = { | ||||
| 'code': re.sub(r'(?<!^)(?=[A-Z])', '_', type(e).__name__).lower(), | |||||
| 'message': getattr(e, 'description', http_status_message(status_code)), | |||||
| 'status': status_code | |||||
| "code": re.sub(r"(?<!^)(?=[A-Z])", "_", type(e).__name__).lower(), | |||||
| "message": getattr(e, "description", http_status_message(status_code)), | |||||
| "status": status_code, | |||||
| } | } | ||||
| if default_data['message'] and default_data['message'] == 'Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)': | |||||
| default_data['message'] = 'Invalid JSON payload received or JSON payload is empty.' | |||||
| if ( | |||||
| default_data["message"] | |||||
| and default_data["message"] == "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" | |||||
| ): | |||||
| default_data["message"] = "Invalid JSON payload received or JSON payload is empty." | |||||
| headers = e.get_response().headers | headers = e.get_response().headers | ||||
| elif isinstance(e, ValueError): | elif isinstance(e, ValueError): | ||||
| status_code = 400 | status_code = 400 | ||||
| default_data = { | default_data = { | ||||
| 'code': 'invalid_param', | |||||
| 'message': str(e), | |||||
| 'status': status_code | |||||
| "code": "invalid_param", | |||||
| "message": str(e), | |||||
| "status": status_code, | |||||
| } | } | ||||
| elif isinstance(e, AppInvokeQuotaExceededError): | elif isinstance(e, AppInvokeQuotaExceededError): | ||||
| status_code = 429 | status_code = 429 | ||||
| default_data = { | default_data = { | ||||
| 'code': 'too_many_requests', | |||||
| 'message': str(e), | |||||
| 'status': status_code | |||||
| "code": "too_many_requests", | |||||
| "message": str(e), | |||||
| "status": status_code, | |||||
| } | } | ||||
| else: | else: | ||||
| status_code = 500 | status_code = 500 | ||||
| default_data = { | default_data = { | ||||
| 'message': http_status_message(status_code), | |||||
| "message": http_status_message(status_code), | |||||
| } | } | ||||
| # Werkzeug exceptions generate a content-length header which is added | # Werkzeug exceptions generate a content-length header which is added | ||||
| # to the response in addition to the actual content-length header | # to the response in addition to the actual content-length header | ||||
| # https://github.com/flask-restful/flask-restful/issues/534 | # https://github.com/flask-restful/flask-restful/issues/534 | ||||
| remove_headers = ('Content-Length',) | |||||
| remove_headers = ("Content-Length",) | |||||
| for header in remove_headers: | for header in remove_headers: | ||||
| headers.pop(header, None) | headers.pop(header, None) | ||||
| data = getattr(e, 'data', default_data) | |||||
| data = getattr(e, "data", default_data) | |||||
| error_cls_name = type(e).__name__ | error_cls_name = type(e).__name__ | ||||
| if error_cls_name in self.errors: | if error_cls_name in self.errors: | ||||
| custom_data = self.errors.get(error_cls_name, {}) | custom_data = self.errors.get(error_cls_name, {}) | ||||
| custom_data = custom_data.copy() | custom_data = custom_data.copy() | ||||
| status_code = custom_data.get('status', 500) | |||||
| status_code = custom_data.get("status", 500) | |||||
| if 'message' in custom_data: | |||||
| custom_data['message'] = custom_data['message'].format( | |||||
| message=str(e.description if hasattr(e, 'description') else e) | |||||
| if "message" in custom_data: | |||||
| custom_data["message"] = custom_data["message"].format( | |||||
| message=str(e.description if hasattr(e, "description") else e) | |||||
| ) | ) | ||||
| data.update(custom_data) | data.update(custom_data) | ||||
| # another NotAcceptable error). | # another NotAcceptable error). | ||||
| supported_mediatypes = list(self.representations.keys()) # only supported application/json | supported_mediatypes = list(self.representations.keys()) # only supported application/json | ||||
| fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" | fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" | ||||
| data = { | |||||
| 'code': 'not_acceptable', | |||||
| 'message': data.get('message') | |||||
| } | |||||
| resp = self.make_response( | |||||
| data, | |||||
| status_code, | |||||
| headers, | |||||
| fallback_mediatype = fallback_mediatype | |||||
| ) | |||||
| data = {"code": "not_acceptable", "message": data.get("message")} | |||||
| resp = self.make_response(data, status_code, headers, fallback_mediatype=fallback_mediatype) | |||||
| elif status_code == 400: | elif status_code == 400: | ||||
| if isinstance(data.get('message'), dict): | |||||
| param_key, param_value = list(data.get('message').items())[0] | |||||
| data = { | |||||
| 'code': 'invalid_param', | |||||
| 'message': param_value, | |||||
| 'params': param_key | |||||
| } | |||||
| if isinstance(data.get("message"), dict): | |||||
| param_key, param_value = list(data.get("message").items())[0] | |||||
| data = {"code": "invalid_param", "message": param_value, "params": param_key} | |||||
| else: | else: | ||||
| if 'code' not in data: | |||||
| data['code'] = 'unknown' | |||||
| if "code" not in data: | |||||
| data["code"] = "unknown" | |||||
| resp = self.make_response(data, status_code, headers) | resp = self.make_response(data, status_code, headers) | ||||
| else: | else: | ||||
| if 'code' not in data: | |||||
| data['code'] = 'unknown' | |||||
| if "code" not in data: | |||||
| data["code"] = "unknown" | |||||
| resp = self.make_response(data, status_code, headers) | resp = self.make_response(data, status_code, headers) | ||||
| if mgfunc: | if mgfunc: | ||||
| self._mgf = mgfunc | self._mgf = mgfunc | ||||
| else: | else: | ||||
| self._mgf = lambda x,y: MGF1(x,y,self._hashObj) | |||||
| self._mgf = lambda x, y: MGF1(x, y, self._hashObj) | |||||
| self._label = _copy_bytes(None, None, label) | self._label = _copy_bytes(None, None, label) | ||||
| self._randfunc = randfunc | self._randfunc = randfunc | ||||
| # See 7.1.1 in RFC3447 | # See 7.1.1 in RFC3447 | ||||
| modBits = Crypto.Util.number.size(self._key.n) | modBits = Crypto.Util.number.size(self._key.n) | ||||
| k = ceil_div(modBits, 8) # Convert from bits to bytes | |||||
| k = ceil_div(modBits, 8) # Convert from bits to bytes | |||||
| hLen = self._hashObj.digest_size | hLen = self._hashObj.digest_size | ||||
| mLen = len(message) | mLen = len(message) | ||||
| # Step 2a | # Step 2a | ||||
| lHash = sha1(self._label).digest() | lHash = sha1(self._label).digest() | ||||
| # Step 2b | # Step 2b | ||||
| ps = b'\x00' * ps_len | |||||
| ps = b"\x00" * ps_len | |||||
| # Step 2c | # Step 2c | ||||
| db = lHash + ps + b'\x01' + _copy_bytes(None, None, message) | |||||
| db = lHash + ps + b"\x01" + _copy_bytes(None, None, message) | |||||
| # Step 2d | # Step 2d | ||||
| ros = self._randfunc(hLen) | ros = self._randfunc(hLen) | ||||
| # Step 2e | # Step 2e | ||||
| dbMask = self._mgf(ros, k-hLen-1) | |||||
| dbMask = self._mgf(ros, k - hLen - 1) | |||||
| # Step 2f | # Step 2f | ||||
| maskedDB = strxor(db, dbMask) | maskedDB = strxor(db, dbMask) | ||||
| # Step 2g | # Step 2g | ||||
| # Step 2h | # Step 2h | ||||
| maskedSeed = strxor(ros, seedMask) | maskedSeed = strxor(ros, seedMask) | ||||
| # Step 2i | # Step 2i | ||||
| em = b'\x00' + maskedSeed + maskedDB | |||||
| em = b"\x00" + maskedSeed + maskedDB | |||||
| # Step 3a (OS2IP) | # Step 3a (OS2IP) | ||||
| em_int = bytes_to_long(em) | em_int = bytes_to_long(em) | ||||
| # Step 3b (RSAEP) | # Step 3b (RSAEP) | ||||
| """ | """ | ||||
| # See 7.1.2 in RFC3447 | # See 7.1.2 in RFC3447 | ||||
| modBits = Crypto.Util.number.size(self._key.n) | modBits = Crypto.Util.number.size(self._key.n) | ||||
| k = ceil_div(modBits,8) # Convert from bits to bytes | |||||
| k = ceil_div(modBits, 8) # Convert from bits to bytes | |||||
| hLen = self._hashObj.digest_size | hLen = self._hashObj.digest_size | ||||
| # Step 1b and 1c | # Step 1b and 1c | ||||
| if len(ciphertext) != k or k<hLen+2: | |||||
| if len(ciphertext) != k or k < hLen + 2: | |||||
| raise ValueError("Ciphertext with incorrect length.") | raise ValueError("Ciphertext with incorrect length.") | ||||
| # Step 2a (O2SIP) | # Step 2a (O2SIP) | ||||
| ct_int = bytes_to_long(ciphertext) | ct_int = bytes_to_long(ciphertext) | ||||
| y = em[0] | y = em[0] | ||||
| # y must be 0, but we MUST NOT check it here in order not to | # y must be 0, but we MUST NOT check it here in order not to | ||||
| # allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143) | # allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143) | ||||
| maskedSeed = em[1:hLen+1] | |||||
| maskedDB = em[hLen+1:] | |||||
| maskedSeed = em[1 : hLen + 1] | |||||
| maskedDB = em[hLen + 1 :] | |||||
| # Step 3c | # Step 3c | ||||
| seedMask = self._mgf(maskedDB, hLen) | seedMask = self._mgf(maskedDB, hLen) | ||||
| # Step 3d | # Step 3d | ||||
| seed = strxor(maskedSeed, seedMask) | seed = strxor(maskedSeed, seedMask) | ||||
| # Step 3e | # Step 3e | ||||
| dbMask = self._mgf(seed, k-hLen-1) | |||||
| dbMask = self._mgf(seed, k - hLen - 1) | |||||
| # Step 3f | # Step 3f | ||||
| db = strxor(maskedDB, dbMask) | db = strxor(maskedDB, dbMask) | ||||
| # Step 3g | # Step 3g | ||||
| one_pos = hLen + db[hLen:].find(b'\x01') | |||||
| one_pos = hLen + db[hLen:].find(b"\x01") | |||||
| lHash1 = db[:hLen] | lHash1 = db[:hLen] | ||||
| invalid = bord(y) | int(one_pos < hLen) | invalid = bord(y) | int(one_pos < hLen) | ||||
| hash_compare = strxor(lHash1, lHash) | hash_compare = strxor(lHash1, lHash) | ||||
| if invalid != 0: | if invalid != 0: | ||||
| raise ValueError("Incorrect decryption.") | raise ValueError("Incorrect decryption.") | ||||
| # Step 4 | # Step 4 | ||||
| return db[one_pos + 1:] | |||||
| return db[one_pos + 1 :] | |||||
| def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None): | |||||
| def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None): | |||||
| """Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption. | """Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption. | ||||
| :param key: | :param key: |
| def run(script): | def run(script): | ||||
| return subprocess.getstatusoutput('source /root/.bashrc && ' + script) | |||||
| return subprocess.getstatusoutput("source /root/.bashrc && " + script) | |||||
| class TimestampField(fields.Raw): | class TimestampField(fields.Raw): | ||||
| if re.match(pattern, email) is not None: | if re.match(pattern, email) is not None: | ||||
| return email | return email | ||||
| error = ('{email} is not a valid email.' | |||||
| .format(email=email)) | |||||
| error = "{email} is not a valid email.".format(email=email) | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| def uuid_value(value): | def uuid_value(value): | ||||
| if value == '': | |||||
| if value == "": | |||||
| return str(value) | return str(value) | ||||
| try: | try: | ||||
| uuid_obj = uuid.UUID(value) | uuid_obj = uuid.UUID(value) | ||||
| return str(uuid_obj) | return str(uuid_obj) | ||||
| except ValueError: | except ValueError: | ||||
| error = ('{value} is not a valid uuid.' | |||||
| .format(value=value)) | |||||
| error = "{value} is not a valid uuid.".format(value=value) | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| def alphanumeric(value: str): | def alphanumeric(value: str): | ||||
| # check if the value is alphanumeric and underlined | # check if the value is alphanumeric and underlined | ||||
| if re.match(r'^[a-zA-Z0-9_]+$', value): | |||||
| if re.match(r"^[a-zA-Z0-9_]+$", value): | |||||
| return value | return value | ||||
| raise ValueError(f'{value} is not a valid alphanumeric value') | |||||
| raise ValueError(f"{value} is not a valid alphanumeric value") | |||||
| def timestamp_value(timestamp): | def timestamp_value(timestamp): | ||||
| try: | try: | ||||
| raise ValueError | raise ValueError | ||||
| return int_timestamp | return int_timestamp | ||||
| except ValueError: | except ValueError: | ||||
| error = ('{timestamp} is not a valid timestamp.' | |||||
| .format(timestamp=timestamp)) | |||||
| error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| class str_len: | class str_len: | ||||
| """ Restrict input to an integer in a range (inclusive) """ | |||||
| """Restrict input to an integer in a range (inclusive)""" | |||||
| def __init__(self, max_length, argument='argument'): | |||||
| def __init__(self, max_length, argument="argument"): | |||||
| self.max_length = max_length | self.max_length = max_length | ||||
| self.argument = argument | self.argument = argument | ||||
| def __call__(self, value): | def __call__(self, value): | ||||
| length = len(value) | length = len(value) | ||||
| if length > self.max_length: | if length > self.max_length: | ||||
| error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' | |||||
| .format(arg=self.argument, val=value, length=self.max_length)) | |||||
| error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( | |||||
| arg=self.argument, val=value, length=self.max_length | |||||
| ) | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| return value | return value | ||||
| class float_range: | class float_range: | ||||
| """ Restrict input to an float in a range (inclusive) """ | |||||
| def __init__(self, low, high, argument='argument'): | |||||
| """Restrict input to an float in a range (inclusive)""" | |||||
| def __init__(self, low, high, argument="argument"): | |||||
| self.low = low | self.low = low | ||||
| self.high = high | self.high = high | ||||
| self.argument = argument | self.argument = argument | ||||
| def __call__(self, value): | def __call__(self, value): | ||||
| value = _get_float(value) | value = _get_float(value) | ||||
| if value < self.low or value > self.high: | if value < self.low or value > self.high: | ||||
| error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' | |||||
| .format(arg=self.argument, val=value, lo=self.low, hi=self.high)) | |||||
| error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( | |||||
| arg=self.argument, val=value, lo=self.low, hi=self.high | |||||
| ) | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| return value | return value | ||||
| class datetime_string: | class datetime_string: | ||||
| def __init__(self, format, argument='argument'): | |||||
| def __init__(self, format, argument="argument"): | |||||
| self.format = format | self.format = format | ||||
| self.argument = argument | self.argument = argument | ||||
| try: | try: | ||||
| datetime.strptime(value, self.format) | datetime.strptime(value, self.format) | ||||
| except ValueError: | except ValueError: | ||||
| error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' | |||||
| .format(arg=self.argument, val=value, format=self.format)) | |||||
| error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( | |||||
| arg=self.argument, val=value, format=self.format | |||||
| ) | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| return value | return value | ||||
| try: | try: | ||||
| return float(value) | return float(value) | ||||
| except (TypeError, ValueError): | except (TypeError, ValueError): | ||||
| raise ValueError('{} is not a valid float'.format(value)) | |||||
| raise ValueError("{} is not a valid float".format(value)) | |||||
| def timezone(timezone_string): | def timezone(timezone_string): | ||||
| if timezone_string and timezone_string in available_timezones(): | if timezone_string and timezone_string in available_timezones(): | ||||
| return timezone_string | return timezone_string | ||||
| error = ('{timezone_string} is not a valid timezone.' | |||||
| .format(timezone_string=timezone_string)) | |||||
| error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| def get_remote_ip(request) -> str: | def get_remote_ip(request) -> str: | ||||
| if request.headers.get('CF-Connecting-IP'): | |||||
| return request.headers.get('Cf-Connecting-Ip') | |||||
| if request.headers.get("CF-Connecting-IP"): | |||||
| return request.headers.get("Cf-Connecting-Ip") | |||||
| elif request.headers.getlist("X-Forwarded-For"): | elif request.headers.getlist("X-Forwarded-For"): | ||||
| return request.headers.getlist("X-Forwarded-For")[0] | return request.headers.getlist("X-Forwarded-For")[0] | ||||
| else: | else: | ||||
| def generate_text_hash(text: str) -> str: | def generate_text_hash(text: str) -> str: | ||||
| hash_text = str(text) + 'None' | |||||
| hash_text = str(text) + "None" | |||||
| return sha256(hash_text.encode()).hexdigest() | return sha256(hash_text.encode()).hexdigest() | ||||
| def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: | def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: | ||||
| if isinstance(response, dict): | if isinstance(response, dict): | ||||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||||
| return Response(response=json.dumps(response), status=200, mimetype="application/json") | |||||
| else: | else: | ||||
| def generate() -> Generator: | def generate() -> Generator: | ||||
| yield from response | yield from response | ||||
| return Response(stream_with_context(generate()), status=200, | |||||
| mimetype='text/event-stream') | |||||
| return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") | |||||
| class TokenManager: | class TokenManager: | ||||
| @classmethod | @classmethod | ||||
| def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: | def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: | ||||
| old_token = cls._get_current_token_for_account(account.id, token_type) | old_token = cls._get_current_token_for_account(account.id, token_type) | ||||
| if old_token: | if old_token: | ||||
| if isinstance(old_token, bytes): | if isinstance(old_token, bytes): | ||||
| old_token = old_token.decode('utf-8') | |||||
| old_token = old_token.decode("utf-8") | |||||
| cls.revoke_token(old_token, token_type) | cls.revoke_token(old_token, token_type) | ||||
| token = str(uuid.uuid4()) | token = str(uuid.uuid4()) | ||||
| token_data = { | |||||
| 'account_id': account.id, | |||||
| 'email': account.email, | |||||
| 'token_type': token_type | |||||
| } | |||||
| token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} | |||||
| if additional_data: | if additional_data: | ||||
| token_data.update(additional_data) | token_data.update(additional_data) | ||||
| expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] | |||||
| expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] | |||||
| token_key = cls._get_token_key(token, token_type) | token_key = cls._get_token_key(token, token_type) | ||||
| redis_client.setex( | |||||
| token_key, | |||||
| expiry_hours * 60 * 60, | |||||
| json.dumps(token_data) | |||||
| ) | |||||
| redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) | |||||
| cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) | cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) | ||||
| return token | return token | ||||
| @classmethod | @classmethod | ||||
| def _get_token_key(cls, token: str, token_type: str) -> str: | def _get_token_key(cls, token: str, token_type: str) -> str: | ||||
| return f'{token_type}:token:{token}' | |||||
| return f"{token_type}:token:{token}" | |||||
| @classmethod | @classmethod | ||||
| def revoke_token(cls, token: str, token_type: str): | def revoke_token(cls, token: str, token_type: str): | ||||
| @classmethod | @classmethod | ||||
| def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | def _get_account_token_key(cls, account_id: str, token_type: str) -> str: | ||||
| return f'{token_type}:account:{account_id}' | |||||
| return f"{token_type}:account:{account_id}" | |||||
| class RateLimiter: | class RateLimiter: | ||||
| current_time = int(time.time()) | current_time = int(time.time()) | ||||
| window_start_time = current_time - self.time_window | window_start_time = current_time - self.time_window | ||||
| redis_client.zremrangebyscore(key, '-inf', window_start_time) | |||||
| redis_client.zremrangebyscore(key, "-inf", window_start_time) | |||||
| attempts = redis_client.zcard(key) | attempts = redis_client.zcard(key) | ||||
| if attempts and int(attempts) >= self.max_attempts: | if attempts and int(attempts) >= self.max_attempts: |
| class InfiniteScrollPagination: | class InfiniteScrollPagination: | ||||
| def __init__(self, data, limit, has_more): | def __init__(self, data, limit, has_more): | ||||
| self.data = data | self.data = data |
| end_index = json_string.find("```", start_index + len("```json")) | end_index = json_string.find("```", start_index + len("```json")) | ||||
| if start_index != -1 and end_index != -1: | if start_index != -1 and end_index != -1: | ||||
| extracted_content = json_string[start_index + len("```json"):end_index].strip() | |||||
| extracted_content = json_string[start_index + len("```json") : end_index].strip() | |||||
| # Parse the JSON string into a Python dictionary | # Parse the JSON string into a Python dictionary | ||||
| parsed = json.loads(extracted_content) | parsed = json.loads(extracted_content) | ||||
| elif start_index != -1 and end_index == -1 and json_string.endswith("``"): | elif start_index != -1 and end_index == -1 and json_string.endswith("``"): | ||||
| end_index = json_string.find("``", start_index + len("```json")) | end_index = json_string.find("``", start_index + len("```json")) | ||||
| extracted_content = json_string[start_index + len("```json"):end_index].strip() | |||||
| extracted_content = json_string[start_index + len("```json") : end_index].strip() | |||||
| # Parse the JSON string into a Python dictionary | # Parse the JSON string into a Python dictionary | ||||
| parsed = json.loads(extracted_content) | parsed = json.loads(extracted_content) | ||||
| for key in expected_keys: | for key in expected_keys: | ||||
| if key not in json_obj: | if key not in json_obj: | ||||
| raise OutputParserException( | raise OutputParserException( | ||||
| f"Got invalid return object. Expected key `{key}` " | |||||
| f"to be present, but got {json_obj}" | |||||
| f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" | |||||
| ) | ) | ||||
| return json_obj | return json_obj |
| @wraps(func) | @wraps(func) | ||||
| def decorated_view(*args, **kwargs): | def decorated_view(*args, **kwargs): | ||||
| auth_header = request.headers.get('Authorization') | |||||
| admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False') | |||||
| if admin_api_key_enable.lower() == 'true': | |||||
| auth_header = request.headers.get("Authorization") | |||||
| admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") | |||||
| if admin_api_key_enable.lower() == "true": | |||||
| if auth_header: | if auth_header: | ||||
| if ' ' not in auth_header: | |||||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||||
| if " " not in auth_header: | |||||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||||
| auth_scheme, auth_token = auth_header.split(None, 1) | auth_scheme, auth_token = auth_header.split(None, 1) | ||||
| auth_scheme = auth_scheme.lower() | auth_scheme = auth_scheme.lower() | ||||
| if auth_scheme != 'bearer': | |||||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||||
| admin_api_key = os.getenv('ADMIN_API_KEY') | |||||
| if auth_scheme != "bearer": | |||||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||||
| admin_api_key = os.getenv("ADMIN_API_KEY") | |||||
| if admin_api_key: | if admin_api_key: | ||||
| if os.getenv('ADMIN_API_KEY') == auth_token: | |||||
| workspace_id = request.headers.get('X-WORKSPACE-ID') | |||||
| if os.getenv("ADMIN_API_KEY") == auth_token: | |||||
| workspace_id = request.headers.get("X-WORKSPACE-ID") | |||||
| if workspace_id: | if workspace_id: | ||||
| tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ | |||||
| .filter(Tenant.id == workspace_id) \ | |||||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | |||||
| .filter(TenantAccountJoin.role == 'owner') \ | |||||
| tenant_account_join = ( | |||||
| db.session.query(Tenant, TenantAccountJoin) | |||||
| .filter(Tenant.id == workspace_id) | |||||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) | |||||
| .filter(TenantAccountJoin.role == "owner") | |||||
| .one_or_none() | .one_or_none() | ||||
| ) | |||||
| if tenant_account_join: | if tenant_account_join: | ||||
| tenant, ta = tenant_account_join | tenant, ta = tenant_account_join | ||||
| account = Account.query.filter_by(id=ta.account_id).first() | account = Account.query.filter_by(id=ta.account_id).first() |
| class GitHubOAuth(OAuth): | class GitHubOAuth(OAuth): | ||||
| _AUTH_URL = 'https://github.com/login/oauth/authorize' | |||||
| _TOKEN_URL = 'https://github.com/login/oauth/access_token' | |||||
| _USER_INFO_URL = 'https://api.github.com/user' | |||||
| _EMAIL_INFO_URL = 'https://api.github.com/user/emails' | |||||
| _AUTH_URL = "https://github.com/login/oauth/authorize" | |||||
| _TOKEN_URL = "https://github.com/login/oauth/access_token" | |||||
| _USER_INFO_URL = "https://api.github.com/user" | |||||
| _EMAIL_INFO_URL = "https://api.github.com/user/emails" | |||||
| def get_authorization_url(self): | def get_authorization_url(self): | ||||
| params = { | params = { | ||||
| 'client_id': self.client_id, | |||||
| 'redirect_uri': self.redirect_uri, | |||||
| 'scope': 'user:email' # Request only basic user information | |||||
| "client_id": self.client_id, | |||||
| "redirect_uri": self.redirect_uri, | |||||
| "scope": "user:email", # Request only basic user information | |||||
| } | } | ||||
| return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | ||||
| def get_access_token(self, code: str): | def get_access_token(self, code: str): | ||||
| data = { | data = { | ||||
| 'client_id': self.client_id, | |||||
| 'client_secret': self.client_secret, | |||||
| 'code': code, | |||||
| 'redirect_uri': self.redirect_uri | |||||
| "client_id": self.client_id, | |||||
| "client_secret": self.client_secret, | |||||
| "code": code, | |||||
| "redirect_uri": self.redirect_uri, | |||||
| } | } | ||||
| headers = {'Accept': 'application/json'} | |||||
| headers = {"Accept": "application/json"} | |||||
| response = requests.post(self._TOKEN_URL, data=data, headers=headers) | response = requests.post(self._TOKEN_URL, data=data, headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| access_token = response_json.get('access_token') | |||||
| access_token = response_json.get("access_token") | |||||
| if not access_token: | if not access_token: | ||||
| raise ValueError(f"Error in GitHub OAuth: {response_json}") | raise ValueError(f"Error in GitHub OAuth: {response_json}") | ||||
| return access_token | return access_token | ||||
| def get_raw_user_info(self, token: str): | def get_raw_user_info(self, token: str): | ||||
| headers = {'Authorization': f"token {token}"} | |||||
| headers = {"Authorization": f"token {token}"} | |||||
| response = requests.get(self._USER_INFO_URL, headers=headers) | response = requests.get(self._USER_INFO_URL, headers=headers) | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| user_info = response.json() | user_info = response.json() | ||||
| email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) | email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) | ||||
| email_info = email_response.json() | email_info = email_response.json() | ||||
| primary_email = next((email for email in email_info if email['primary'] == True), None) | |||||
| primary_email = next((email for email in email_info if email["primary"] == True), None) | |||||
| return {**user_info, 'email': primary_email['email']} | |||||
| return {**user_info, "email": primary_email["email"]} | |||||
| def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: | def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: | ||||
| email = raw_info.get('email') | |||||
| email = raw_info.get("email") | |||||
| if not email: | if not email: | ||||
| email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" | email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" | ||||
| return OAuthUserInfo( | |||||
| id=str(raw_info['id']), | |||||
| name=raw_info['name'], | |||||
| email=email | |||||
| ) | |||||
| return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) | |||||
| class GoogleOAuth(OAuth): | class GoogleOAuth(OAuth): | ||||
| _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth' | |||||
| _TOKEN_URL = 'https://oauth2.googleapis.com/token' | |||||
| _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo' | |||||
| _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" | |||||
| _TOKEN_URL = "https://oauth2.googleapis.com/token" | |||||
| _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" | |||||
| def get_authorization_url(self): | def get_authorization_url(self): | ||||
| params = { | params = { | ||||
| 'client_id': self.client_id, | |||||
| 'response_type': 'code', | |||||
| 'redirect_uri': self.redirect_uri, | |||||
| 'scope': 'openid email' | |||||
| "client_id": self.client_id, | |||||
| "response_type": "code", | |||||
| "redirect_uri": self.redirect_uri, | |||||
| "scope": "openid email", | |||||
| } | } | ||||
| return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | ||||
| def get_access_token(self, code: str): | def get_access_token(self, code: str): | ||||
| data = { | data = { | ||||
| 'client_id': self.client_id, | |||||
| 'client_secret': self.client_secret, | |||||
| 'code': code, | |||||
| 'grant_type': 'authorization_code', | |||||
| 'redirect_uri': self.redirect_uri | |||||
| "client_id": self.client_id, | |||||
| "client_secret": self.client_secret, | |||||
| "code": code, | |||||
| "grant_type": "authorization_code", | |||||
| "redirect_uri": self.redirect_uri, | |||||
| } | } | ||||
| headers = {'Accept': 'application/json'} | |||||
| headers = {"Accept": "application/json"} | |||||
| response = requests.post(self._TOKEN_URL, data=data, headers=headers) | response = requests.post(self._TOKEN_URL, data=data, headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| access_token = response_json.get('access_token') | |||||
| access_token = response_json.get("access_token") | |||||
| if not access_token: | if not access_token: | ||||
| raise ValueError(f"Error in Google OAuth: {response_json}") | raise ValueError(f"Error in Google OAuth: {response_json}") | ||||
| return access_token | return access_token | ||||
| def get_raw_user_info(self, token: str): | def get_raw_user_info(self, token: str): | ||||
| headers = {'Authorization': f"Bearer {token}"} | |||||
| headers = {"Authorization": f"Bearer {token}"} | |||||
| response = requests.get(self._USER_INFO_URL, headers=headers) | response = requests.get(self._USER_INFO_URL, headers=headers) | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| return response.json() | return response.json() | ||||
| def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: | def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: | ||||
| return OAuthUserInfo( | |||||
| id=str(raw_info['sub']), | |||||
| name=None, | |||||
| email=raw_info['email'] | |||||
| ) | |||||
| return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) |
| class NotionOAuth(OAuthDataSource): | class NotionOAuth(OAuthDataSource): | ||||
| _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' | |||||
| _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' | |||||
| _AUTH_URL = "https://api.notion.com/v1/oauth/authorize" | |||||
| _TOKEN_URL = "https://api.notion.com/v1/oauth/token" | |||||
| _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" | _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" | ||||
| _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" | _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" | ||||
| _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" | _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" | ||||
| def get_authorization_url(self): | def get_authorization_url(self): | ||||
| params = { | params = { | ||||
| 'client_id': self.client_id, | |||||
| 'response_type': 'code', | |||||
| 'redirect_uri': self.redirect_uri, | |||||
| 'owner': 'user' | |||||
| "client_id": self.client_id, | |||||
| "response_type": "code", | |||||
| "redirect_uri": self.redirect_uri, | |||||
| "owner": "user", | |||||
| } | } | ||||
| return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" | ||||
| def get_access_token(self, code: str): | def get_access_token(self, code: str): | ||||
| data = { | |||||
| 'code': code, | |||||
| 'grant_type': 'authorization_code', | |||||
| 'redirect_uri': self.redirect_uri | |||||
| } | |||||
| headers = {'Accept': 'application/json'} | |||||
| data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} | |||||
| headers = {"Accept": "application/json"} | |||||
| auth = (self.client_id, self.client_secret) | auth = (self.client_id, self.client_secret) | ||||
| response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) | response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| access_token = response_json.get('access_token') | |||||
| access_token = response_json.get("access_token") | |||||
| if not access_token: | if not access_token: | ||||
| raise ValueError(f"Error in Notion OAuth: {response_json}") | raise ValueError(f"Error in Notion OAuth: {response_json}") | ||||
| workspace_name = response_json.get('workspace_name') | |||||
| workspace_icon = response_json.get('workspace_icon') | |||||
| workspace_id = response_json.get('workspace_id') | |||||
| workspace_name = response_json.get("workspace_name") | |||||
| workspace_icon = response_json.get("workspace_icon") | |||||
| workspace_id = response_json.get("workspace_id") | |||||
| # get all authorized pages | # get all authorized pages | ||||
| pages = self.get_authorized_pages(access_token) | pages = self.get_authorized_pages(access_token) | ||||
| source_info = { | source_info = { | ||||
| 'workspace_name': workspace_name, | |||||
| 'workspace_icon': workspace_icon, | |||||
| 'workspace_id': workspace_id, | |||||
| 'pages': pages, | |||||
| 'total': len(pages) | |||||
| "workspace_name": workspace_name, | |||||
| "workspace_icon": workspace_icon, | |||||
| "workspace_id": workspace_id, | |||||
| "pages": pages, | |||||
| "total": len(pages), | |||||
| } | } | ||||
| # save data source binding | # save data source binding | ||||
| data_source_binding = DataSourceOauthBinding.query.filter( | data_source_binding = DataSourceOauthBinding.query.filter( | ||||
| db.and_( | db.and_( | ||||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | ||||
| DataSourceOauthBinding.provider == 'notion', | |||||
| DataSourceOauthBinding.access_token == access_token | |||||
| DataSourceOauthBinding.provider == "notion", | |||||
| DataSourceOauthBinding.access_token == access_token, | |||||
| ) | ) | ||||
| ).first() | ).first() | ||||
| if data_source_binding: | if data_source_binding: | ||||
| tenant_id=current_user.current_tenant_id, | tenant_id=current_user.current_tenant_id, | ||||
| access_token=access_token, | access_token=access_token, | ||||
| source_info=source_info, | source_info=source_info, | ||||
| provider='notion' | |||||
| provider="notion", | |||||
| ) | ) | ||||
| db.session.add(new_data_source_binding) | db.session.add(new_data_source_binding) | ||||
| db.session.commit() | db.session.commit() | ||||
| # get all authorized pages | # get all authorized pages | ||||
| pages = self.get_authorized_pages(access_token) | pages = self.get_authorized_pages(access_token) | ||||
| source_info = { | source_info = { | ||||
| 'workspace_name': workspace_name, | |||||
| 'workspace_icon': workspace_icon, | |||||
| 'workspace_id': workspace_id, | |||||
| 'pages': pages, | |||||
| 'total': len(pages) | |||||
| "workspace_name": workspace_name, | |||||
| "workspace_icon": workspace_icon, | |||||
| "workspace_id": workspace_id, | |||||
| "pages": pages, | |||||
| "total": len(pages), | |||||
| } | } | ||||
| # save data source binding | # save data source binding | ||||
| data_source_binding = DataSourceOauthBinding.query.filter( | data_source_binding = DataSourceOauthBinding.query.filter( | ||||
| db.and_( | db.and_( | ||||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | ||||
| DataSourceOauthBinding.provider == 'notion', | |||||
| DataSourceOauthBinding.access_token == access_token | |||||
| DataSourceOauthBinding.provider == "notion", | |||||
| DataSourceOauthBinding.access_token == access_token, | |||||
| ) | ) | ||||
| ).first() | ).first() | ||||
| if data_source_binding: | if data_source_binding: | ||||
| tenant_id=current_user.current_tenant_id, | tenant_id=current_user.current_tenant_id, | ||||
| access_token=access_token, | access_token=access_token, | ||||
| source_info=source_info, | source_info=source_info, | ||||
| provider='notion' | |||||
| provider="notion", | |||||
| ) | ) | ||||
| db.session.add(new_data_source_binding) | db.session.add(new_data_source_binding) | ||||
| db.session.commit() | db.session.commit() | ||||
| data_source_binding = DataSourceOauthBinding.query.filter( | data_source_binding = DataSourceOauthBinding.query.filter( | ||||
| db.and_( | db.and_( | ||||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | ||||
| DataSourceOauthBinding.provider == 'notion', | |||||
| DataSourceOauthBinding.provider == "notion", | |||||
| DataSourceOauthBinding.id == binding_id, | DataSourceOauthBinding.id == binding_id, | ||||
| DataSourceOauthBinding.disabled == False | |||||
| DataSourceOauthBinding.disabled == False, | |||||
| ) | ) | ||||
| ).first() | ).first() | ||||
| if data_source_binding: | if data_source_binding: | ||||
| pages = self.get_authorized_pages(data_source_binding.access_token) | pages = self.get_authorized_pages(data_source_binding.access_token) | ||||
| source_info = data_source_binding.source_info | source_info = data_source_binding.source_info | ||||
| new_source_info = { | new_source_info = { | ||||
| 'workspace_name': source_info['workspace_name'], | |||||
| 'workspace_icon': source_info['workspace_icon'], | |||||
| 'workspace_id': source_info['workspace_id'], | |||||
| 'pages': pages, | |||||
| 'total': len(pages) | |||||
| "workspace_name": source_info["workspace_name"], | |||||
| "workspace_icon": source_info["workspace_icon"], | |||||
| "workspace_id": source_info["workspace_id"], | |||||
| "pages": pages, | |||||
| "total": len(pages), | |||||
| } | } | ||||
| data_source_binding.source_info = new_source_info | data_source_binding.source_info = new_source_info | ||||
| data_source_binding.disabled = False | data_source_binding.disabled = False | ||||
| db.session.commit() | db.session.commit() | ||||
| else: | else: | ||||
| raise ValueError('Data source binding not found') | |||||
| raise ValueError("Data source binding not found") | |||||
| def get_authorized_pages(self, access_token: str): | def get_authorized_pages(self, access_token: str): | ||||
| pages = [] | pages = [] | ||||
| database_results = self.notion_database_search(access_token) | database_results = self.notion_database_search(access_token) | ||||
| # get page detail | # get page detail | ||||
| for page_result in page_results: | for page_result in page_results: | ||||
| page_id = page_result['id'] | |||||
| page_name = 'Untitled' | |||||
| for key in page_result['properties']: | |||||
| if 'title' in page_result['properties'][key] and page_result['properties'][key]['title']: | |||||
| title_list = page_result['properties'][key]['title'] | |||||
| if len(title_list) > 0 and 'plain_text' in title_list[0]: | |||||
| page_name = title_list[0]['plain_text'] | |||||
| page_icon = page_result['icon'] | |||||
| page_id = page_result["id"] | |||||
| page_name = "Untitled" | |||||
| for key in page_result["properties"]: | |||||
| if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]: | |||||
| title_list = page_result["properties"][key]["title"] | |||||
| if len(title_list) > 0 and "plain_text" in title_list[0]: | |||||
| page_name = title_list[0]["plain_text"] | |||||
| page_icon = page_result["icon"] | |||||
| if page_icon: | if page_icon: | ||||
| icon_type = page_icon['type'] | |||||
| if icon_type == 'external' or icon_type == 'file': | |||||
| url = page_icon[icon_type]['url'] | |||||
| icon = { | |||||
| 'type': 'url', | |||||
| 'url': url if url.startswith('http') else f'https://www.notion.so{url}' | |||||
| } | |||||
| icon_type = page_icon["type"] | |||||
| if icon_type == "external" or icon_type == "file": | |||||
| url = page_icon[icon_type]["url"] | |||||
| icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} | |||||
| else: | else: | ||||
| icon = { | |||||
| 'type': 'emoji', | |||||
| 'emoji': page_icon[icon_type] | |||||
| } | |||||
| icon = {"type": "emoji", "emoji": page_icon[icon_type]} | |||||
| else: | else: | ||||
| icon = None | icon = None | ||||
| parent = page_result['parent'] | |||||
| parent_type = parent['type'] | |||||
| if parent_type == 'block_id': | |||||
| parent = page_result["parent"] | |||||
| parent_type = parent["type"] | |||||
| if parent_type == "block_id": | |||||
| parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) | parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) | ||||
| elif parent_type == 'workspace': | |||||
| parent_id = 'root' | |||||
| elif parent_type == "workspace": | |||||
| parent_id = "root" | |||||
| else: | else: | ||||
| parent_id = parent[parent_type] | parent_id = parent[parent_type] | ||||
| page = { | page = { | ||||
| 'page_id': page_id, | |||||
| 'page_name': page_name, | |||||
| 'page_icon': icon, | |||||
| 'parent_id': parent_id, | |||||
| 'type': 'page' | |||||
| "page_id": page_id, | |||||
| "page_name": page_name, | |||||
| "page_icon": icon, | |||||
| "parent_id": parent_id, | |||||
| "type": "page", | |||||
| } | } | ||||
| pages.append(page) | pages.append(page) | ||||
| # get database detail | # get database detail | ||||
| for database_result in database_results: | for database_result in database_results: | ||||
| page_id = database_result['id'] | |||||
| if len(database_result['title']) > 0: | |||||
| page_name = database_result['title'][0]['plain_text'] | |||||
| page_id = database_result["id"] | |||||
| if len(database_result["title"]) > 0: | |||||
| page_name = database_result["title"][0]["plain_text"] | |||||
| else: | else: | ||||
| page_name = 'Untitled' | |||||
| page_icon = database_result['icon'] | |||||
| page_name = "Untitled" | |||||
| page_icon = database_result["icon"] | |||||
| if page_icon: | if page_icon: | ||||
| icon_type = page_icon['type'] | |||||
| if icon_type == 'external' or icon_type == 'file': | |||||
| url = page_icon[icon_type]['url'] | |||||
| icon = { | |||||
| 'type': 'url', | |||||
| 'url': url if url.startswith('http') else f'https://www.notion.so{url}' | |||||
| } | |||||
| icon_type = page_icon["type"] | |||||
| if icon_type == "external" or icon_type == "file": | |||||
| url = page_icon[icon_type]["url"] | |||||
| icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} | |||||
| else: | else: | ||||
| icon = { | |||||
| 'type': icon_type, | |||||
| icon_type: page_icon[icon_type] | |||||
| } | |||||
| icon = {"type": icon_type, icon_type: page_icon[icon_type]} | |||||
| else: | else: | ||||
| icon = None | icon = None | ||||
| parent = database_result['parent'] | |||||
| parent_type = parent['type'] | |||||
| if parent_type == 'block_id': | |||||
| parent = database_result["parent"] | |||||
| parent_type = parent["type"] | |||||
| if parent_type == "block_id": | |||||
| parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) | parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) | ||||
| elif parent_type == 'workspace': | |||||
| parent_id = 'root' | |||||
| elif parent_type == "workspace": | |||||
| parent_id = "root" | |||||
| else: | else: | ||||
| parent_id = parent[parent_type] | parent_id = parent[parent_type] | ||||
| page = { | page = { | ||||
| 'page_id': page_id, | |||||
| 'page_name': page_name, | |||||
| 'page_icon': icon, | |||||
| 'parent_id': parent_id, | |||||
| 'type': 'database' | |||||
| "page_id": page_id, | |||||
| "page_name": page_name, | |||||
| "page_icon": icon, | |||||
| "parent_id": parent_id, | |||||
| "type": "database", | |||||
| } | } | ||||
| pages.append(page) | pages.append(page) | ||||
| return pages | return pages | ||||
| def notion_page_search(self, access_token: str): | def notion_page_search(self, access_token: str): | ||||
| data = { | |||||
| 'filter': { | |||||
| "value": "page", | |||||
| "property": "object" | |||||
| } | |||||
| } | |||||
| data = {"filter": {"value": "page", "property": "object"}} | |||||
| headers = { | headers = { | ||||
| 'Content-Type': 'application/json', | |||||
| 'Authorization': f"Bearer {access_token}", | |||||
| 'Notion-Version': '2022-06-28', | |||||
| "Content-Type": "application/json", | |||||
| "Authorization": f"Bearer {access_token}", | |||||
| "Notion-Version": "2022-06-28", | |||||
| } | } | ||||
| response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| results = response_json.get('results', []) | |||||
| results = response_json.get("results", []) | |||||
| return results | return results | ||||
| def notion_block_parent_page_id(self, access_token: str, block_id: str): | def notion_block_parent_page_id(self, access_token: str, block_id: str): | ||||
| headers = { | headers = { | ||||
| 'Authorization': f"Bearer {access_token}", | |||||
| 'Notion-Version': '2022-06-28', | |||||
| "Authorization": f"Bearer {access_token}", | |||||
| "Notion-Version": "2022-06-28", | |||||
| } | } | ||||
| response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers) | |||||
| response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) | |||||
| response_json = response.json() | response_json = response.json() | ||||
| parent = response_json['parent'] | |||||
| parent_type = parent['type'] | |||||
| if parent_type == 'block_id': | |||||
| parent = response_json["parent"] | |||||
| parent_type = parent["type"] | |||||
| if parent_type == "block_id": | |||||
| return self.notion_block_parent_page_id(access_token, parent[parent_type]) | return self.notion_block_parent_page_id(access_token, parent[parent_type]) | ||||
| return parent[parent_type] | return parent[parent_type] | ||||
| def notion_workspace_name(self, access_token: str): | def notion_workspace_name(self, access_token: str): | ||||
| headers = { | headers = { | ||||
| 'Authorization': f"Bearer {access_token}", | |||||
| 'Notion-Version': '2022-06-28', | |||||
| "Authorization": f"Bearer {access_token}", | |||||
| "Notion-Version": "2022-06-28", | |||||
| } | } | ||||
| response = requests.get(url=self._NOTION_BOT_USER, headers=headers) | response = requests.get(url=self._NOTION_BOT_USER, headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| if 'object' in response_json and response_json['object'] == 'user': | |||||
| user_type = response_json['type'] | |||||
| if "object" in response_json and response_json["object"] == "user": | |||||
| user_type = response_json["type"] | |||||
| user_info = response_json[user_type] | user_info = response_json[user_type] | ||||
| if 'workspace_name' in user_info: | |||||
| return user_info['workspace_name'] | |||||
| return 'workspace' | |||||
| if "workspace_name" in user_info: | |||||
| return user_info["workspace_name"] | |||||
| return "workspace" | |||||
| def notion_database_search(self, access_token: str): | def notion_database_search(self, access_token: str): | ||||
| data = { | |||||
| 'filter': { | |||||
| "value": "database", | |||||
| "property": "object" | |||||
| } | |||||
| } | |||||
| data = {"filter": {"value": "database", "property": "object"}} | |||||
| headers = { | headers = { | ||||
| 'Content-Type': 'application/json', | |||||
| 'Authorization': f"Bearer {access_token}", | |||||
| 'Notion-Version': '2022-06-28', | |||||
| "Content-Type": "application/json", | |||||
| "Authorization": f"Bearer {access_token}", | |||||
| "Notion-Version": "2022-06-28", | |||||
| } | } | ||||
| response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) | ||||
| response_json = response.json() | response_json = response.json() | ||||
| results = response_json.get('results', []) | |||||
| results = response_json.get("results", []) | |||||
| return results | return results |
| self.sk = dify_config.SECRET_KEY | self.sk = dify_config.SECRET_KEY | ||||
| def issue(self, payload): | def issue(self, payload): | ||||
| return jwt.encode(payload, self.sk, algorithm='HS256') | |||||
| return jwt.encode(payload, self.sk, algorithm="HS256") | |||||
| def verify(self, token): | def verify(self, token): | ||||
| try: | try: | ||||
| return jwt.decode(token, self.sk, algorithms=['HS256']) | |||||
| return jwt.decode(token, self.sk, algorithms=["HS256"]) | |||||
| except jwt.exceptions.InvalidSignatureError: | except jwt.exceptions.InvalidSignatureError: | ||||
| raise Unauthorized('Invalid token signature.') | |||||
| raise Unauthorized("Invalid token signature.") | |||||
| except jwt.exceptions.DecodeError: | except jwt.exceptions.DecodeError: | ||||
| raise Unauthorized('Invalid token.') | |||||
| raise Unauthorized("Invalid token.") | |||||
| except jwt.exceptions.ExpiredSignatureError: | except jwt.exceptions.ExpiredSignatureError: | ||||
| raise Unauthorized('Token has expired.') | |||||
| raise Unauthorized("Token has expired.") |
| password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" | password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" | ||||
| def valid_password(password): | def valid_password(password): | ||||
| # Define a regex pattern for password rules | # Define a regex pattern for password rules | ||||
| pattern = password_pattern | pattern = password_pattern | ||||
| if re.match(pattern, password) is not None: | if re.match(pattern, password) is not None: | ||||
| return password | return password | ||||
| raise ValueError('Not a valid password.') | |||||
| raise ValueError("Not a valid password.") | |||||
| def hash_password(password_str, salt_byte): | def hash_password(password_str, salt_byte): | ||||
| dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000) | |||||
| dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) | |||||
| return binascii.hexlify(dk) | return binascii.hexlify(dk) | ||||
| def get_decrypt_decoding(tenant_id): | def get_decrypt_decoding(tenant_id): | ||||
| filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" | filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" | ||||
| cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) | |||||
| cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) | |||||
| private_key = redis_client.get(cache_key) | private_key = redis_client.get(cache_key) | ||||
| if not private_key: | if not private_key: | ||||
| try: | try: | ||||
| def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): | def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): | ||||
| if encrypted_text.startswith(prefix_hybrid): | if encrypted_text.startswith(prefix_hybrid): | ||||
| encrypted_text = encrypted_text[len(prefix_hybrid):] | |||||
| encrypted_text = encrypted_text[len(prefix_hybrid) :] | |||||
| enc_aes_key = encrypted_text[:rsa_key.size_in_bytes()] | |||||
| nonce = encrypted_text[rsa_key.size_in_bytes():rsa_key.size_in_bytes() + 16] | |||||
| tag = encrypted_text[rsa_key.size_in_bytes() + 16:rsa_key.size_in_bytes() + 32] | |||||
| ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32:] | |||||
| enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()] | |||||
| nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16] | |||||
| tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32] | |||||
| ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :] | |||||
| aes_key = cipher_rsa.decrypt(enc_aes_key) | aes_key = cipher_rsa.decrypt(enc_aes_key) | ||||
| class SMTPClient: | class SMTPClient: | ||||
| def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False): | |||||
| def __init__( | |||||
| self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False | |||||
| ): | |||||
| self.server = server | self.server = server | ||||
| self.port = port | self.port = port | ||||
| self._from = _from | self._from = _from | ||||
| smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) | smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) | ||||
| else: | else: | ||||
| smtp = smtplib.SMTP(self.server, self.port, timeout=10) | smtp = smtplib.SMTP(self.server, self.port, timeout=10) | ||||
| if self.username and self.password: | if self.username and self.password: | ||||
| smtp.login(self.username, self.password) | smtp.login(self.username, self.password) | ||||
| msg = MIMEMultipart() | msg = MIMEMultipart() | ||||
| msg['Subject'] = mail['subject'] | |||||
| msg['From'] = self._from | |||||
| msg['To'] = mail['to'] | |||||
| msg.attach(MIMEText(mail['html'], 'html')) | |||||
| msg["Subject"] = mail["subject"] | |||||
| msg["From"] = self._from | |||||
| msg["To"] = mail["to"] | |||||
| msg.attach(MIMEText(mail["html"], "html")) | |||||
| smtp.sendmail(self._from, mail['to'], msg.as_string()) | |||||
| smtp.sendmail(self._from, mail["to"], msg.as_string()) | |||||
| except smtplib.SMTPException as e: | except smtplib.SMTPException as e: | ||||
| logging.error(f"SMTP error occurred: {str(e)}") | logging.error(f"SMTP error occurred: {str(e)}") | ||||
| raise | raise |
| "core/**/*.py", | "core/**/*.py", | ||||
| "controllers/**/*.py", | "controllers/**/*.py", | ||||
| "models/**/*.py", | "models/**/*.py", | ||||
| "utils/**/*.py", | |||||
| "migrations/**/*", | "migrations/**/*", | ||||
| "services/**/*.py", | "services/**/*.py", | ||||
| "tasks/**/*.py", | "tasks/**/*.py", | ||||
| "tests/**/*.py", | "tests/**/*.py", | ||||
| "libs/**/*.py", | |||||
| "configs/**/*.py", | "configs/**/*.py", | ||||
| ] | ] | ||||