Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>tags/0.15.0
| click.echo(click.style("Starting database migration.", fg="green")) | click.echo(click.style("Starting database migration.", fg="green")) | ||||
| # run db migration | # run db migration | ||||
| import flask_migrate | |||||
| import flask_migrate # type: ignore | |||||
| flask_migrate.upgrade() | flask_migrate.upgrade() | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| try: | try: | ||||
| response = indexing_runner.indexing_estimate( | |||||
| estimate_response = indexing_runner.indexing_estimate( | |||||
| current_user.current_tenant_id, | current_user.current_tenant_id, | ||||
| [extract_setting], | [extract_setting], | ||||
| data_process_rule_dict, | data_process_rule_dict, | ||||
| "English", | "English", | ||||
| dataset_id, | dataset_id, | ||||
| ) | ) | ||||
| return estimate_response.model_dump(), 200 | |||||
| except LLMBadRequestError: | except LLMBadRequestError: | ||||
| raise ProviderNotInitializeError( | raise ProviderNotInitializeError( | ||||
| "No Embedding Model available. Please configure a valid provider " | "No Embedding Model available. Please configure a valid provider " | ||||
| except Exception as e: | except Exception as e: | ||||
| raise IndexingEstimateError(str(e)) | raise IndexingEstimateError(str(e)) | ||||
| return response.model_dump(), 200 | |||||
| return response, 200 | |||||
| class DocumentBatchIndexingEstimateApi(DocumentResource): | class DocumentBatchIndexingEstimateApi(DocumentResource): | ||||
| "English", | "English", | ||||
| dataset_id, | dataset_id, | ||||
| ) | ) | ||||
| return response.model_dump(), 200 | |||||
| except LLMBadRequestError: | except LLMBadRequestError: | ||||
| raise ProviderNotInitializeError( | raise ProviderNotInitializeError( | ||||
| "No Embedding Model available. Please configure a valid provider " | "No Embedding Model available. Please configure a valid provider " | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except Exception as e: | except Exception as e: | ||||
| raise IndexingEstimateError(str(e)) | raise IndexingEstimateError(str(e)) | ||||
| return response.model_dump(), 200 | |||||
| class DocumentBatchIndexingStatusApi(DocumentResource): | class DocumentBatchIndexingStatusApi(DocumentResource): |
| from libs.login import current_user | from libs.login import current_user | ||||
| from models.dataset import Dataset, Document, DocumentSegment | from models.dataset import Dataset, Document, DocumentSegment | ||||
| from services.dataset_service import DocumentService | from services.dataset_service import DocumentService | ||||
| from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig | |||||
| from services.file_service import FileService | from services.file_service import FileService | ||||
| "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | ||||
| } | } | ||||
| args["data_source"] = data_source | args["data_source"] = data_source | ||||
| knowledge_config = KnowledgeConfig(**args) | |||||
| # validate args | # validate args | ||||
| DocumentService.document_create_args_validate(args) | |||||
| DocumentService.document_create_args_validate(knowledge_config) | |||||
| try: | try: | ||||
| documents, batch = DocumentService.save_document_with_dataset_id( | documents, batch = DocumentService.save_document_with_dataset_id( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| document_data=args, | |||||
| knowledge_config=knowledge_config, | |||||
| account=current_user, | account=current_user, | ||||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | ||||
| created_from="api", | created_from="api", | ||||
| args["data_source"] = data_source | args["data_source"] = data_source | ||||
| # validate args | # validate args | ||||
| args["original_document_id"] = str(document_id) | args["original_document_id"] = str(document_id) | ||||
| DocumentService.document_create_args_validate(args) | |||||
| knowledge_config = KnowledgeConfig(**args) | |||||
| DocumentService.document_create_args_validate(knowledge_config) | |||||
| try: | try: | ||||
| documents, batch = DocumentService.save_document_with_dataset_id( | documents, batch = DocumentService.save_document_with_dataset_id( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| document_data=args, | |||||
| knowledge_config=knowledge_config, | |||||
| account=current_user, | account=current_user, | ||||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | ||||
| created_from="api", | created_from="api", | ||||
| data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | ||||
| args["data_source"] = data_source | args["data_source"] = data_source | ||||
| # validate args | # validate args | ||||
| DocumentService.document_create_args_validate(args) | |||||
| knowledge_config = KnowledgeConfig(**args) | |||||
| DocumentService.document_create_args_validate(knowledge_config) | |||||
| try: | try: | ||||
| documents, batch = DocumentService.save_document_with_dataset_id( | documents, batch = DocumentService.save_document_with_dataset_id( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| document_data=args, | |||||
| knowledge_config=knowledge_config, | |||||
| account=dataset.created_by_account, | account=dataset.created_by_account, | ||||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | ||||
| created_from="api", | created_from="api", | ||||
| args["data_source"] = data_source | args["data_source"] = data_source | ||||
| # validate args | # validate args | ||||
| args["original_document_id"] = str(document_id) | args["original_document_id"] = str(document_id) | ||||
| DocumentService.document_create_args_validate(args) | |||||
| knowledge_config = KnowledgeConfig(**args) | |||||
| DocumentService.document_create_args_validate(knowledge_config) | |||||
| try: | try: | ||||
| documents, batch = DocumentService.save_document_with_dataset_id( | documents, batch = DocumentService.save_document_with_dataset_id( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| document_data=args, | |||||
| knowledge_config=knowledge_config, | |||||
| account=dataset.created_by_account, | account=dataset.created_by_account, | ||||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | ||||
| created_from="api", | created_from="api", |
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_type=ModelType.TEXT_EMBEDDING, | model_type=ModelType.TEXT_EMBEDDING, | ||||
| ) | ) | ||||
| preview_texts = [] | |||||
| preview_texts = [] # type: ignore | |||||
| total_segments = 0 | total_segments = 0 | ||||
| index_type = doc_form | index_type = doc_form | ||||
| if len(preview_texts) < 10: | if len(preview_texts) < 10: | ||||
| if doc_form and doc_form == "qa_model": | if doc_form and doc_form == "qa_model": | ||||
| preview_detail = QAPreviewDetail( | preview_detail = QAPreviewDetail( | ||||
| question=document.page_content, answer=document.metadata.get("answer") | |||||
| question=document.page_content, answer=document.metadata.get("answer") or "" | |||||
| ) | ) | ||||
| preview_texts.append(preview_detail) | preview_texts.append(preview_detail) | ||||
| else: | else: | ||||
| preview_detail = PreviewDetail(content=document.page_content) | |||||
| preview_detail = PreviewDetail(content=document.page_content) # type: ignore | |||||
| if document.children: | if document.children: | ||||
| preview_detail.child_chunks = [child.page_content for child in document.children] | |||||
| preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore | |||||
| preview_texts.append(preview_detail) | preview_texts.append(preview_detail) | ||||
| # delete image files and related db records | # delete image files and related db records | ||||
| if doc_form and doc_form == "qa_model": | if doc_form and doc_form == "qa_model": | ||||
| return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) | return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) | ||||
| return IndexingEstimate(total_segments=total_segments, preview=preview_texts) | |||||
| return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore | |||||
| def _extract( | def _extract( | ||||
| self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict | self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict | ||||
| embedding_model_instance=embedding_model_instance, | embedding_model_instance=embedding_model_instance, | ||||
| ) | ) | ||||
| return character_splitter | |||||
| return character_splitter # type: ignore | |||||
| def _split_to_documents_for_estimate( | def _split_to_documents_for_estimate( | ||||
| self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule | self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule | ||||
| # create keyword index | # create keyword index | ||||
| create_keyword_thread = threading.Thread( | create_keyword_thread = threading.Thread( | ||||
| target=self._process_keyword_index, | target=self._process_keyword_index, | ||||
| args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), | |||||
| args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore | |||||
| ) | ) | ||||
| create_keyword_thread.start() | create_keyword_thread.start() | ||||
| include_segment_ids = [] | include_segment_ids = [] | ||||
| segment_child_map = {} | segment_child_map = {} | ||||
| for document in documents: | for document in documents: | ||||
| document_id = document.metadata["document_id"] | |||||
| document_id = document.metadata.get("document_id") | |||||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | ||||
| if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| child_index_node_id = document.metadata["doc_id"] | |||||
| result = ( | |||||
| db.session.query(ChildChunk, DocumentSegment) | |||||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||||
| .filter( | |||||
| ChildChunk.index_node_id == child_index_node_id, | |||||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| if dataset_document: | |||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| child_index_node_id = document.metadata.get("doc_id") | |||||
| result = ( | |||||
| db.session.query(ChildChunk, DocumentSegment) | |||||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||||
| .filter( | |||||
| ChildChunk.index_node_id == child_index_node_id, | |||||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| ) | |||||
| .first() | |||||
| ) | ) | ||||
| .first() | |||||
| ) | |||||
| if result: | |||||
| child_chunk, segment = result | |||||
| if not segment: | |||||
| continue | |||||
| if segment.id not in include_segment_ids: | |||||
| include_segment_ids.append(segment.id) | |||||
| child_chunk_detail = { | |||||
| "id": child_chunk.id, | |||||
| "content": child_chunk.content, | |||||
| "position": child_chunk.position, | |||||
| "score": document.metadata.get("score", 0.0), | |||||
| } | |||||
| map_detail = { | |||||
| "max_score": document.metadata.get("score", 0.0), | |||||
| "child_chunks": [child_chunk_detail], | |||||
| } | |||||
| segment_child_map[segment.id] = map_detail | |||||
| record = { | |||||
| "segment": segment, | |||||
| } | |||||
| records.append(record) | |||||
| if result: | |||||
| child_chunk, segment = result | |||||
| if not segment: | |||||
| continue | |||||
| if segment.id not in include_segment_ids: | |||||
| include_segment_ids.append(segment.id) | |||||
| child_chunk_detail = { | |||||
| "id": child_chunk.id, | |||||
| "content": child_chunk.content, | |||||
| "position": child_chunk.position, | |||||
| "score": document.metadata.get("score", 0.0), | |||||
| } | |||||
| map_detail = { | |||||
| "max_score": document.metadata.get("score", 0.0), | |||||
| "child_chunks": [child_chunk_detail], | |||||
| } | |||||
| segment_child_map[segment.id] = map_detail | |||||
| record = { | |||||
| "segment": segment, | |||||
| } | |||||
| records.append(record) | |||||
| else: | |||||
| child_chunk_detail = { | |||||
| "id": child_chunk.id, | |||||
| "content": child_chunk.content, | |||||
| "position": child_chunk.position, | |||||
| "score": document.metadata.get("score", 0.0), | |||||
| } | |||||
| segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | |||||
| segment_child_map[segment.id]["max_score"] = max( | |||||
| segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | |||||
| ) | |||||
| else: | else: | ||||
| child_chunk_detail = { | |||||
| "id": child_chunk.id, | |||||
| "content": child_chunk.content, | |||||
| "position": child_chunk.position, | |||||
| "score": document.metadata.get("score", 0.0), | |||||
| } | |||||
| segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | |||||
| segment_child_map[segment.id]["max_score"] = max( | |||||
| segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | |||||
| ) | |||||
| continue | |||||
| else: | else: | ||||
| continue | |||||
| else: | |||||
| index_node_id = document.metadata["doc_id"] | |||||
| index_node_id = document.metadata["doc_id"] | |||||
| segment = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .filter( | |||||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.index_node_id == index_node_id, | |||||
| segment = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .filter( | |||||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.index_node_id == index_node_id, | |||||
| ) | |||||
| .first() | |||||
| ) | ) | ||||
| .first() | |||||
| ) | |||||
| if not segment: | |||||
| continue | |||||
| include_segment_ids.append(segment.id) | |||||
| record = { | |||||
| "segment": segment, | |||||
| "score": document.metadata.get("score", None), | |||||
| } | |||||
| if not segment: | |||||
| continue | |||||
| include_segment_ids.append(segment.id) | |||||
| record = { | |||||
| "segment": segment, | |||||
| "score": document.metadata.get("score", None), | |||||
| } | |||||
| records.append(record) | |||||
| records.append(record) | |||||
| for record in records: | for record in records: | ||||
| if record["segment"].id in segment_child_map: | if record["segment"].id in segment_child_map: | ||||
| record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) | record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) |
| db.session.add(segment_document) | db.session.add(segment_document) | ||||
| db.session.flush() | db.session.flush() | ||||
| if save_child: | if save_child: | ||||
| for postion, child in enumerate(doc.children, start=1): | |||||
| child_segment = ChildChunk( | |||||
| tenant_id=self._dataset.tenant_id, | |||||
| dataset_id=self._dataset.id, | |||||
| document_id=self._document_id, | |||||
| segment_id=segment_document.id, | |||||
| position=postion, | |||||
| index_node_id=child.metadata["doc_id"], | |||||
| index_node_hash=child.metadata["doc_hash"], | |||||
| content=child.page_content, | |||||
| word_count=len(child.page_content), | |||||
| type="automatic", | |||||
| created_by=self._user_id, | |||||
| ) | |||||
| db.session.add(child_segment) | |||||
| if doc.children: | |||||
| for postion, child in enumerate(doc.children, start=1): | |||||
| child_segment = ChildChunk( | |||||
| tenant_id=self._dataset.tenant_id, | |||||
| dataset_id=self._dataset.id, | |||||
| document_id=self._document_id, | |||||
| segment_id=segment_document.id, | |||||
| position=postion, | |||||
| index_node_id=child.metadata.get("doc_id"), | |||||
| index_node_hash=child.metadata.get("doc_hash"), | |||||
| content=child.page_content, | |||||
| word_count=len(child.page_content), | |||||
| type="automatic", | |||||
| created_by=self._user_id, | |||||
| ) | |||||
| db.session.add(child_segment) | |||||
| else: | else: | ||||
| segment_document.content = doc.page_content | segment_document.content = doc.page_content | ||||
| if doc.metadata.get("answer"): | if doc.metadata.get("answer"): | ||||
| segment_document.answer = doc.metadata.pop("answer", "") | segment_document.answer = doc.metadata.pop("answer", "") | ||||
| segment_document.index_node_hash = doc.metadata["doc_hash"] | |||||
| segment_document.index_node_hash = doc.metadata.get("doc_hash") | |||||
| segment_document.word_count = len(doc.page_content) | segment_document.word_count = len(doc.page_content) | ||||
| segment_document.tokens = tokens | segment_document.tokens = tokens | ||||
| if save_child and doc.children: | if save_child and doc.children: | ||||
| document_id=self._document_id, | document_id=self._document_id, | ||||
| segment_id=segment_document.id, | segment_id=segment_document.id, | ||||
| position=position, | position=position, | ||||
| index_node_id=child.metadata["doc_id"], | |||||
| index_node_hash=child.metadata["doc_hash"], | |||||
| index_node_id=child.metadata.get("doc_id"), | |||||
| index_node_hash=child.metadata.get("doc_hash"), | |||||
| content=child.page_content, | content=child.page_content, | ||||
| word_count=len(child.page_content), | word_count=len(child.page_content), | ||||
| type="automatic", | type="automatic", |
| from typing import Optional, cast | from typing import Optional, cast | ||||
| import pandas as pd | import pandas as pd | ||||
| from openpyxl import load_workbook | |||||
| from openpyxl import load_workbook # type: ignore | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | from core.rag.extractor.extractor_base import BaseExtractor | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document |
| embedding_model_instance=embedding_model_instance, | embedding_model_instance=embedding_model_instance, | ||||
| ) | ) | ||||
| return character_splitter | |||||
| return character_splitter # type: ignore |
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | def transform(self, documents: list[Document], **kwargs) -> list[Document]: | ||||
| process_rule = kwargs.get("process_rule") | process_rule = kwargs.get("process_rule") | ||||
| if not process_rule: | |||||
| raise ValueError("No process rule found.") | |||||
| if process_rule.get("mode") == "automatic": | if process_rule.get("mode") == "automatic": | ||||
| automatic_rule = DatasetProcessRule.AUTOMATIC_RULES | automatic_rule = DatasetProcessRule.AUTOMATIC_RULES | ||||
| rules = Rule(**automatic_rule) | rules = Rule(**automatic_rule) | ||||
| else: | else: | ||||
| if not process_rule.get("rules"): | |||||
| raise ValueError("No rules found in process rule.") | |||||
| rules = Rule(**process_rule.get("rules")) | rules = Rule(**process_rule.get("rules")) | ||||
| # Split the text documents into nodes. | # Split the text documents into nodes. | ||||
| if not rules.segmentation: | |||||
| raise ValueError("No segmentation found in rules.") | |||||
| splitter = self._get_splitter( | splitter = self._get_splitter( | ||||
| processing_rule_mode=process_rule.get("mode"), | processing_rule_mode=process_rule.get("mode"), | ||||
| max_tokens=rules.segmentation.max_tokens, | max_tokens=rules.segmentation.max_tokens, |
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | def transform(self, documents: list[Document], **kwargs) -> list[Document]: | ||||
| process_rule = kwargs.get("process_rule") | process_rule = kwargs.get("process_rule") | ||||
| if not process_rule: | |||||
| raise ValueError("No process rule found.") | |||||
| if not process_rule.get("rules"): | |||||
| raise ValueError("No rules found in process rule.") | |||||
| rules = Rule(**process_rule.get("rules")) | rules = Rule(**process_rule.get("rules")) | ||||
| all_documents = [] | |||||
| all_documents = [] # type: ignore | |||||
| if rules.parent_mode == ParentMode.PARAGRAPH: | if rules.parent_mode == ParentMode.PARAGRAPH: | ||||
| # Split the text documents into nodes. | # Split the text documents into nodes. | ||||
| splitter = self._get_splitter( | splitter = self._get_splitter( | ||||
| process_rule_mode: str, | process_rule_mode: str, | ||||
| embedding_model_instance: Optional[ModelInstance], | embedding_model_instance: Optional[ModelInstance], | ||||
| ) -> list[ChildDocument]: | ) -> list[ChildDocument]: | ||||
| if not rules.subchunk_segmentation: | |||||
| raise ValueError("No subchunk segmentation found in rules.") | |||||
| child_splitter = self._get_splitter( | child_splitter = self._get_splitter( | ||||
| processing_rule_mode=process_rule_mode, | processing_rule_mode=process_rule_mode, | ||||
| max_tokens=rules.subchunk_segmentation.max_tokens, | max_tokens=rules.subchunk_segmentation.max_tokens, |
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | def transform(self, documents: list[Document], **kwargs) -> list[Document]: | ||||
| preview = kwargs.get("preview") | preview = kwargs.get("preview") | ||||
| process_rule = kwargs.get("process_rule") | process_rule = kwargs.get("process_rule") | ||||
| if not process_rule: | |||||
| raise ValueError("No process rule found.") | |||||
| if not process_rule.get("rules"): | |||||
| raise ValueError("No rules found in process rule.") | |||||
| rules = Rule(**process_rule.get("rules")) | rules = Rule(**process_rule.get("rules")) | ||||
| splitter = self._get_splitter( | splitter = self._get_splitter( | ||||
| processing_rule_mode=process_rule.get("mode"), | processing_rule_mode=process_rule.get("mode"), | ||||
| max_tokens=rules.segmentation.max_tokens, | |||||
| chunk_overlap=rules.segmentation.chunk_overlap, | |||||
| separator=rules.segmentation.separator, | |||||
| max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, | |||||
| chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0, | |||||
| separator=rules.segmentation.separator if rules.segmentation else "", | |||||
| embedding_model_instance=kwargs.get("embedding_model_instance"), | embedding_model_instance=kwargs.get("embedding_model_instance"), | ||||
| ) | ) | ||||
| all_documents.extend(split_documents) | all_documents.extend(split_documents) | ||||
| if preview: | if preview: | ||||
| self._format_qa_document( | self._format_qa_document( | ||||
| current_app._get_current_object(), | |||||
| kwargs.get("tenant_id"), | |||||
| current_app._get_current_object(), # type: ignore | |||||
| kwargs.get("tenant_id"), # type: ignore | |||||
| all_documents[0], | all_documents[0], | ||||
| all_qa_documents, | all_qa_documents, | ||||
| kwargs.get("doc_language", "English"), | kwargs.get("doc_language", "English"), | ||||
| document_format_thread = threading.Thread( | document_format_thread = threading.Thread( | ||||
| target=self._format_qa_document, | target=self._format_qa_document, | ||||
| kwargs={ | kwargs={ | ||||
| "flask_app": current_app._get_current_object(), | |||||
| "tenant_id": kwargs.get("tenant_id"), | |||||
| "flask_app": current_app._get_current_object(), # type: ignore | |||||
| "tenant_id": kwargs.get("tenant_id"), # type: ignore | |||||
| "document_node": doc, | "document_node": doc, | ||||
| "all_qa_documents": all_qa_documents, | "all_qa_documents": all_qa_documents, | ||||
| "document_language": kwargs.get("doc_language", "English"), | "document_language": kwargs.get("doc_language", "English"), |
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from pydantic import BaseModel, Field | |||||
| from pydantic import BaseModel | |||||
| class ChildDocument(BaseModel): | class ChildDocument(BaseModel): | ||||
| """Arbitrary metadata about the page content (e.g., source, relationships to other | """Arbitrary metadata about the page content (e.g., source, relationships to other | ||||
| documents, etc.). | documents, etc.). | ||||
| """ | """ | ||||
| metadata: Optional[dict] = Field(default_factory=dict) | |||||
| metadata: dict = {} | |||||
| class Document(BaseModel): | class Document(BaseModel): | ||||
| """Arbitrary metadata about the page content (e.g., source, relationships to other | """Arbitrary metadata about the page content (e.g., source, relationships to other | ||||
| documents, etc.). | documents, etc.). | ||||
| """ | """ | ||||
| metadata: Optional[dict] = Field(default_factory=dict) | |||||
| metadata: dict = {} | |||||
| provider: Optional[str] = "dify" | provider: Optional[str] = "dify" | ||||
| def init_app(app: DifyApp): | def init_app(app: DifyApp): | ||||
| # register blueprint routers | # register blueprint routers | ||||
| from flask_cors import CORS | |||||
| from flask_cors import CORS # type: ignore | |||||
| from controllers.console import bp as console_app_bp | from controllers.console import bp as console_app_bp | ||||
| from controllers.files import bp as files_bp | from controllers.files import bp as files_bp |
| import logging | import logging | ||||
| import time | import time | ||||
| from collections import defaultdict | |||||
| import click | import click | ||||
| from celery import shared_task # type: ignore | from celery import shared_task # type: ignore | ||||
| from flask import render_template | |||||
| from extensions.ext_mail import mail | from extensions.ext_mail import mail | ||||
| from models.account import Account, Tenant, TenantAccountJoin | from models.account import Account, Tenant, TenantAccountJoin | ||||
| try: | try: | ||||
| dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() | dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() | ||||
| # group by tenant_id | # group by tenant_id | ||||
| dataset_auto_disable_logs_map = {} | |||||
| dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) | |||||
| for dataset_auto_disable_log in dataset_auto_disable_logs: | for dataset_auto_disable_log in dataset_auto_disable_logs: | ||||
| dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) | dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) | ||||
| if not tenant: | if not tenant: | ||||
| continue | continue | ||||
| current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() | current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() | ||||
| if not current_owner_join: | |||||
| continue | |||||
| account = Account.query.filter(Account.id == current_owner_join.account_id).first() | account = Account.query.filter(Account.id == current_owner_join.account_id).first() | ||||
| if not account: | if not account: | ||||
| continue | continue | ||||
| dataset_auto_dataset_map = {} | |||||
| dataset_auto_dataset_map = {} # type: ignore | |||||
| for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: | for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: | ||||
| dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( | dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( | ||||
| dataset_auto_disable_log.document_id | dataset_auto_disable_log.document_id | ||||
| document_count = len(document_ids) | document_count = len(document_ids) | ||||
| knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>") | knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>") | ||||
| html_content = render_template( | |||||
| "clean_document_job_mail_template-US.html", | |||||
| ) | |||||
| mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info( | logging.info( | ||||
| click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") | click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") | ||||
| ) | ) | ||||
| except Exception: | except Exception: | ||||
| logging.exception("Send invite member mail to {} failed".format(to)) | |||||
| logging.exception("Send invite member mail to failed") |
| from typing import Optional, cast | from typing import Optional, cast | ||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| import yaml | |||||
| import yaml # type: ignore | |||||
| from packaging import version | from packaging import version | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from sqlalchemy import select | from sqlalchemy import select | ||||
| else: | else: | ||||
| cls._append_model_config_export_data(export_data, app_model) | cls._append_model_config_export_data(export_data, app_model) | ||||
| return yaml.dump(export_data, allow_unicode=True) | |||||
| return yaml.dump(export_data, allow_unicode=True) # type: ignore | |||||
| @classmethod | @classmethod | ||||
| def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: | def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: |
| from services.entities.knowledge_entities.knowledge_entities import ( | from services.entities.knowledge_entities.knowledge_entities import ( | ||||
| ChildChunkUpdateArgs, | ChildChunkUpdateArgs, | ||||
| KnowledgeConfig, | KnowledgeConfig, | ||||
| RerankingModel, | |||||
| RetrievalModel, | RetrievalModel, | ||||
| SegmentUpdateArgs, | SegmentUpdateArgs, | ||||
| ) | ) | ||||
| } | } | ||||
| @staticmethod | @staticmethod | ||||
| def get_document(dataset_id: str, document_id: str) -> Optional[Document]: | |||||
| document = ( | |||||
| db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() | |||||
| ) | |||||
| return document | |||||
| def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]: | |||||
| if document_id: | |||||
| document = ( | |||||
| db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() | |||||
| ) | |||||
| return document | |||||
| else: | |||||
| return None | |||||
| @staticmethod | @staticmethod | ||||
| def get_document_by_id(document_id: str) -> Optional[Document]: | def get_document_by_id(document_id: str) -> Optional[Document]: | ||||
| if features.billing.enabled: | if features.billing.enabled: | ||||
| if not knowledge_config.original_document_id: | if not knowledge_config.original_document_id: | ||||
| count = 0 | count = 0 | ||||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | |||||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids | |||||
| count = len(upload_file_list) | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | |||||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | |||||
| for notion_info in notion_info_list: | |||||
| count = count + len(notion_info.pages) | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | |||||
| website_info = knowledge_config.data_source.info_list.website_info_list | |||||
| count = len(website_info.urls) | |||||
| batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) | |||||
| if count > batch_upload_limit: | |||||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||||
| DocumentService.check_documents_upload_quota(count, features) | |||||
| if knowledge_config.data_source: | |||||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | |||||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore | |||||
| count = len(upload_file_list) | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | |||||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | |||||
| for notion_info in notion_info_list: # type: ignore | |||||
| count = count + len(notion_info.pages) | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | |||||
| website_info = knowledge_config.data_source.info_list.website_info_list | |||||
| count = len(website_info.urls) # type: ignore | |||||
| batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) | |||||
| if count > batch_upload_limit: | |||||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||||
| DocumentService.check_documents_upload_quota(count, features) | |||||
| # if dataset is empty, update dataset data_source_type | # if dataset is empty, update dataset data_source_type | ||||
| if not dataset.data_source_type: | if not dataset.data_source_type: | ||||
| dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type | |||||
| dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore | |||||
| if not dataset.indexing_technique: | if not dataset.indexing_technique: | ||||
| if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: | if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: | ||||
| "score_threshold_enabled": False, | "score_threshold_enabled": False, | ||||
| } | } | ||||
| dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model | |||||
| dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore | |||||
| documents = [] | documents = [] | ||||
| if knowledge_config.original_document_id: | if knowledge_config.original_document_id: | ||||
| # save process rule | # save process rule | ||||
| if not dataset_process_rule: | if not dataset_process_rule: | ||||
| process_rule = knowledge_config.process_rule | process_rule = knowledge_config.process_rule | ||||
| if process_rule.mode in ("custom", "hierarchical"): | |||||
| dataset_process_rule = DatasetProcessRule( | |||||
| dataset_id=dataset.id, | |||||
| mode=process_rule.mode, | |||||
| rules=process_rule.rules.model_dump_json(), | |||||
| created_by=account.id, | |||||
| ) | |||||
| elif process_rule.mode == "automatic": | |||||
| dataset_process_rule = DatasetProcessRule( | |||||
| dataset_id=dataset.id, | |||||
| mode=process_rule.mode, | |||||
| rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), | |||||
| created_by=account.id, | |||||
| ) | |||||
| else: | |||||
| logging.warn( | |||||
| f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule" | |||||
| ) | |||||
| return | |||||
| db.session.add(dataset_process_rule) | |||||
| db.session.commit() | |||||
| if process_rule: | |||||
| if process_rule.mode in ("custom", "hierarchical"): | |||||
| dataset_process_rule = DatasetProcessRule( | |||||
| dataset_id=dataset.id, | |||||
| mode=process_rule.mode, | |||||
| rules=process_rule.rules.model_dump_json() if process_rule.rules else None, | |||||
| created_by=account.id, | |||||
| ) | |||||
| elif process_rule.mode == "automatic": | |||||
| dataset_process_rule = DatasetProcessRule( | |||||
| dataset_id=dataset.id, | |||||
| mode=process_rule.mode, | |||||
| rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), | |||||
| created_by=account.id, | |||||
| ) | |||||
| else: | |||||
| logging.warn( | |||||
| f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" | |||||
| ) | |||||
| return | |||||
| db.session.add(dataset_process_rule) | |||||
| db.session.commit() | |||||
| lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) | lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) | ||||
| with redis_client.lock(lock_name, timeout=600): | with redis_client.lock(lock_name, timeout=600): | ||||
| position = DocumentService.get_documents_position(dataset.id) | position = DocumentService.get_documents_position(dataset.id) | ||||
| document_ids = [] | document_ids = [] | ||||
| duplicate_document_ids = [] | duplicate_document_ids = [] | ||||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | if knowledge_config.data_source.info_list.data_source_type == "upload_file": | ||||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids | |||||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore | |||||
| for file_id in upload_file_list: | for file_id in upload_file_list: | ||||
| file = ( | file = ( | ||||
| db.session.query(UploadFile) | db.session.query(UploadFile) | ||||
| name=file_name, | name=file_name, | ||||
| ).first() | ).first() | ||||
| if document: | if document: | ||||
| document.dataset_process_rule_id = dataset_process_rule.id | |||||
| document.dataset_process_rule_id = dataset_process_rule.id # type: ignore | |||||
| document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | ||||
| document.created_from = created_from | document.created_from = created_from | ||||
| document.doc_form = knowledge_config.doc_form | document.doc_form = knowledge_config.doc_form | ||||
| continue | continue | ||||
| document = DocumentService.build_document( | document = DocumentService.build_document( | ||||
| dataset, | dataset, | ||||
| dataset_process_rule.id, | |||||
| dataset_process_rule.id, # type: ignore | |||||
| knowledge_config.data_source.info_list.data_source_type, | knowledge_config.data_source.info_list.data_source_type, | ||||
| knowledge_config.doc_form, | knowledge_config.doc_form, | ||||
| knowledge_config.doc_language, | knowledge_config.doc_language, | ||||
| position += 1 | position += 1 | ||||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | ||||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | notion_info_list = knowledge_config.data_source.info_list.notion_info_list | ||||
| if not notion_info_list: | |||||
| raise ValueError("No notion info list found.") | |||||
| exist_page_ids = [] | exist_page_ids = [] | ||||
| exist_document = {} | exist_document = {} | ||||
| documents = Document.query.filter_by( | documents = Document.query.filter_by( | ||||
| } | } | ||||
| document = DocumentService.build_document( | document = DocumentService.build_document( | ||||
| dataset, | dataset, | ||||
| dataset_process_rule.id, | |||||
| dataset_process_rule.id, # type: ignore | |||||
| knowledge_config.data_source.info_list.data_source_type, | knowledge_config.data_source.info_list.data_source_type, | ||||
| knowledge_config.doc_form, | knowledge_config.doc_form, | ||||
| knowledge_config.doc_language, | knowledge_config.doc_language, | ||||
| clean_notion_document_task.delay(list(exist_document.values()), dataset.id) | clean_notion_document_task.delay(list(exist_document.values()), dataset.id) | ||||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | ||||
| website_info = knowledge_config.data_source.info_list.website_info_list | website_info = knowledge_config.data_source.info_list.website_info_list | ||||
| if not website_info: | |||||
| raise ValueError("No website info list found.") | |||||
| urls = website_info.urls | urls = website_info.urls | ||||
| for url in urls: | for url in urls: | ||||
| data_source_info = { | data_source_info = { | ||||
| document_name = url | document_name = url | ||||
| document = DocumentService.build_document( | document = DocumentService.build_document( | ||||
| dataset, | dataset, | ||||
| dataset_process_rule.id, | |||||
| dataset_process_rule.id, # type: ignore | |||||
| knowledge_config.data_source.info_list.data_source_type, | knowledge_config.data_source.info_list.data_source_type, | ||||
| knowledge_config.doc_form, | knowledge_config.doc_form, | ||||
| knowledge_config.doc_language, | knowledge_config.doc_language, | ||||
| dataset_process_rule = DatasetProcessRule( | dataset_process_rule = DatasetProcessRule( | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, | ||||
| mode=process_rule.mode, | mode=process_rule.mode, | ||||
| rules=process_rule.rules.model_dump_json(), | |||||
| rules=process_rule.rules.model_dump_json() if process_rule.rules else None, | |||||
| created_by=account.id, | created_by=account.id, | ||||
| ) | ) | ||||
| elif process_rule.mode == "automatic": | elif process_rule.mode == "automatic": | ||||
| file_name = "" | file_name = "" | ||||
| data_source_info = {} | data_source_info = {} | ||||
| if document_data.data_source.info_list.data_source_type == "upload_file": | if document_data.data_source.info_list.data_source_type == "upload_file": | ||||
| if not document_data.data_source.info_list.file_info_list: | |||||
| raise ValueError("No file info list found.") | |||||
| upload_file_list = document_data.data_source.info_list.file_info_list.file_ids | upload_file_list = document_data.data_source.info_list.file_info_list.file_ids | ||||
| for file_id in upload_file_list: | for file_id in upload_file_list: | ||||
| file = ( | file = ( | ||||
| "upload_file_id": file_id, | "upload_file_id": file_id, | ||||
| } | } | ||||
| elif document_data.data_source.info_list.data_source_type == "notion_import": | elif document_data.data_source.info_list.data_source_type == "notion_import": | ||||
| if not document_data.data_source.info_list.notion_info_list: | |||||
| raise ValueError("No notion info list found.") | |||||
| notion_info_list = document_data.data_source.info_list.notion_info_list | notion_info_list = document_data.data_source.info_list.notion_info_list | ||||
| for notion_info in notion_info_list: | for notion_info in notion_info_list: | ||||
| workspace_id = notion_info.workspace_id | workspace_id = notion_info.workspace_id | ||||
| data_source_info = { | data_source_info = { | ||||
| "notion_workspace_id": workspace_id, | "notion_workspace_id": workspace_id, | ||||
| "notion_page_id": page.page_id, | "notion_page_id": page.page_id, | ||||
| "notion_page_icon": page.page_icon, | |||||
| "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore | |||||
| "type": page.type, | "type": page.type, | ||||
| } | } | ||||
| elif document_data.data_source.info_list.data_source_type == "website_crawl": | elif document_data.data_source.info_list.data_source_type == "website_crawl": | ||||
| website_info = document_data.data_source.info_list.website_info_list | website_info = document_data.data_source.info_list.website_info_list | ||||
| urls = website_info.urls | |||||
| for url in urls: | |||||
| data_source_info = { | |||||
| "url": url, | |||||
| "provider": website_info.provider, | |||||
| "job_id": website_info.job_id, | |||||
| "only_main_content": website_info.only_main_content, | |||||
| "mode": "crawl", | |||||
| } | |||||
| if website_info: | |||||
| urls = website_info.urls | |||||
| for url in urls: | |||||
| data_source_info = { | |||||
| "url": url, | |||||
| "provider": website_info.provider, | |||||
| "job_id": website_info.job_id, | |||||
| "only_main_content": website_info.only_main_content, # type: ignore | |||||
| "mode": "crawl", | |||||
| } | |||||
| document.data_source_type = document_data.data_source.info_list.data_source_type | document.data_source_type = document_data.data_source.info_list.data_source_type | ||||
| document.data_source_info = json.dumps(data_source_info) | document.data_source_info = json.dumps(data_source_info) | ||||
| document.name = file_name | document.name = file_name | ||||
| if features.billing.enabled: | if features.billing.enabled: | ||||
| count = 0 | count = 0 | ||||
| if knowledge_config.data_source.info_list.data_source_type == "upload_file": | if knowledge_config.data_source.info_list.data_source_type == "upload_file": | ||||
| upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids | |||||
| upload_file_list = ( | |||||
| knowledge_config.data_source.info_list.file_info_list.file_ids | |||||
| if knowledge_config.data_source.info_list.file_info_list | |||||
| else [] | |||||
| ) | |||||
| count = len(upload_file_list) | count = len(upload_file_list) | ||||
| elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | elif knowledge_config.data_source.info_list.data_source_type == "notion_import": | ||||
| notion_info_list = knowledge_config.data_source.info_list.notion_info_list | notion_info_list = knowledge_config.data_source.info_list.notion_info_list | ||||
| for notion_info in notion_info_list: | |||||
| count = count + len(notion_info.pages) | |||||
| if notion_info_list: | |||||
| for notion_info in notion_info_list: | |||||
| count = count + len(notion_info.pages) | |||||
| elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": | ||||
| website_info = knowledge_config.data_source.info_list.website_info_list | website_info = knowledge_config.data_source.info_list.website_info_list | ||||
| count = len(website_info.urls) | |||||
| if website_info: | |||||
| count = len(website_info.urls) | |||||
| batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) | batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) | ||||
| if count > batch_upload_limit: | if count > batch_upload_limit: | ||||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | ||||
| retrieval_model = None | retrieval_model = None | ||||
| if knowledge_config.indexing_technique == "high_quality": | if knowledge_config.indexing_technique == "high_quality": | ||||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | ||||
| knowledge_config.embedding_model_provider, knowledge_config.embedding_model | |||||
| knowledge_config.embedding_model_provider, # type: ignore | |||||
| knowledge_config.embedding_model, # type: ignore | |||||
| ) | ) | ||||
| dataset_collection_binding_id = dataset_collection_binding.id | dataset_collection_binding_id = dataset_collection_binding.id | ||||
| if knowledge_config.retrieval_model: | if knowledge_config.retrieval_model: | ||||
| retrieval_model = knowledge_config.retrieval_model | retrieval_model = knowledge_config.retrieval_model | ||||
| else: | else: | ||||
| default_retrieval_model = { | |||||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||||
| "reranking_enable": False, | |||||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||||
| "top_k": 2, | |||||
| "score_threshold_enabled": False, | |||||
| } | |||||
| retrieval_model = RetrievalModel(**default_retrieval_model) | |||||
| retrieval_model = RetrievalModel( | |||||
| search_method=RetrievalMethod.SEMANTIC_SEARCH.value, | |||||
| reranking_enable=False, | |||||
| reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), | |||||
| top_k=2, | |||||
| score_threshold_enabled=False, | |||||
| ) | |||||
| # save dataset | # save dataset | ||||
| dataset = Dataset( | dataset = Dataset( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| raise ValueError("Can't update disabled segment") | raise ValueError("Can't update disabled segment") | ||||
| try: | try: | ||||
| word_count_change = segment.word_count | word_count_change = segment.word_count | ||||
| content = args.content | |||||
| content = args.content or segment.content | |||||
| if segment.content == content: | if segment.content == content: | ||||
| segment.word_count = len(content) | segment.word_count = len(content) | ||||
| if document.doc_form == "qa_model": | if document.doc_form == "qa_model": | ||||
| segment.answer = args.answer | segment.answer = args.answer | ||||
| segment.word_count += len(args.answer) | |||||
| segment.word_count += len(args.answer) if args.answer else 0 | |||||
| word_count_change = segment.word_count - word_count_change | word_count_change = segment.word_count - word_count_change | ||||
| if args.keywords: | if args.keywords: | ||||
| segment.keywords = args.keywords | segment.keywords = args.keywords | ||||
| db.session.add(document) | db.session.add(document) | ||||
| # update segment index task | # update segment index task | ||||
| if args.enabled: | if args.enabled: | ||||
| VectorService.create_segments_vector([args.keywords], [segment], dataset) | |||||
| VectorService.create_segments_vector( | |||||
| [args.keywords] if args.keywords else None, | |||||
| [segment], | |||||
| dataset, | |||||
| document.doc_form, | |||||
| ) | |||||
| if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: | if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: | ||||
| # regenerate child chunks | # regenerate child chunks | ||||
| # get embedding model instance | # get embedding model instance | ||||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| if not processing_rule: | |||||
| raise ValueError("No processing rule found.") | |||||
| VectorService.generate_child_chunks( | VectorService.generate_child_chunks( | ||||
| segment, document, dataset, embedding_model_instance, processing_rule, True | segment, document, dataset, embedding_model_instance, processing_rule, True | ||||
| ) | ) | ||||
| segment.disabled_by = None | segment.disabled_by = None | ||||
| if document.doc_form == "qa_model": | if document.doc_form == "qa_model": | ||||
| segment.answer = args.answer | segment.answer = args.answer | ||||
| segment.word_count += len(args.answer) | |||||
| segment.word_count += len(args.answer) if args.answer else 0 | |||||
| word_count_change = segment.word_count - word_count_change | word_count_change = segment.word_count - word_count_change | ||||
| # update document word count | # update document word count | ||||
| if word_count_change != 0: | if word_count_change != 0: | ||||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| if not processing_rule: | |||||
| raise ValueError("No processing rule found.") | |||||
| VectorService.generate_child_chunks( | VectorService.generate_child_chunks( | ||||
| segment, document, dataset, embedding_model_instance, processing_rule, True | segment, document, dataset, embedding_model_instance, processing_rule, True | ||||
| ) | ) |
| original_document_id: Optional[str] = None | original_document_id: Optional[str] = None | ||||
| duplicate: bool = True | duplicate: bool = True | ||||
| indexing_technique: Literal["high_quality", "economy"] | indexing_technique: Literal["high_quality", "economy"] | ||||
| data_source: Optional[DataSource] = None | |||||
| data_source: DataSource | |||||
| process_rule: Optional[ProcessRule] = None | process_rule: Optional[ProcessRule] = None | ||||
| retrieval_model: Optional[RetrievalModel] = None | retrieval_model: Optional[RetrievalModel] = None | ||||
| doc_form: str = "text_model" | doc_form: str = "text_model" |
| db.session.add(dataset_query) | db.session.add(dataset_query) | ||||
| db.session.commit() | db.session.commit() | ||||
| return cls.compact_retrieve_response(query, all_documents) | |||||
| return cls.compact_retrieve_response(query, all_documents) # type: ignore | |||||
| @classmethod | @classmethod | ||||
| def external_retrieve( | def external_retrieve( |
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| if not processing_rule: | |||||
| raise ValueError("No processing rule found.") | |||||
| # get embedding model instance | # get embedding model instance | ||||
| if dataset.indexing_technique == "high_quality": | if dataset.indexing_technique == "high_quality": | ||||
| # check embedding model setting | # check embedding model setting | ||||
| def generate_child_chunks( | def generate_child_chunks( | ||||
| cls, | cls, | ||||
| segment: DocumentSegment, | segment: DocumentSegment, | ||||
| dataset_document: Document, | |||||
| dataset_document: DatasetDocument, | |||||
| dataset: Dataset, | dataset: Dataset, | ||||
| embedding_model_instance: ModelInstance, | embedding_model_instance: ModelInstance, | ||||
| processing_rule: DatasetProcessRule, | processing_rule: DatasetProcessRule, | ||||
| doc_language=dataset_document.doc_language, | doc_language=dataset_document.doc_language, | ||||
| ) | ) | ||||
| # save child chunks | # save child chunks | ||||
| if len(documents) > 0 and len(documents[0].children) > 0: | |||||
| if documents and documents[0].children: | |||||
| index_processor.load(dataset, documents) | index_processor.load(dataset, documents) | ||||
| for position, child_chunk in enumerate(documents[0].children, start=1): | for position, child_chunk in enumerate(documents[0].children, start=1): |
| for upload_file_id in image_upload_file_ids: | for upload_file_id in image_upload_file_ids: | ||||
| image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() | image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() | ||||
| try: | try: | ||||
| storage.delete(image_file.key) | |||||
| if image_file and image_file.key: | |||||
| storage.delete(image_file.key) | |||||
| except Exception: | except Exception: | ||||
| logging.exception( | logging.exception( | ||||
| "Delete image_files failed when storage deleted, \ | "Delete image_files failed when storage deleted, \ |