| @@ -3,6 +3,7 @@ import logging | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse | |||
| from sqlalchemy import select | |||
| from werkzeug.exceptions import Unauthorized | |||
| import services | |||
| @@ -88,9 +89,8 @@ class WorkspaceListApi(Resource): | |||
| parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate( | |||
| page=args["page"], per_page=args["limit"], error_out=False | |||
| ) | |||
| stmt = select(Tenant).order_by(Tenant.created_at.desc()) | |||
| tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False) | |||
| has_more = False | |||
| if tenants.has_next: | |||
| @@ -162,7 +162,7 @@ class CustomConfigWorkspaceApi(Resource): | |||
| parser.add_argument("replace_webapp_logo", type=str, location="json") | |||
| args = parser.parse_args() | |||
| tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() | |||
| tenant = db.get_or_404(Tenant, current_user.current_tenant_id) | |||
| custom_config_dict = { | |||
| "remove_webapp_brand": args["remove_webapp_brand"], | |||
| @@ -226,7 +226,7 @@ class WorkspaceInfoApi(Resource): | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() | |||
| tenant = db.get_or_404(Tenant, current_user.current_tenant_id) | |||
| tenant.name = args["name"] | |||
| db.session.commit() | |||
| @@ -347,14 +347,18 @@ class NotionExtractor(BaseExtractor): | |||
| @classmethod | |||
| def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', | |||
| ) | |||
| ) | |||
| ).first() | |||
| .first() | |||
| ) | |||
| if not data_source_binding: | |||
| raise Exception( | |||
| @@ -61,13 +61,17 @@ class NotionOAuth(OAuthDataSource): | |||
| "total": len(pages), | |||
| } | |||
| # save data source binding | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.access_token == access_token, | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.access_token == access_token, | |||
| ) | |||
| ) | |||
| ).first() | |||
| .first() | |||
| ) | |||
| if data_source_binding: | |||
| data_source_binding.source_info = source_info | |||
| data_source_binding.disabled = False | |||
| @@ -97,13 +101,17 @@ class NotionOAuth(OAuthDataSource): | |||
| "total": len(pages), | |||
| } | |||
| # save data source binding | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.access_token == access_token, | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.access_token == access_token, | |||
| ) | |||
| ) | |||
| ).first() | |||
| .first() | |||
| ) | |||
| if data_source_binding: | |||
| data_source_binding.source_info = source_info | |||
| data_source_binding.disabled = False | |||
| @@ -121,14 +129,18 @@ class NotionOAuth(OAuthDataSource): | |||
| def sync_data_source(self, binding_id: str): | |||
| # save data source binding | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.id == binding_id, | |||
| DataSourceOauthBinding.disabled == False, | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.id == binding_id, | |||
| DataSourceOauthBinding.disabled == False, | |||
| ) | |||
| ) | |||
| ).first() | |||
| .first() | |||
| ) | |||
| if data_source_binding: | |||
| # get all authorized pages | |||
| pages = self.get_authorized_pages(data_source_binding.access_token) | |||
| @@ -45,7 +45,7 @@ def mail_clean_document_notify_task(): | |||
| if plan != "sandbox": | |||
| knowledge_details = [] | |||
| # check tenant | |||
| tenant = Tenant.query.filter(Tenant.id == tenant_id).first() | |||
| tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() | |||
| if not tenant: | |||
| continue | |||
| # check current owner | |||
| @@ -300,9 +300,9 @@ class AccountService: | |||
| """Link account integrate""" | |||
| try: | |||
| # Query whether there is an existing binding record for the same provider | |||
| account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( | |||
| account_id=account.id, provider=provider | |||
| ).first() | |||
| account_integrate: Optional[AccountIntegrate] = ( | |||
| db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() | |||
| ) | |||
| if account_integrate: | |||
| # If it exists, update the record | |||
| @@ -851,7 +851,7 @@ class TenantService: | |||
| @staticmethod | |||
| def get_custom_config(tenant_id: str) -> dict: | |||
| tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() | |||
| tenant = db.get_or_404(Tenant, tenant_id) | |||
| return cast(dict, tenant.custom_config_dict) | |||
| @@ -4,7 +4,7 @@ from typing import cast | |||
| import pandas as pd | |||
| from flask_login import current_user | |||
| from sqlalchemy import or_ | |||
| from sqlalchemy import or_, select | |||
| from werkzeug.datastructures import FileStorage | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -124,8 +124,9 @@ class AppAnnotationService: | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| if keyword: | |||
| annotations = ( | |||
| MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) | |||
| stmt = ( | |||
| select(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .filter( | |||
| or_( | |||
| MessageAnnotation.question.ilike("%{}%".format(keyword)), | |||
| @@ -133,14 +134,14 @@ class AppAnnotationService: | |||
| ) | |||
| ) | |||
| .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| ) | |||
| else: | |||
| annotations = ( | |||
| MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) | |||
| stmt = ( | |||
| select(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| ) | |||
| annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| return annotations.items, annotations.total | |||
| @classmethod | |||
| @@ -325,13 +326,16 @@ class AppAnnotationService: | |||
| if not annotation: | |||
| raise NotFound("Annotation not found") | |||
| annotation_hit_histories = ( | |||
| AppAnnotationHitHistory.query.filter( | |||
| stmt = ( | |||
| select(AppAnnotationHitHistory) | |||
| .filter( | |||
| AppAnnotationHitHistory.app_id == app_id, | |||
| AppAnnotationHitHistory.annotation_id == annotation_id, | |||
| ) | |||
| .order_by(AppAnnotationHitHistory.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| ) | |||
| annotation_hit_histories = db.paginate( | |||
| select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False | |||
| ) | |||
| return annotation_hit_histories.items, annotation_hit_histories.total | |||
| @@ -1087,14 +1087,18 @@ class DocumentService: | |||
| exist_document[data_source_info["notion_page_id"]] = document.id | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info.workspace_id | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| ) | |||
| ) | |||
| ).first() | |||
| .first() | |||
| ) | |||
| if not data_source_binding: | |||
| raise ValueError("Data source binding not found.") | |||
| for page in notion_info.pages: | |||
| @@ -1302,14 +1306,18 @@ class DocumentService: | |||
| notion_info_list = document_data.data_source.info_list.notion_info_list | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info.workspace_id | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| ) | |||
| ) | |||
| ).first() | |||
| .first() | |||
| ) | |||
| if not data_source_binding: | |||
| raise ValueError("Data source binding not found.") | |||
| for page in notion_info.pages: | |||
| @@ -44,14 +44,18 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| page_id = data_source_info["notion_page_id"] | |||
| page_type = data_source_info["type"] | |||
| page_edited_time = data_source_info["last_edited_time"] | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == document.tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == document.tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| ) | |||
| ) | |||
| ).first() | |||
| .first() | |||
| ) | |||
| if not data_source_binding: | |||
| raise ValueError("Data source binding not found.") | |||