| @@ -6,6 +6,7 @@ from typing import Optional | |||
| import click | |||
| from flask import current_app | |||
| from sqlalchemy import select | |||
| from werkzeug.exceptions import NotFound | |||
| from configs import dify_config | |||
| @@ -297,11 +298,11 @@ def migrate_knowledge_vector_database(): | |||
| page = 1 | |||
| while True: | |||
| try: | |||
| datasets = ( | |||
| Dataset.query.filter(Dataset.indexing_technique == "high_quality") | |||
| .order_by(Dataset.created_at.desc()) | |||
| .paginate(page=page, per_page=50) | |||
| stmt = ( | |||
| select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) | |||
| ) | |||
| datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) | |||
| except NotFound: | |||
| break | |||
| @@ -592,11 +593,15 @@ def old_metadata_migration(): | |||
| ) | |||
| db.session.add(dataset_metadata_binding) | |||
| else: | |||
| dataset_metadata_binding = DatasetMetadataBinding.query.filter( | |||
| DatasetMetadataBinding.dataset_id == document.dataset_id, | |||
| DatasetMetadataBinding.document_id == document.id, | |||
| DatasetMetadataBinding.metadata_id == dataset_metadata.id, | |||
| ).first() | |||
| dataset_metadata_binding = ( | |||
| db.session.query(DatasetMetadataBinding) # type: ignore | |||
| .filter( | |||
| DatasetMetadataBinding.dataset_id == document.dataset_id, | |||
| DatasetMetadataBinding.document_id == document.id, | |||
| DatasetMetadataBinding.metadata_id == dataset_metadata.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not dataset_metadata_binding: | |||
| dataset_metadata_binding = DatasetMetadataBinding( | |||
| tenant_id=document.tenant_id, | |||
| @@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource): | |||
| ) | |||
| documents_status = [] | |||
| for document in documents: | |||
| completed_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ).count() | |||
| total_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" | |||
| ).count() | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ) | |||
| .count() | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| @@ -6,7 +6,7 @@ from typing import cast | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, fields, marshal, marshal_with, reqparse | |||
| from sqlalchemy import asc, desc | |||
| from sqlalchemy import asc, desc, select | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| import services | |||
| @@ -112,7 +112,7 @@ class GetProcessRuleApi(Resource): | |||
| limits = DocumentService.DEFAULT_RULES["limits"] | |||
| if document_id: | |||
| # get the latest process rule | |||
| document = Document.query.get_or_404(document_id) | |||
| document = db.get_or_404(Document, document_id) | |||
| dataset = DatasetService.get_dataset(document.dataset_id) | |||
| @@ -175,7 +175,7 @@ class DatasetDocumentListApi(Resource): | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) | |||
| query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) | |||
| if search: | |||
| search = f"%{search}%" | |||
| @@ -209,18 +209,24 @@ class DatasetDocumentListApi(Resource): | |||
| desc(Document.position), | |||
| ) | |||
| paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| documents = paginated_documents.items | |||
| if fetch: | |||
| for document in documents: | |||
| completed_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ).count() | |||
| total_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" | |||
| ).count() | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ) | |||
| .count() | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| data = marshal(documents, document_with_segments_fields) | |||
| @@ -563,14 +569,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| documents = self.get_batch_documents(dataset_id, batch) | |||
| documents_status = [] | |||
| for document in documents: | |||
| completed_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ).count() | |||
| total_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" | |||
| ).count() | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ) | |||
| .count() | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| @@ -589,14 +601,20 @@ class DocumentIndexingStatusApi(DocumentResource): | |||
| document_id = str(document_id) | |||
| document = self.get_document(dataset_id, document_id) | |||
| completed_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document_id), | |||
| DocumentSegment.status != "re_segment", | |||
| ).count() | |||
| total_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" | |||
| ).count() | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document_id), | |||
| DocumentSegment.status != "re_segment", | |||
| ) | |||
| .count() | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| @@ -4,6 +4,7 @@ import pandas as pd | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, marshal, reqparse | |||
| from sqlalchemy import select | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| import services | |||
| @@ -26,6 +27,7 @@ from controllers.console.wraps import ( | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from fields.segment_fields import child_chunk_fields, segment_fields | |||
| from libs.login import login_required | |||
| @@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| hit_count_gte = args["hit_count_gte"] | |||
| keyword = args["keyword"] | |||
| query = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).order_by(DocumentSegment.position.asc()) | |||
| query = ( | |||
| select(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.document_id == str(document_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id, | |||
| ) | |||
| .order_by(DocumentSegment.position.asc()) | |||
| ) | |||
| if status_list: | |||
| query = query.filter(DocumentSegment.status.in_(status_list)) | |||
| @@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| elif args["enabled"].lower() == "false": | |||
| query = query.filter(DocumentSegment.enabled == False) | |||
| segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| response = { | |||
| "data": marshal(segments.items, segment_fields), | |||
| @@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor | |||
| @@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor | |||
| @@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource): | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| if not current_user.is_dataset_editor: | |||
| @@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource): | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| parser = reqparse.RequestParser() | |||
| @@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource): | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor | |||
| @@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource): | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # check child chunk | |||
| child_chunk_id = str(child_chunk_id) | |||
| child_chunk = ChildChunk.query.filter( | |||
| ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not child_chunk: | |||
| raise NotFound("Child chunk not found.") | |||
| # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor | |||
| @@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource): | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # check child chunk | |||
| child_chunk_id = str(child_chunk_id) | |||
| child_chunk = ChildChunk.query.filter( | |||
| ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not child_chunk: | |||
| raise NotFound("Child chunk not found.") | |||
| # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor | |||
| @@ -2,10 +2,10 @@ import json | |||
| from flask import request | |||
| from flask_restful import marshal, reqparse | |||
| from sqlalchemy import desc | |||
| from sqlalchemy import desc, select | |||
| from werkzeug.exceptions import NotFound | |||
| import services.dataset_service | |||
| import services | |||
| from controllers.common.errors import FilenameNotExistsError | |||
| from controllers.service_api import api | |||
| from controllers.service_api.app.error import ( | |||
| @@ -337,7 +337,7 @@ class DocumentListApi(DatasetApiResource): | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) | |||
| query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) | |||
| if search: | |||
| search = f"%{search}%" | |||
| @@ -345,7 +345,7 @@ class DocumentListApi(DatasetApiResource): | |||
| query = query.order_by(desc(Document.created_at), desc(Document.position)) | |||
| paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| documents = paginated_documents.items | |||
| response = { | |||
| @@ -374,14 +374,20 @@ class DocumentIndexingStatusApi(DatasetApiResource): | |||
| raise NotFound("Documents not found.") | |||
| documents_status = [] | |||
| for document in documents: | |||
| completed_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ).count() | |||
| total_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" | |||
| ).count() | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ) | |||
| .count() | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| @@ -46,14 +46,22 @@ class DatasetIndexToolCallbackHandler: | |||
| DatasetDocument.id == document.metadata["document_id"] | |||
| ).first() | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunk = ChildChunk.query.filter( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| ).first() | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .filter( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if child_chunk: | |||
| segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == child_chunk.segment_id) | |||
| .update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False | |||
| ) | |||
| ) | |||
| else: | |||
| query = db.session.query(DocumentSegment).filter( | |||
| @@ -51,7 +51,7 @@ class IndexingRunner: | |||
| for dataset_document in dataset_documents: | |||
| try: | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| @@ -103,15 +103,17 @@ class IndexingRunner: | |||
| """Run the indexing process when the index_status is splitting.""" | |||
| try: | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # get exist document_segment list and delete | |||
| document_segments = DocumentSegment.query.filter_by( | |||
| dataset_id=dataset.id, document_id=dataset_document.id | |||
| ).all() | |||
| document_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) | |||
| .all() | |||
| ) | |||
| for document_segment in document_segments: | |||
| db.session.delete(document_segment) | |||
| @@ -162,15 +164,17 @@ class IndexingRunner: | |||
| """Run the indexing process when the index_status is indexing.""" | |||
| try: | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| # get exist document_segment list and delete | |||
| document_segments = DocumentSegment.query.filter_by( | |||
| dataset_id=dataset.id, document_id=dataset_document.id | |||
| ).all() | |||
| document_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) | |||
| .all() | |||
| ) | |||
| documents = [] | |||
| if document_segments: | |||
| @@ -254,7 +258,7 @@ class IndexingRunner: | |||
| embedding_model_instance = None | |||
| if dataset_id: | |||
| dataset = Dataset.query.filter_by(id=dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("Dataset not found.") | |||
| if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": | |||
| @@ -587,7 +591,7 @@ class IndexingRunner: | |||
| @staticmethod | |||
| def _process_keyword_index(flask_app, dataset_id, document_id, documents): | |||
| with flask_app.app_context(): | |||
| dataset = Dataset.query.filter_by(id=dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("no dataset found") | |||
| keyword = Keyword(dataset) | |||
| @@ -676,7 +680,7 @@ class IndexingRunner: | |||
| """ | |||
| Update the document segment by document id. | |||
| """ | |||
| DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) | |||
| db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params) | |||
| db.session.commit() | |||
| def _transform( | |||
| @@ -237,7 +237,7 @@ class DatasetRetrieval: | |||
| if show_retrieve_source: | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| document = DatasetDocument.query.filter( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| @@ -511,14 +511,23 @@ class DatasetRetrieval: | |||
| ).first() | |||
| if dataset_document: | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunk = ChildChunk.query.filter( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| ).first() | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .filter( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if child_chunk: | |||
| segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == child_chunk.segment_id) | |||
| .update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | |||
| synchronize_session=False, | |||
| ) | |||
| ) | |||
| db.session.commit() | |||
| else: | |||
| @@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| document_context_list = [] | |||
| index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] | |||
| segments = DocumentSegment.query.filter( | |||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||
| ).all() | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||
| ) | |||
| .all() | |||
| ) | |||
| if segments: | |||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||
| @@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| context_list = [] | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| document = Document.query.filter( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| document = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .first() | |||
| ) | |||
| if dataset and document: | |||
| source = { | |||
| "position": resource_number, | |||
| @@ -185,7 +185,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| if self.return_resource: | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| document = DatasetDocument.query.filter( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| @@ -275,12 +275,16 @@ class KnowledgeRetrievalNode(LLMNode): | |||
| if records: | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| document = Document.query.filter( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore | |||
| document = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .first() | |||
| ) | |||
| if dataset and document: | |||
| source = { | |||
| "metadata": { | |||
| @@ -93,7 +93,8 @@ class Dataset(Base): | |||
| @property | |||
| def latest_process_rule(self): | |||
| return ( | |||
| DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) | |||
| db.session.query(DatasetProcessRule) | |||
| .filter(DatasetProcessRule.dataset_id == self.id) | |||
| .order_by(DatasetProcessRule.created_at.desc()) | |||
| .first() | |||
| ) | |||
| @@ -138,7 +139,8 @@ class Dataset(Base): | |||
| @property | |||
| def word_count(self): | |||
| return ( | |||
| Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) | |||
| db.session.query(Document) | |||
| .with_entities(func.coalesce(func.sum(Document.word_count))) | |||
| .filter(Document.dataset_id == self.id) | |||
| .scalar() | |||
| ) | |||
| @@ -440,12 +442,13 @@ class Document(Base): | |||
| @property | |||
| def segment_count(self): | |||
| return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() | |||
| return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count() | |||
| @property | |||
| def hit_count(self): | |||
| return ( | |||
| DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) | |||
| db.session.query(DocumentSegment) | |||
| .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) | |||
| .filter(DocumentSegment.document_id == self.id) | |||
| .scalar() | |||
| ) | |||
| @@ -892,7 +895,7 @@ class DatasetKeywordTable(Base): | |||
| return dct | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by(id=self.dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() | |||
| if not dataset: | |||
| return None | |||
| if self.data_source_type == "database": | |||
| @@ -2,7 +2,7 @@ import datetime | |||
| import time | |||
| import click | |||
| from sqlalchemy import func | |||
| from sqlalchemy import func, select | |||
| from werkzeug.exceptions import NotFound | |||
| import app | |||
| @@ -51,8 +51,9 @@ def clean_unused_datasets_task(): | |||
| ) | |||
| # Main query with join and filter | |||
| datasets = ( | |||
| Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) | |||
| stmt = ( | |||
| select(Dataset) | |||
| .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) | |||
| .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) | |||
| .filter( | |||
| Dataset.created_at < plan_sandbox_clean_day, | |||
| @@ -60,9 +61,10 @@ def clean_unused_datasets_task(): | |||
| func.coalesce(document_subquery_old.c.document_count, 0) > 0, | |||
| ) | |||
| .order_by(Dataset.created_at.desc()) | |||
| .paginate(page=1, per_page=50) | |||
| ) | |||
| datasets = db.paginate(stmt, page=1, per_page=50) | |||
| except NotFound: | |||
| break | |||
| if datasets.items is None or len(datasets.items) == 0: | |||
| @@ -99,7 +101,7 @@ def clean_unused_datasets_task(): | |||
| # update document | |||
| update_params = {Document.enabled: False} | |||
| Document.query.filter_by(dataset_id=dataset.id).update(update_params) | |||
| db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) | |||
| db.session.commit() | |||
| click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) | |||
| except Exception as e: | |||
| @@ -135,8 +137,9 @@ def clean_unused_datasets_task(): | |||
| ) | |||
| # Main query with join and filter | |||
| datasets = ( | |||
| Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) | |||
| stmt = ( | |||
| select(Dataset) | |||
| .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) | |||
| .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) | |||
| .filter( | |||
| Dataset.created_at < plan_pro_clean_day, | |||
| @@ -144,8 +147,8 @@ def clean_unused_datasets_task(): | |||
| func.coalesce(document_subquery_old.c.document_count, 0) > 0, | |||
| ) | |||
| .order_by(Dataset.created_at.desc()) | |||
| .paginate(page=1, per_page=50) | |||
| ) | |||
| datasets = db.paginate(stmt, page=1, per_page=50) | |||
| except NotFound: | |||
| break | |||
| @@ -175,7 +178,7 @@ def clean_unused_datasets_task(): | |||
| # update document | |||
| update_params = {Document.enabled: False} | |||
| Document.query.filter_by(dataset_id=dataset.id).update(update_params) | |||
| db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) | |||
| db.session.commit() | |||
| click.echo( | |||
| click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") | |||
| @@ -19,7 +19,9 @@ def create_tidb_serverless_task(): | |||
| while True: | |||
| try: | |||
| # check the number of idle tidb serverless | |||
| idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() | |||
| idle_tidb_serverless_number = ( | |||
| db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count() | |||
| ) | |||
| if idle_tidb_serverless_number >= tidb_serverless_number: | |||
| break | |||
| # create tidb serverless | |||
| @@ -29,7 +29,9 @@ def mail_clean_document_notify_task(): | |||
| # send document clean notify mail | |||
| try: | |||
| dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() | |||
| dataset_auto_disable_logs = ( | |||
| db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all() | |||
| ) | |||
| # group by tenant_id | |||
| dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) | |||
| for dataset_auto_disable_log in dataset_auto_disable_logs: | |||
| @@ -65,7 +67,7 @@ def mail_clean_document_notify_task(): | |||
| ) | |||
| for dataset_id, document_ids in dataset_auto_dataset_map.items(): | |||
| dataset = Dataset.query.filter(Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| if dataset: | |||
| document_count = len(document_ids) | |||
| knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") | |||
| @@ -5,6 +5,7 @@ import click | |||
| import app | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService | |||
| from extensions.ext_database import db | |||
| from models.dataset import TidbAuthBinding | |||
| @@ -14,9 +15,11 @@ def update_tidb_serverless_status_task(): | |||
| start_at = time.perf_counter() | |||
| try: | |||
| # check the number of idle tidb serverless | |||
| tidb_serverless_list = TidbAuthBinding.query.filter( | |||
| TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" | |||
| ).all() | |||
| tidb_serverless_list = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") | |||
| .all() | |||
| ) | |||
| if len(tidb_serverless_list) == 0: | |||
| return | |||
| # update tidb serverless status | |||
| @@ -9,7 +9,7 @@ from collections import Counter | |||
| from typing import Any, Optional | |||
| from flask_login import current_user | |||
| from sqlalchemy import func | |||
| from sqlalchemy import func, select | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -77,11 +77,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde | |||
| class DatasetService: | |||
| @staticmethod | |||
| def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): | |||
| query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) | |||
| query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) | |||
| if user: | |||
| # get permitted dataset ids | |||
| dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() | |||
| dataset_permission = ( | |||
| db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all() | |||
| ) | |||
| permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None | |||
| if user.current_role == TenantAccountRole.DATASET_OPERATOR: | |||
| @@ -129,7 +131,7 @@ class DatasetService: | |||
| else: | |||
| return [], 0 | |||
| datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) | |||
| datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False) | |||
| return datasets.items, datasets.total | |||
| @@ -153,9 +155,10 @@ class DatasetService: | |||
| @staticmethod | |||
| def get_datasets_by_ids(ids, tenant_id): | |||
| datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( | |||
| page=1, per_page=len(ids), max_per_page=len(ids), error_out=False | |||
| ) | |||
| stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) | |||
| datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) | |||
| return datasets.items, datasets.total | |||
| @staticmethod | |||
| @@ -174,7 +177,7 @@ class DatasetService: | |||
| retrieval_model: Optional[RetrievalModel] = None, | |||
| ): | |||
| # check if dataset name already exists | |||
| if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): | |||
| if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): | |||
| raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") | |||
| embedding_model = None | |||
| if indexing_technique == "high_quality": | |||
| @@ -235,7 +238,7 @@ class DatasetService: | |||
| @staticmethod | |||
| def get_dataset(dataset_id) -> Optional[Dataset]: | |||
| dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() | |||
| dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() | |||
| return dataset | |||
| @staticmethod | |||
| @@ -436,7 +439,7 @@ class DatasetService: | |||
| # update Retrieval model | |||
| filtered_data["retrieval_model"] = data["retrieval_model"] | |||
| dataset.query.filter_by(id=dataset_id).update(filtered_data) | |||
| db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) | |||
| db.session.commit() | |||
| if action: | |||
| @@ -460,7 +463,7 @@ class DatasetService: | |||
| @staticmethod | |||
| def dataset_use_check(dataset_id) -> bool: | |||
| count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count() | |||
| count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count() | |||
| if count > 0: | |||
| return True | |||
| return False | |||
| @@ -475,7 +478,9 @@ class DatasetService: | |||
| logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") | |||
| raise NoPermissionError("You do not have permission to access this dataset.") | |||
| if dataset.permission == "partial_members": | |||
| user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() | |||
| user_permission = ( | |||
| db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first() | |||
| ) | |||
| if ( | |||
| not user_permission | |||
| and dataset.tenant_id != user.current_tenant_id | |||
| @@ -499,23 +504,24 @@ class DatasetService: | |||
| elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: | |||
| if not any( | |||
| dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() | |||
| dp.dataset_id == dataset.id | |||
| for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all() | |||
| ): | |||
| raise NoPermissionError("You do not have permission to access this dataset.") | |||
| @staticmethod | |||
| def get_dataset_queries(dataset_id: str, page: int, per_page: int): | |||
| dataset_queries = ( | |||
| DatasetQuery.query.filter_by(dataset_id=dataset_id) | |||
| .order_by(db.desc(DatasetQuery.created_at)) | |||
| .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) | |||
| ) | |||
| stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at)) | |||
| dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False) | |||
| return dataset_queries.items, dataset_queries.total | |||
| @staticmethod | |||
| def get_related_apps(dataset_id: str): | |||
| return ( | |||
| AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) | |||
| db.session.query(AppDatasetJoin) | |||
| .filter(AppDatasetJoin.dataset_id == dataset_id) | |||
| .order_by(db.desc(AppDatasetJoin.created_at)) | |||
| .all() | |||
| ) | |||
| @@ -530,10 +536,14 @@ class DatasetService: | |||
| } | |||
| # get recent 30 days auto disable logs | |||
| start_date = datetime.datetime.now() - datetime.timedelta(days=30) | |||
| dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( | |||
| DatasetAutoDisableLog.dataset_id == dataset_id, | |||
| DatasetAutoDisableLog.created_at >= start_date, | |||
| ).all() | |||
| dataset_auto_disable_logs = ( | |||
| db.session.query(DatasetAutoDisableLog) | |||
| .filter( | |||
| DatasetAutoDisableLog.dataset_id == dataset_id, | |||
| DatasetAutoDisableLog.created_at >= start_date, | |||
| ) | |||
| .all() | |||
| ) | |||
| if dataset_auto_disable_logs: | |||
| return { | |||
| "document_ids": [log.document_id for log in dataset_auto_disable_logs], | |||
| @@ -873,7 +883,9 @@ class DocumentService: | |||
| @staticmethod | |||
| def get_documents_position(dataset_id): | |||
| document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() | |||
| document = ( | |||
| db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() | |||
| ) | |||
| if document: | |||
| return document.position + 1 | |||
| else: | |||
| @@ -1010,13 +1022,17 @@ class DocumentService: | |||
| } | |||
| # check duplicate | |||
| if knowledge_config.duplicate: | |||
| document = Document.query.filter_by( | |||
| dataset_id=dataset.id, | |||
| tenant_id=current_user.current_tenant_id, | |||
| data_source_type="upload_file", | |||
| enabled=True, | |||
| name=file_name, | |||
| ).first() | |||
| document = ( | |||
| db.session.query(Document) | |||
| .filter_by( | |||
| dataset_id=dataset.id, | |||
| tenant_id=current_user.current_tenant_id, | |||
| data_source_type="upload_file", | |||
| enabled=True, | |||
| name=file_name, | |||
| ) | |||
| .first() | |||
| ) | |||
| if document: | |||
| document.dataset_process_rule_id = dataset_process_rule.id # type: ignore | |||
| document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| @@ -1054,12 +1070,16 @@ class DocumentService: | |||
| raise ValueError("No notion info list found.") | |||
| exist_page_ids = [] | |||
| exist_document = {} | |||
| documents = Document.query.filter_by( | |||
| dataset_id=dataset.id, | |||
| tenant_id=current_user.current_tenant_id, | |||
| data_source_type="notion_import", | |||
| enabled=True, | |||
| ).all() | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter_by( | |||
| dataset_id=dataset.id, | |||
| tenant_id=current_user.current_tenant_id, | |||
| data_source_type="notion_import", | |||
| enabled=True, | |||
| ) | |||
| .all() | |||
| ) | |||
| if documents: | |||
| for document in documents: | |||
| data_source_info = json.loads(document.data_source_info) | |||
| @@ -1206,12 +1226,16 @@ class DocumentService: | |||
| @staticmethod | |||
| def get_tenant_documents_count(): | |||
| documents_count = Document.query.filter( | |||
| Document.completed_at.isnot(None), | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| Document.tenant_id == current_user.current_tenant_id, | |||
| ).count() | |||
| documents_count = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.completed_at.isnot(None), | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| Document.tenant_id == current_user.current_tenant_id, | |||
| ) | |||
| .count() | |||
| ) | |||
| return documents_count | |||
| @staticmethod | |||
| @@ -1328,7 +1352,7 @@ class DocumentService: | |||
| db.session.commit() | |||
| # update document segment | |||
| update_params = {DocumentSegment.status: "re_segment"} | |||
| DocumentSegment.query.filter_by(document_id=document.id).update(update_params) | |||
| db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params) | |||
| db.session.commit() | |||
| # trigger async task | |||
| document_indexing_update_task.delay(document.dataset_id, document.id) | |||
| @@ -1918,7 +1942,8 @@ class SegmentService: | |||
| @classmethod | |||
| def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): | |||
| index_node_ids = ( | |||
| DocumentSegment.query.with_entities(DocumentSegment.index_node_id) | |||
| db.session.query(DocumentSegment) | |||
| .with_entities(DocumentSegment.index_node_id) | |||
| .filter( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| @@ -2157,20 +2182,28 @@ class SegmentService: | |||
| def get_child_chunks( | |||
| cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None | |||
| ): | |||
| query = ChildChunk.query.filter_by( | |||
| tenant_id=current_user.current_tenant_id, | |||
| dataset_id=dataset_id, | |||
| document_id=document_id, | |||
| segment_id=segment_id, | |||
| ).order_by(ChildChunk.position.asc()) | |||
| query = ( | |||
| select(ChildChunk) | |||
| .filter_by( | |||
| tenant_id=current_user.current_tenant_id, | |||
| dataset_id=dataset_id, | |||
| document_id=document_id, | |||
| segment_id=segment_id, | |||
| ) | |||
| .order_by(ChildChunk.position.asc()) | |||
| ) | |||
| if keyword: | |||
| query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) | |||
| return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| @classmethod | |||
| def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: | |||
| """Get a child chunk by its ID.""" | |||
| result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first() | |||
| result = ( | |||
| db.session.query(ChildChunk) | |||
| .filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) | |||
| .first() | |||
| ) | |||
| return result if isinstance(result, ChildChunk) else None | |||
| @classmethod | |||
| @@ -2184,7 +2217,7 @@ class SegmentService: | |||
| limit: int = 20, | |||
| ): | |||
| """Get segments for a document with optional filtering.""" | |||
| query = DocumentSegment.query.filter( | |||
| query = select(DocumentSegment).filter( | |||
| DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id | |||
| ) | |||
| @@ -2194,9 +2227,8 @@ class SegmentService: | |||
| if keyword: | |||
| query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) | |||
| paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate( | |||
| page=page, per_page=limit, max_per_page=100, error_out=False | |||
| ) | |||
| query = query.order_by(DocumentSegment.position.asc()) | |||
| paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| return paginated_segments.items, paginated_segments.total | |||
| @@ -2236,9 +2268,11 @@ class SegmentService: | |||
| raise ValueError(ex.description) | |||
| # check segment | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id | |||
| ).first() | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| @@ -2251,9 +2285,11 @@ class SegmentService: | |||
| @classmethod | |||
| def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: | |||
| """Get a segment by its ID.""" | |||
| result = DocumentSegment.query.filter( | |||
| DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id | |||
| ).first() | |||
| result = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) | |||
| .first() | |||
| ) | |||
| return result if isinstance(result, DocumentSegment) else None | |||
| @@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast | |||
| from urllib.parse import urlparse | |||
| import httpx | |||
| from sqlalchemy import select | |||
| from constants import HIDDEN_VALUE | |||
| from core.helper import ssrf_proxy | |||
| @@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError | |||
| class ExternalDatasetService: | |||
| @staticmethod | |||
| def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: | |||
| query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by( | |||
| ExternalKnowledgeApis.created_at.desc() | |||
| def get_external_knowledge_apis( | |||
| page, per_page, tenant_id, search=None | |||
| ) -> tuple[list[ExternalKnowledgeApis], int | None]: | |||
| query = ( | |||
| select(ExternalKnowledgeApis) | |||
| .filter(ExternalKnowledgeApis.tenant_id == tenant_id) | |||
| .order_by(ExternalKnowledgeApis.created_at.desc()) | |||
| ) | |||
| if search: | |||
| query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) | |||
| external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) | |||
| external_knowledge_apis = db.paginate( | |||
| select=query, page=page, per_page=per_page, max_per_page=100, error_out=False | |||
| ) | |||
| return external_knowledge_apis.items, external_knowledge_apis.total | |||
| @@ -92,18 +99,18 @@ class ExternalDatasetService: | |||
| @staticmethod | |||
| def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: | |||
| external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( | |||
| id=external_knowledge_api_id | |||
| ).first() | |||
| external_knowledge_api: Optional[ExternalKnowledgeApis] = ( | |||
| db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first() | |||
| ) | |||
| if external_knowledge_api is None: | |||
| raise ValueError("api template not found") | |||
| return external_knowledge_api | |||
| @staticmethod | |||
| def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: | |||
| external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( | |||
| id=external_knowledge_api_id, tenant_id=tenant_id | |||
| ).first() | |||
| external_knowledge_api: Optional[ExternalKnowledgeApis] = ( | |||
| db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() | |||
| ) | |||
| if external_knowledge_api is None: | |||
| raise ValueError("api template not found") | |||
| if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: | |||
| @@ -120,9 +127,9 @@ class ExternalDatasetService: | |||
| @staticmethod | |||
| def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): | |||
| external_knowledge_api = ExternalKnowledgeApis.query.filter_by( | |||
| id=external_knowledge_api_id, tenant_id=tenant_id | |||
| ).first() | |||
| external_knowledge_api = ( | |||
| db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() | |||
| ) | |||
| if external_knowledge_api is None: | |||
| raise ValueError("api template not found") | |||
| @@ -131,25 +138,29 @@ class ExternalDatasetService: | |||
| @staticmethod | |||
| def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: | |||
| count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count() | |||
| count = ( | |||
| db.session.query(ExternalKnowledgeBindings) | |||
| .filter_by(external_knowledge_api_id=external_knowledge_api_id) | |||
| .count() | |||
| ) | |||
| if count > 0: | |||
| return True, count | |||
| return False, 0 | |||
| @staticmethod | |||
| def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: | |||
| external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( | |||
| dataset_id=dataset_id, tenant_id=tenant_id | |||
| ).first() | |||
| external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ( | |||
| db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() | |||
| ) | |||
| if not external_knowledge_binding: | |||
| raise ValueError("external knowledge binding not found") | |||
| return external_knowledge_binding | |||
| @staticmethod | |||
| def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): | |||
| external_knowledge_api = ExternalKnowledgeApis.query.filter_by( | |||
| id=external_knowledge_api_id, tenant_id=tenant_id | |||
| ).first() | |||
| external_knowledge_api = ( | |||
| db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() | |||
| ) | |||
| if external_knowledge_api is None: | |||
| raise ValueError("api template not found") | |||
| settings = json.loads(external_knowledge_api.settings) | |||
| @@ -212,11 +223,13 @@ class ExternalDatasetService: | |||
| @staticmethod | |||
| def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: | |||
| # check if dataset name already exists | |||
| if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first(): | |||
| if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first(): | |||
| raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") | |||
| external_knowledge_api = ExternalKnowledgeApis.query.filter_by( | |||
| id=args.get("external_knowledge_api_id"), tenant_id=tenant_id | |||
| ).first() | |||
| external_knowledge_api = ( | |||
| db.session.query(ExternalKnowledgeApis) | |||
| .filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id) | |||
| .first() | |||
| ) | |||
| if external_knowledge_api is None: | |||
| raise ValueError("api template not found") | |||
| @@ -254,15 +267,17 @@ class ExternalDatasetService: | |||
| external_retrieval_parameters: dict, | |||
| metadata_condition: Optional[MetadataCondition] = None, | |||
| ) -> list: | |||
| external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( | |||
| dataset_id=dataset_id, tenant_id=tenant_id | |||
| ).first() | |||
| external_knowledge_binding = ( | |||
| db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() | |||
| ) | |||
| if not external_knowledge_binding: | |||
| raise ValueError("external knowledge binding not found") | |||
| external_knowledge_api = ExternalKnowledgeApis.query.filter_by( | |||
| id=external_knowledge_binding.external_knowledge_api_id | |||
| ).first() | |||
| external_knowledge_api = ( | |||
| db.session.query(ExternalKnowledgeApis) | |||
| .filter_by(id=external_knowledge_binding.external_knowledge_api_id) | |||
| .first() | |||
| ) | |||
| if not external_knowledge_api: | |||
| raise ValueError("external api template not found") | |||
| @@ -20,9 +20,11 @@ class MetadataService: | |||
| @staticmethod | |||
| def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: | |||
| # check if metadata name already exists | |||
| if DatasetMetadata.query.filter_by( | |||
| tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name | |||
| ).first(): | |||
| if ( | |||
| db.session.query(DatasetMetadata) | |||
| .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name) | |||
| .first() | |||
| ): | |||
| raise ValueError("Metadata name already exists.") | |||
| for field in BuiltInField: | |||
| if field.value == metadata_args.name: | |||
| @@ -42,16 +44,18 @@ class MetadataService: | |||
| def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore | |||
| lock_key = f"dataset_metadata_lock_{dataset_id}" | |||
| # check if metadata name already exists | |||
| if DatasetMetadata.query.filter_by( | |||
| tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name | |||
| ).first(): | |||
| if ( | |||
| db.session.query(DatasetMetadata) | |||
| .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name) | |||
| .first() | |||
| ): | |||
| raise ValueError("Metadata name already exists.") | |||
| for field in BuiltInField: | |||
| if field.value == name: | |||
| raise ValueError("Metadata name already exists in Built-in fields.") | |||
| try: | |||
| MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) | |||
| metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() | |||
| metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first() | |||
| if metadata is None: | |||
| raise ValueError("Metadata not found.") | |||
| old_name = metadata.name | |||
| @@ -60,7 +64,9 @@ class MetadataService: | |||
| metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| # update related documents | |||
| dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() | |||
| dataset_metadata_bindings = ( | |||
| db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() | |||
| ) | |||
| if dataset_metadata_bindings: | |||
| document_ids = [binding.document_id for binding in dataset_metadata_bindings] | |||
| documents = DocumentService.get_document_by_ids(document_ids) | |||
| @@ -82,13 +88,15 @@ class MetadataService: | |||
| lock_key = f"dataset_metadata_lock_{dataset_id}" | |||
| try: | |||
| MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) | |||
| metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() | |||
| metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first() | |||
| if metadata is None: | |||
| raise ValueError("Metadata not found.") | |||
| db.session.delete(metadata) | |||
| # deal related documents | |||
| dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() | |||
| dataset_metadata_bindings = ( | |||
| db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() | |||
| ) | |||
| if dataset_metadata_bindings: | |||
| document_ids = [binding.document_id for binding in dataset_metadata_bindings] | |||
| documents = DocumentService.get_document_by_ids(document_ids) | |||
| @@ -193,7 +201,7 @@ class MetadataService: | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| # deal metadata binding | |||
| DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() | |||
| db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete() | |||
| for metadata_value in operation.metadata_list: | |||
| dataset_metadata_binding = DatasetMetadataBinding( | |||
| tenant_id=current_user.current_tenant_id, | |||
| @@ -230,9 +238,9 @@ class MetadataService: | |||
| "id": item.get("id"), | |||
| "name": item.get("name"), | |||
| "type": item.get("type"), | |||
| "count": DatasetMetadataBinding.query.filter_by( | |||
| metadata_id=item.get("id"), dataset_id=dataset.id | |||
| ).count(), | |||
| "count": db.session.query(DatasetMetadataBinding) | |||
| .filter_by(metadata_id=item.get("id"), dataset_id=dataset.id) | |||
| .count(), | |||
| } | |||
| for item in dataset.doc_metadata or [] | |||
| if item.get("id") != "built-in" | |||
| @@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] | |||
| DocumentSegment.status: "indexing", | |||
| DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | |||
| } | |||
| DocumentSegment.query.filter_by(id=segment.id).update(update_params) | |||
| db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) | |||
| db.session.commit() | |||
| document = Document( | |||
| page_content=segment.content, | |||
| @@ -78,7 +78,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] | |||
| DocumentSegment.status: "completed", | |||
| DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | |||
| } | |||
| DocumentSegment.query.filter_by(id=segment.id).update(update_params) | |||
| db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) | |||
| db.session.commit() | |||
| end_at = time.perf_counter() | |||
| @@ -24,7 +24,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): | |||
| start_at = time.perf_counter() | |||
| try: | |||
| dataset = Dataset.query.filter_by(id=dataset_id).first() | |||
| dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() | |||
| if not dataset: | |||
| raise Exception("Dataset not found") | |||