Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>tags/0.15.0
| @@ -587,7 +587,7 @@ def upgrade_db(): | |||
| click.echo(click.style("Starting database migration.", fg="green")) | |||
| # run db migration | |||
| import flask_migrate | |||
| import flask_migrate # type: ignore | |||
| flask_migrate.upgrade() | |||
| @@ -413,7 +413,7 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| indexing_runner = IndexingRunner() | |||
| try: | |||
| response = indexing_runner.indexing_estimate( | |||
| estimate_response = indexing_runner.indexing_estimate( | |||
| current_user.current_tenant_id, | |||
| [extract_setting], | |||
| data_process_rule_dict, | |||
| @@ -421,6 +421,7 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| "English", | |||
| dataset_id, | |||
| ) | |||
| return estimate_response.model_dump(), 200 | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| @@ -431,7 +432,7 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| except Exception as e: | |||
| raise IndexingEstimateError(str(e)) | |||
| return response.model_dump(), 200 | |||
| return response, 200 | |||
| class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| @@ -521,6 +522,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| "English", | |||
| dataset_id, | |||
| ) | |||
| return response.model_dump(), 200 | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| @@ -530,7 +532,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except Exception as e: | |||
| raise IndexingEstimateError(str(e)) | |||
| return response.model_dump(), 200 | |||
| class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| @@ -22,6 +22,7 @@ from fields.document_fields import document_fields, document_status_fields | |||
| from libs.login import current_user | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from services.dataset_service import DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig | |||
| from services.file_service import FileService | |||
| @@ -67,13 +68,14 @@ class DocumentAddByTextApi(DatasetApiResource): | |||
| "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | |||
| } | |||
| args["data_source"] = data_source | |||
| knowledge_config = KnowledgeConfig(**args) | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| DocumentService.document_create_args_validate(knowledge_config) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| document_data=args, | |||
| knowledge_config=knowledge_config, | |||
| account=current_user, | |||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | |||
| created_from="api", | |||
| @@ -122,12 +124,13 @@ class DocumentUpdateByTextApi(DatasetApiResource): | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| args["original_document_id"] = str(document_id) | |||
| DocumentService.document_create_args_validate(args) | |||
| knowledge_config = KnowledgeConfig(**args) | |||
| DocumentService.document_create_args_validate(knowledge_config) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| document_data=args, | |||
| knowledge_config=knowledge_config, | |||
| account=current_user, | |||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | |||
| created_from="api", | |||
| @@ -186,12 +189,13 @@ class DocumentAddByFileApi(DatasetApiResource): | |||
| data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| knowledge_config = KnowledgeConfig(**args) | |||
| DocumentService.document_create_args_validate(knowledge_config) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| document_data=args, | |||
| knowledge_config=knowledge_config, | |||
| account=dataset.created_by_account, | |||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | |||
| created_from="api", | |||
| @@ -245,12 +249,14 @@ class DocumentUpdateByFileApi(DatasetApiResource): | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| args["original_document_id"] = str(document_id) | |||
| DocumentService.document_create_args_validate(args) | |||
| knowledge_config = KnowledgeConfig(**args) | |||
| DocumentService.document_create_args_validate(knowledge_config) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| document_data=args, | |||
| knowledge_config=knowledge_config, | |||
| account=dataset.created_by_account, | |||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | |||
| created_from="api", | |||
| @@ -276,7 +276,7 @@ class IndexingRunner: | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| preview_texts = [] | |||
| preview_texts = [] # type: ignore | |||
| total_segments = 0 | |||
| index_type = doc_form | |||
| @@ -300,13 +300,13 @@ class IndexingRunner: | |||
| if len(preview_texts) < 10: | |||
| if doc_form and doc_form == "qa_model": | |||
| preview_detail = QAPreviewDetail( | |||
| question=document.page_content, answer=document.metadata.get("answer") | |||
| question=document.page_content, answer=document.metadata.get("answer") or "" | |||
| ) | |||
| preview_texts.append(preview_detail) | |||
| else: | |||
| preview_detail = PreviewDetail(content=document.page_content) | |||
| preview_detail = PreviewDetail(content=document.page_content) # type: ignore | |||
| if document.children: | |||
| preview_detail.child_chunks = [child.page_content for child in document.children] | |||
| preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore | |||
| preview_texts.append(preview_detail) | |||
| # delete image files and related db records | |||
| @@ -325,7 +325,7 @@ class IndexingRunner: | |||
| if doc_form and doc_form == "qa_model": | |||
| return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) | |||
| return IndexingEstimate(total_segments=total_segments, preview=preview_texts) | |||
| return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore | |||
| def _extract( | |||
| self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict | |||
| @@ -454,7 +454,7 @@ class IndexingRunner: | |||
| embedding_model_instance=embedding_model_instance, | |||
| ) | |||
| return character_splitter | |||
| return character_splitter # type: ignore | |||
| def _split_to_documents_for_estimate( | |||
| self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule | |||
| @@ -535,7 +535,7 @@ class IndexingRunner: | |||
| # create keyword index | |||
| create_keyword_thread = threading.Thread( | |||
| target=self._process_keyword_index, | |||
| args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), | |||
| args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore | |||
| ) | |||
| create_keyword_thread.start() | |||
| @@ -258,78 +258,79 @@ class RetrievalService: | |||
| include_segment_ids = [] | |||
| segment_child_map = {} | |||
| for document in documents: | |||
| document_id = document.metadata["document_id"] | |||
| document_id = document.metadata.get("document_id") | |||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | |||
| if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_index_node_id = document.metadata["doc_id"] | |||
| result = ( | |||
| db.session.query(ChildChunk, DocumentSegment) | |||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||
| .filter( | |||
| ChildChunk.index_node_id == child_index_node_id, | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| if dataset_document: | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_index_node_id = document.metadata.get("doc_id") | |||
| result = ( | |||
| db.session.query(ChildChunk, DocumentSegment) | |||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||
| .filter( | |||
| ChildChunk.index_node_id == child_index_node_id, | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| ) | |||
| .first() | |||
| ) | |||
| .first() | |||
| ) | |||
| if result: | |||
| child_chunk, segment = result | |||
| if not segment: | |||
| continue | |||
| if segment.id not in include_segment_ids: | |||
| include_segment_ids.append(segment.id) | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| map_detail = { | |||
| "max_score": document.metadata.get("score", 0.0), | |||
| "child_chunks": [child_chunk_detail], | |||
| } | |||
| segment_child_map[segment.id] = map_detail | |||
| record = { | |||
| "segment": segment, | |||
| } | |||
| records.append(record) | |||
| if result: | |||
| child_chunk, segment = result | |||
| if not segment: | |||
| continue | |||
| if segment.id not in include_segment_ids: | |||
| include_segment_ids.append(segment.id) | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| map_detail = { | |||
| "max_score": document.metadata.get("score", 0.0), | |||
| "child_chunks": [child_chunk_detail], | |||
| } | |||
| segment_child_map[segment.id] = map_detail | |||
| record = { | |||
| "segment": segment, | |||
| } | |||
| records.append(record) | |||
| else: | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | |||
| segment_child_map[segment.id]["max_score"] = max( | |||
| segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | |||
| ) | |||
| else: | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | |||
| segment_child_map[segment.id]["max_score"] = max( | |||
| segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | |||
| ) | |||
| continue | |||
| else: | |||
| continue | |||
| else: | |||
| index_node_id = document.metadata["doc_id"] | |||
| index_node_id = document.metadata["doc_id"] | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.index_node_id == index_node_id, | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.index_node_id == index_node_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| continue | |||
| include_segment_ids.append(segment.id) | |||
| record = { | |||
| "segment": segment, | |||
| "score": document.metadata.get("score", None), | |||
| } | |||
| if not segment: | |||
| continue | |||
| include_segment_ids.append(segment.id) | |||
| record = { | |||
| "segment": segment, | |||
| "score": document.metadata.get("score", None), | |||
| } | |||
| records.append(record) | |||
| records.append(record) | |||
| for record in records: | |||
| if record["segment"].id in segment_child_map: | |||
| record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) | |||
| @@ -122,26 +122,27 @@ class DatasetDocumentStore: | |||
| db.session.add(segment_document) | |||
| db.session.flush() | |||
| if save_child: | |||
| for postion, child in enumerate(doc.children, start=1): | |||
| child_segment = ChildChunk( | |||
| tenant_id=self._dataset.tenant_id, | |||
| dataset_id=self._dataset.id, | |||
| document_id=self._document_id, | |||
| segment_id=segment_document.id, | |||
| position=postion, | |||
| index_node_id=child.metadata["doc_id"], | |||
| index_node_hash=child.metadata["doc_hash"], | |||
| content=child.page_content, | |||
| word_count=len(child.page_content), | |||
| type="automatic", | |||
| created_by=self._user_id, | |||
| ) | |||
| db.session.add(child_segment) | |||
| if doc.children: | |||
| for postion, child in enumerate(doc.children, start=1): | |||
| child_segment = ChildChunk( | |||
| tenant_id=self._dataset.tenant_id, | |||
| dataset_id=self._dataset.id, | |||
| document_id=self._document_id, | |||
| segment_id=segment_document.id, | |||
| position=postion, | |||
| index_node_id=child.metadata.get("doc_id"), | |||
| index_node_hash=child.metadata.get("doc_hash"), | |||
| content=child.page_content, | |||
| word_count=len(child.page_content), | |||
| type="automatic", | |||
| created_by=self._user_id, | |||
| ) | |||
| db.session.add(child_segment) | |||
| else: | |||
| segment_document.content = doc.page_content | |||
| if doc.metadata.get("answer"): | |||
| segment_document.answer = doc.metadata.pop("answer", "") | |||
| segment_document.index_node_hash = doc.metadata["doc_hash"] | |||
| segment_document.index_node_hash = doc.metadata.get("doc_hash") | |||
| segment_document.word_count = len(doc.page_content) | |||
| segment_document.tokens = tokens | |||
| if save_child and doc.children: | |||
| @@ -160,8 +161,8 @@ class DatasetDocumentStore: | |||
| document_id=self._document_id, | |||
| segment_id=segment_document.id, | |||
| position=position, | |||
| index_node_id=child.metadata["doc_id"], | |||
| index_node_hash=child.metadata["doc_hash"], | |||
| index_node_id=child.metadata.get("doc_id"), | |||
| index_node_hash=child.metadata.get("doc_hash"), | |||
| content=child.page_content, | |||
| word_count=len(child.page_content), | |||
| type="automatic", | |||
| @@ -4,7 +4,7 @@ import os | |||
| from typing import Optional, cast | |||
| import pandas as pd | |||
| from openpyxl import load_workbook | |||
| from openpyxl import load_workbook # type: ignore | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| @@ -81,4 +81,4 @@ class BaseIndexProcessor(ABC): | |||
| embedding_model_instance=embedding_model_instance, | |||
| ) | |||
| return character_splitter | |||
| return character_splitter # type: ignore | |||
| @@ -30,12 +30,18 @@ class ParagraphIndexProcessor(BaseIndexProcessor): | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| process_rule = kwargs.get("process_rule") | |||
| if not process_rule: | |||
| raise ValueError("No process rule found.") | |||
| if process_rule.get("mode") == "automatic": | |||
| automatic_rule = DatasetProcessRule.AUTOMATIC_RULES | |||
| rules = Rule(**automatic_rule) | |||
| else: | |||
| if not process_rule.get("rules"): | |||
| raise ValueError("No rules found in process rule.") | |||
| rules = Rule(**process_rule.get("rules")) | |||
| # Split the text documents into nodes. | |||
| if not rules.segmentation: | |||
| raise ValueError("No segmentation found in rules.") | |||
| splitter = self._get_splitter( | |||
| processing_rule_mode=process_rule.get("mode"), | |||
| max_tokens=rules.segmentation.max_tokens, | |||
| @@ -30,8 +30,12 @@ class ParentChildIndexProcessor(BaseIndexProcessor): | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| process_rule = kwargs.get("process_rule") | |||
| if not process_rule: | |||
| raise ValueError("No process rule found.") | |||
| if not process_rule.get("rules"): | |||
| raise ValueError("No rules found in process rule.") | |||
| rules = Rule(**process_rule.get("rules")) | |||
| all_documents = [] | |||
| all_documents = [] # type: ignore | |||
| if rules.parent_mode == ParentMode.PARAGRAPH: | |||
| # Split the text documents into nodes. | |||
| splitter = self._get_splitter( | |||
| @@ -161,6 +165,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): | |||
| process_rule_mode: str, | |||
| embedding_model_instance: Optional[ModelInstance], | |||
| ) -> list[ChildDocument]: | |||
| if not rules.subchunk_segmentation: | |||
| raise ValueError("No subchunk segmentation found in rules.") | |||
| child_splitter = self._get_splitter( | |||
| processing_rule_mode=process_rule_mode, | |||
| max_tokens=rules.subchunk_segmentation.max_tokens, | |||
| @@ -37,12 +37,16 @@ class QAIndexProcessor(BaseIndexProcessor): | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| preview = kwargs.get("preview") | |||
| process_rule = kwargs.get("process_rule") | |||
| if not process_rule: | |||
| raise ValueError("No process rule found.") | |||
| if not process_rule.get("rules"): | |||
| raise ValueError("No rules found in process rule.") | |||
| rules = Rule(**process_rule.get("rules")) | |||
| splitter = self._get_splitter( | |||
| processing_rule_mode=process_rule.get("mode"), | |||
| max_tokens=rules.segmentation.max_tokens, | |||
| chunk_overlap=rules.segmentation.chunk_overlap, | |||
| separator=rules.segmentation.separator, | |||
| max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, | |||
| chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0, | |||
| separator=rules.segmentation.separator if rules.segmentation else "", | |||
| embedding_model_instance=kwargs.get("embedding_model_instance"), | |||
| ) | |||
| @@ -71,8 +75,8 @@ class QAIndexProcessor(BaseIndexProcessor): | |||
| all_documents.extend(split_documents) | |||
| if preview: | |||
| self._format_qa_document( | |||
| current_app._get_current_object(), | |||
| kwargs.get("tenant_id"), | |||
| current_app._get_current_object(), # type: ignore | |||
| kwargs.get("tenant_id"), # type: ignore | |||
| all_documents[0], | |||
| all_qa_documents, | |||
| kwargs.get("doc_language", "English"), | |||
| @@ -85,8 +89,8 @@ class QAIndexProcessor(BaseIndexProcessor): | |||
| document_format_thread = threading.Thread( | |||
| target=self._format_qa_document, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), | |||
| "tenant_id": kwargs.get("tenant_id"), | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "tenant_id": kwargs.get("tenant_id"), # type: ignore | |||
| "document_node": doc, | |||
| "all_qa_documents": all_qa_documents, | |||
| "document_language": kwargs.get("doc_language", "English"), | |||
| @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel, Field | |||
| from pydantic import BaseModel | |||
| class ChildDocument(BaseModel): | |||
| @@ -15,7 +15,7 @@ class ChildDocument(BaseModel): | |||
| """Arbitrary metadata about the page content (e.g., source, relationships to other | |||
| documents, etc.). | |||
| """ | |||
| metadata: Optional[dict] = Field(default_factory=dict) | |||
| metadata: dict = {} | |||
| class Document(BaseModel): | |||
| @@ -28,7 +28,7 @@ class Document(BaseModel): | |||
| """Arbitrary metadata about the page content (e.g., source, relationships to other | |||
| documents, etc.). | |||
| """ | |||
| metadata: Optional[dict] = Field(default_factory=dict) | |||
| metadata: dict = {} | |||
| provider: Optional[str] = "dify" | |||
| @@ -5,7 +5,7 @@ from dify_app import DifyApp | |||
| def init_app(app: DifyApp): | |||
| # register blueprint routers | |||
| from flask_cors import CORS | |||
| from flask_cors import CORS # type: ignore | |||
| from controllers.console import bp as console_app_bp | |||
| from controllers.files import bp as files_bp | |||
| @@ -1,9 +1,9 @@ | |||
| import logging | |||
| import time | |||
| from collections import defaultdict | |||
| import click | |||
| from celery import shared_task # type: ignore | |||
| from flask import render_template | |||
| from extensions.ext_mail import mail | |||
| from models.account import Account, Tenant, TenantAccountJoin | |||
| @@ -27,7 +27,7 @@ def send_document_clean_notify_task(): | |||
| try: | |||
| dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() | |||
| # group by tenant_id | |||
| dataset_auto_disable_logs_map = {} | |||
| dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) | |||
| for dataset_auto_disable_log in dataset_auto_disable_logs: | |||
| dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) | |||
| @@ -37,11 +37,13 @@ def send_document_clean_notify_task(): | |||
| if not tenant: | |||
| continue | |||
| current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() | |||
| if not current_owner_join: | |||
| continue | |||
| account = Account.query.filter(Account.id == current_owner_join.account_id).first() | |||
| if not account: | |||
| continue | |||
| dataset_auto_dataset_map = {} | |||
| dataset_auto_dataset_map = {} # type: ignore | |||
| for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: | |||
| dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( | |||
| dataset_auto_disable_log.document_id | |||
| @@ -53,14 +55,9 @@ def send_document_clean_notify_task(): | |||
| document_count = len(document_ids) | |||
| knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>") | |||
| html_content = render_template( | |||
| "clean_document_job_mail_template-US.html", | |||
| ) | |||
| mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") | |||
| ) | |||
| except Exception: | |||
| logging.exception("Send invite member mail to {} failed".format(to)) | |||
| logging.exception("Send invite member mail to failed") | |||
| @@ -4,7 +4,7 @@ from enum import StrEnum | |||
| from typing import Optional, cast | |||
| from uuid import uuid4 | |||
| import yaml | |||
| import yaml # type: ignore | |||
| from packaging import version | |||
| from pydantic import BaseModel | |||
| from sqlalchemy import select | |||
| @@ -465,7 +465,7 @@ class AppDslService: | |||
| else: | |||
| cls._append_model_config_export_data(export_data, app_model) | |||
| return yaml.dump(export_data, allow_unicode=True) | |||
| return yaml.dump(export_data, allow_unicode=True) # type: ignore | |||
| @classmethod | |||
| def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: | |||
| @@ -41,6 +41,7 @@ from models.source import DataSourceOauthBinding | |||
| from services.entities.knowledge_entities.knowledge_entities import ( | |||
| ChildChunkUpdateArgs, | |||
| KnowledgeConfig, | |||
| RerankingModel, | |||
| RetrievalModel, | |||
| SegmentUpdateArgs, | |||
| ) | |||
| @@ -548,12 +549,14 @@ class DocumentService: | |||
| } | |||
| @staticmethod | |||
| def get_document(dataset_id: str, document_id: str) -> Optional[Document]: | |||
| document = ( | |||
| db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() | |||
| ) | |||
| return document | |||
| def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: | |||
| if document_id: | |||
| document = ( | |||
| db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() | |||
| ) | |||
| return document | |||
| else: | |||
| return None | |||
| @staticmethod | |||
| def get_document_by_id(document_id: str) -> Optional[Document]: | |||
| @@ -744,25 +747,26 @@ class DocumentService: | |||
| if features.billing.enabled: | |||
| if not knowledge_config.original_document_id: | |||
| count = 0 | |||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | |||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids | |||
| count = len(upload_file_list) | |||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | |||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | |||
| for notion_info in notion_info_list: | |||
| count = count + len(notion_info.pages) | |||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | |||
| website_info = knowledge_config.data_source.info_list.website_info_list | |||
| count = len(website_info.urls) | |||
| batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) | |||
| if count > batch_upload_limit: | |||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||
| DocumentService.check_documents_upload_quota(count, features) | |||
| if knowledge_config.data_source: | |||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | |||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore | |||
| count = len(upload_file_list) | |||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | |||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | |||
| for notion_info in notion_info_list: # type: ignore | |||
| count = count + len(notion_info.pages) | |||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | |||
| website_info = knowledge_config.data_source.info_list.website_info_list | |||
| count = len(website_info.urls) # type: ignore | |||
| batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) | |||
| if count > batch_upload_limit: | |||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||
| DocumentService.check_documents_upload_quota(count, features) | |||
| # if dataset is empty, update dataset data_source_type | |||
| if not dataset.data_source_type: | |||
| dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type | |||
| dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore | |||
| if not dataset.indexing_technique: | |||
| if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: | |||
| @@ -789,7 +793,7 @@ class DocumentService: | |||
| "score_threshold_enabled": False, | |||
| } | |||
| dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model | |||
| dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore | |||
| documents = [] | |||
| if knowledge_config.original_document_id: | |||
| @@ -801,34 +805,35 @@ class DocumentService: | |||
| # save process rule | |||
| if not dataset_process_rule: | |||
| process_rule = knowledge_config.process_rule | |||
| if process_rule.mode in ("custom", "hierarchical"): | |||
| dataset_process_rule = DatasetProcessRule( | |||
| dataset_id=dataset.id, | |||
| mode=process_rule.mode, | |||
| rules=process_rule.rules.model_dump_json(), | |||
| created_by=account.id, | |||
| ) | |||
| elif process_rule.mode == "automatic": | |||
| dataset_process_rule = DatasetProcessRule( | |||
| dataset_id=dataset.id, | |||
| mode=process_rule.mode, | |||
| rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), | |||
| created_by=account.id, | |||
| ) | |||
| else: | |||
| logging.warn( | |||
| f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" | |||
| ) | |||
| return | |||
| db.session.add(dataset_process_rule) | |||
| db.session.commit() | |||
| if process_rule: | |||
| if process_rule.mode in ("custom", "hierarchical"): | |||
| dataset_process_rule = DatasetProcessRule( | |||
| dataset_id=dataset.id, | |||
| mode=process_rule.mode, | |||
| rules=process_rule.rules.model_dump_json() if process_rule.rules else None, | |||
| created_by=account.id, | |||
| ) | |||
| elif process_rule.mode == "automatic": | |||
| dataset_process_rule = DatasetProcessRule( | |||
| dataset_id=dataset.id, | |||
| mode=process_rule.mode, | |||
| rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), | |||
| created_by=account.id, | |||
| ) | |||
| else: | |||
| logging.warn( | |||
| f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" | |||
| ) | |||
| return | |||
| db.session.add(dataset_process_rule) | |||
| db.session.commit() | |||
| lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) | |||
| with redis_client.lock(lock_name, timeout=600): | |||
| position = DocumentService.get_documents_position(dataset.id) | |||
| document_ids = [] | |||
| duplicate_document_ids = [] | |||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | |||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids | |||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore | |||
| for file_id in upload_file_list: | |||
| file = ( | |||
| db.session.query(UploadFile) | |||
| @@ -854,7 +859,7 @@ class DocumentService: | |||
| name=file_name, | |||
| ).first() | |||
| if document: | |||
| document.dataset_process_rule_id = dataset_process_rule.id | |||
| document.dataset_process_rule_id = dataset_process_rule.id # type: ignore | |||
| document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| document.created_from = created_from | |||
| document.doc_form = knowledge_config.doc_form | |||
| @@ -868,7 +873,7 @@ class DocumentService: | |||
| continue | |||
| document = DocumentService.build_document( | |||
| dataset, | |||
| dataset_process_rule.id, | |||
| dataset_process_rule.id, # type: ignore | |||
| knowledge_config.data_source.info_list.data_source_type, | |||
| knowledge_config.doc_form, | |||
| knowledge_config.doc_language, | |||
| @@ -886,6 +891,8 @@ class DocumentService: | |||
| position += 1 | |||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | |||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | |||
| if not notion_info_list: | |||
| raise ValueError("No notion info list found.") | |||
| exist_page_ids = [] | |||
| exist_document = {} | |||
| documents = Document.query.filter_by( | |||
| @@ -921,7 +928,7 @@ class DocumentService: | |||
| } | |||
| document = DocumentService.build_document( | |||
| dataset, | |||
| dataset_process_rule.id, | |||
| dataset_process_rule.id, # type: ignore | |||
| knowledge_config.data_source.info_list.data_source_type, | |||
| knowledge_config.doc_form, | |||
| knowledge_config.doc_language, | |||
| @@ -944,6 +951,8 @@ class DocumentService: | |||
| clean_notion_document_task.delay(list(exist_document.values()), dataset.id) | |||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | |||
| website_info = knowledge_config.data_source.info_list.website_info_list | |||
| if not website_info: | |||
| raise ValueError("No website info list found.") | |||
| urls = website_info.urls | |||
| for url in urls: | |||
| data_source_info = { | |||
| @@ -959,7 +968,7 @@ class DocumentService: | |||
| document_name = url | |||
| document = DocumentService.build_document( | |||
| dataset, | |||
| dataset_process_rule.id, | |||
| dataset_process_rule.id, # type: ignore | |||
| knowledge_config.data_source.info_list.data_source_type, | |||
| knowledge_config.doc_form, | |||
| knowledge_config.doc_language, | |||
| @@ -1054,7 +1063,7 @@ class DocumentService: | |||
| dataset_process_rule = DatasetProcessRule( | |||
| dataset_id=dataset.id, | |||
| mode=process_rule.mode, | |||
| rules=process_rule.rules.model_dump_json(), | |||
| rules=process_rule.rules.model_dump_json() if process_rule.rules else None, | |||
| created_by=account.id, | |||
| ) | |||
| elif process_rule.mode == "automatic": | |||
| @@ -1073,6 +1082,8 @@ class DocumentService: | |||
| file_name = "" | |||
| data_source_info = {} | |||
| if document_data.data_source.info_list.data_source_type == "upload_file": | |||
| if not document_data.data_source.info_list.file_info_list: | |||
| raise ValueError("No file info list found.") | |||
| upload_file_list = document_data.data_source.info_list.file_info_list.file_ids | |||
| for file_id in upload_file_list: | |||
| file = ( | |||
| @@ -1090,6 +1101,8 @@ class DocumentService: | |||
| "upload_file_id": file_id, | |||
| } | |||
| elif document_data.data_source.info_list.data_source_type == "notion_import": | |||
| if not document_data.data_source.info_list.notion_info_list: | |||
| raise ValueError("No notion info list found.") | |||
| notion_info_list = document_data.data_source.info_list.notion_info_list | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info.workspace_id | |||
| @@ -1107,20 +1120,21 @@ class DocumentService: | |||
| data_source_info = { | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_page_id": page.page_id, | |||
| "notion_page_icon": page.page_icon, | |||
| "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore | |||
| "type": page.type, | |||
| } | |||
| elif document_data.data_source.info_list.data_source_type == "website_crawl": | |||
| website_info = document_data.data_source.info_list.website_info_list | |||
| urls = website_info.urls | |||
| for url in urls: | |||
| data_source_info = { | |||
| "url": url, | |||
| "provider": website_info.provider, | |||
| "job_id": website_info.job_id, | |||
| "only_main_content": website_info.only_main_content, | |||
| "mode": "crawl", | |||
| } | |||
| if website_info: | |||
| urls = website_info.urls | |||
| for url in urls: | |||
| data_source_info = { | |||
| "url": url, | |||
| "provider": website_info.provider, | |||
| "job_id": website_info.job_id, | |||
| "only_main_content": website_info.only_main_content, # type: ignore | |||
| "mode": "crawl", | |||
| } | |||
| document.data_source_type = document_data.data_source.info_list.data_source_type | |||
| document.data_source_info = json.dumps(data_source_info) | |||
| document.name = file_name | |||
| @@ -1155,15 +1169,21 @@ class DocumentService: | |||
| if features.billing.enabled: | |||
| count = 0 | |||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | |||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids | |||
| upload_file_list = ( | |||
| knowledge_config.data_source.info_list.file_info_list.file_ids | |||
| if knowledge_config.data_source.info_list.file_info_list | |||
| else [] | |||
| ) | |||
| count = len(upload_file_list) | |||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | |||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | |||
| for notion_info in notion_info_list: | |||
| count = count + len(notion_info.pages) | |||
| if notion_info_list: | |||
| for notion_info in notion_info_list: | |||
| count = count + len(notion_info.pages) | |||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | |||
| website_info = knowledge_config.data_source.info_list.website_info_list | |||
| count = len(website_info.urls) | |||
| if website_info: | |||
| count = len(website_info.urls) | |||
| batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) | |||
| if count > batch_upload_limit: | |||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||
| @@ -1174,20 +1194,20 @@ class DocumentService: | |||
| retrieval_model = None | |||
| if knowledge_config.indexing_technique == "high_quality": | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| knowledge_config.embedding_model_provider, knowledge_config.embedding_model | |||
| knowledge_config.embedding_model_provider, # type: ignore | |||
| knowledge_config.embedding_model, # type: ignore | |||
| ) | |||
| dataset_collection_binding_id = dataset_collection_binding.id | |||
| if knowledge_config.retrieval_model: | |||
| retrieval_model = knowledge_config.retrieval_model | |||
| else: | |||
| default_retrieval_model = { | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| "reranking_enable": False, | |||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||
| "top_k": 2, | |||
| "score_threshold_enabled": False, | |||
| } | |||
| retrieval_model = RetrievalModel(**default_retrieval_model) | |||
| retrieval_model = RetrievalModel( | |||
| search_method=RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| reranking_enable=False, | |||
| reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), | |||
| top_k=2, | |||
| score_threshold_enabled=False, | |||
| ) | |||
| # save dataset | |||
| dataset = Dataset( | |||
| tenant_id=tenant_id, | |||
| @@ -1557,12 +1577,12 @@ class SegmentService: | |||
| raise ValueError("Can't update disabled segment") | |||
| try: | |||
| word_count_change = segment.word_count | |||
| content = args.content | |||
| content = args.content or segment.content | |||
| if segment.content == content: | |||
| segment.word_count = len(content) | |||
| if document.doc_form == "qa_model": | |||
| segment.answer = args.answer | |||
| segment.word_count += len(args.answer) | |||
| segment.word_count += len(args.answer) if args.answer else 0 | |||
| word_count_change = segment.word_count - word_count_change | |||
| if args.keywords: | |||
| segment.keywords = args.keywords | |||
| @@ -1577,7 +1597,12 @@ class SegmentService: | |||
| db.session.add(document) | |||
| # update segment index task | |||
| if args.enabled: | |||
| VectorService.create_segments_vector([args.keywords], [segment], dataset) | |||
| VectorService.create_segments_vector( | |||
| [args.keywords] if args.keywords else None, | |||
| [segment], | |||
| dataset, | |||
| document.doc_form, | |||
| ) | |||
| if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: | |||
| # regenerate child chunks | |||
| # get embedding model instance | |||
| @@ -1605,6 +1630,8 @@ class SegmentService: | |||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| if not processing_rule: | |||
| raise ValueError("No processing rule found.") | |||
| VectorService.generate_child_chunks( | |||
| segment, document, dataset, embedding_model_instance, processing_rule, True | |||
| ) | |||
| @@ -1639,7 +1666,7 @@ class SegmentService: | |||
| segment.disabled_by = None | |||
| if document.doc_form == "qa_model": | |||
| segment.answer = args.answer | |||
| segment.word_count += len(args.answer) | |||
| segment.word_count += len(args.answer) if args.answer else 0 | |||
| word_count_change = segment.word_count - word_count_change | |||
| # update document word count | |||
| if word_count_change != 0: | |||
| @@ -1673,6 +1700,8 @@ class SegmentService: | |||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| if not processing_rule: | |||
| raise ValueError("No processing rule found.") | |||
| VectorService.generate_child_chunks( | |||
| segment, document, dataset, embedding_model_instance, processing_rule, True | |||
| ) | |||
| @@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel): | |||
| original_document_id: Optional[str] = None | |||
| duplicate: bool = True | |||
| indexing_technique: Literal["high_quality", "economy"] | |||
| data_source: Optional[DataSource] = None | |||
| data_source: DataSource | |||
| process_rule: Optional[ProcessRule] = None | |||
| retrieval_model: Optional[RetrievalModel] = None | |||
| doc_form: str = "text_model" | |||
| @@ -69,7 +69,7 @@ class HitTestingService: | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| return cls.compact_retrieve_response(query, all_documents) | |||
| return cls.compact_retrieve_response(query, all_documents) # type: ignore | |||
| @classmethod | |||
| def external_retrieve( | |||
| @@ -29,6 +29,8 @@ class VectorService: | |||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| if not processing_rule: | |||
| raise ValueError("No processing rule found.") | |||
| # get embedding model instance | |||
| if dataset.indexing_technique == "high_quality": | |||
| # check embedding model setting | |||
| @@ -98,7 +100,7 @@ class VectorService: | |||
| def generate_child_chunks( | |||
| cls, | |||
| segment: DocumentSegment, | |||
| dataset_document: Document, | |||
| dataset_document: DatasetDocument, | |||
| dataset: Dataset, | |||
| embedding_model_instance: ModelInstance, | |||
| processing_rule: DatasetProcessRule, | |||
| @@ -130,7 +132,7 @@ class VectorService: | |||
| doc_language=dataset_document.doc_language, | |||
| ) | |||
| # save child chunks | |||
| if len(documents) > 0 and len(documents[0].children) > 0: | |||
| if documents and documents[0].children: | |||
| index_processor.load(dataset, documents) | |||
| for position, child_chunk in enumerate(documents[0].children, start=1): | |||
| @@ -44,7 +44,8 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form | |||
| for upload_file_id in image_upload_file_ids: | |||
| image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() | |||
| try: | |||
| storage.delete(image_file.key) | |||
| if image_file and image_file.key: | |||
| storage.delete(image_file.key) | |||
| except Exception: | |||
| logging.exception( | |||
| "Delete image_files failed when storage deleted, \ | |||