| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| from core.rag.entities.context_entities import DocumentContext | |||||
| from core.rag.models.document import Document as RetrievalDocument | from core.rag.models.document import Document as RetrievalDocument | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, Document, DocumentSegment | |||||
| from models.dataset import Dataset | |||||
| from models.dataset import Document as DatasetDocument | |||||
| from services.external_knowledge_service import ExternalDatasetService | from services.external_knowledge_service import ExternalDatasetService | ||||
| default_retrieval_model = { | default_retrieval_model = { | ||||
| if not dataset: | if not dataset: | ||||
| return "" | return "" | ||||
| for hit_callback in self.hit_callbacks: | for hit_callback in self.hit_callbacks: | ||||
| hit_callback.on_query(query, dataset.id) | hit_callback.on_query(query, dataset.id) | ||||
| if dataset.provider == "external": | if dataset.provider == "external": | ||||
| ) | ) | ||||
| else: | else: | ||||
| documents = [] | documents = [] | ||||
| for hit_callback in self.hit_callbacks: | for hit_callback in self.hit_callbacks: | ||||
| hit_callback.on_tool_end(documents) | hit_callback.on_tool_end(documents) | ||||
| document_score_list = {} | document_score_list = {} | ||||
| if item.metadata is not None and item.metadata.get("score"): | if item.metadata is not None and item.metadata.get("score"): | ||||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | ||||
| document_context_list = [] | document_context_list = [] | ||||
| index_node_ids = [document.metadata["doc_id"] for document in documents] | |||||
| segments = DocumentSegment.query.filter( | |||||
| DocumentSegment.dataset_id == self.dataset_id, | |||||
| DocumentSegment.completed_at.isnot(None), | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||||
| ).all() | |||||
| if segments: | |||||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||||
| sorted_segments = sorted( | |||||
| segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) | |||||
| ) | |||||
| for segment in sorted_segments: | |||||
| records = RetrievalService.format_retrieval_documents(documents) | |||||
| if records: | |||||
| for record in records: | |||||
| segment = record.segment | |||||
| if segment.answer: | if segment.answer: | ||||
| document_context_list.append( | document_context_list.append( | ||||
| f"question:{segment.get_sign_content()} answer:{segment.answer}" | |||||
| DocumentContext( | |||||
| content=f"question:{segment.get_sign_content()} answer:{segment.answer}", | |||||
| score=record.score, | |||||
| ) | |||||
| ) | ) | ||||
| else: | else: | ||||
| document_context_list.append(segment.get_sign_content()) | |||||
| document_context_list.append( | |||||
| DocumentContext( | |||||
| content=segment.get_sign_content(), | |||||
| score=record.score, | |||||
| ) | |||||
| ) | |||||
| retrieval_resource_list = [] | |||||
| if self.return_resource: | if self.return_resource: | ||||
| context_list = [] | |||||
| resource_number = 1 | |||||
| for segment in sorted_segments: | |||||
| document_segment = Document.query.filter( | |||||
| Document.id == segment.document_id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| for record in records: | |||||
| segment = record.segment | |||||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||||
| document = DatasetDocument.query.filter( | |||||
| DatasetDocument.id == segment.document_id, | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ).first() | ).first() | ||||
| if not document_segment: | |||||
| continue | |||||
| if dataset and document_segment: | |||||
| if dataset and document: | |||||
| source = { | source = { | ||||
| "position": resource_number, | |||||
| "dataset_id": dataset.id, | "dataset_id": dataset.id, | ||||
| "dataset_name": dataset.name, | "dataset_name": dataset.name, | ||||
| "document_id": document_segment.id, | |||||
| "document_name": document_segment.name, | |||||
| "data_source_type": document_segment.data_source_type, | |||||
| "document_id": document.id, # type: ignore | |||||
| "document_name": document.name, # type: ignore | |||||
| "data_source_type": document.data_source_type, # type: ignore | |||||
| "segment_id": segment.id, | "segment_id": segment.id, | ||||
| "retriever_from": self.retriever_from, | "retriever_from": self.retriever_from, | ||||
| "score": document_score_list.get(segment.index_node_id, None), | |||||
| "score": record.score or 0.0, | |||||
| } | } | ||||
| if self.retriever_from == "dev": | if self.retriever_from == "dev": | ||||
| source["hit_count"] = segment.hit_count | source["hit_count"] = segment.hit_count | ||||
| source["word_count"] = segment.word_count | source["word_count"] = segment.word_count | ||||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | ||||
| else: | else: | ||||
| source["content"] = segment.content | source["content"] = segment.content | ||||
| context_list.append(source) | |||||
| resource_number += 1 | |||||
| for hit_callback in self.hit_callbacks: | |||||
| hit_callback.return_retriever_resource_info(context_list) | |||||
| retrieval_resource_list.append(source) | |||||
| return str("\n".join(document_context_list)) | |||||
| if self.return_resource and retrieval_resource_list: | |||||
| retrieval_resource_list = sorted( | |||||
| retrieval_resource_list, | |||||
| key=lambda x: x.get("score") or 0.0, | |||||
| reverse=True, | |||||
| ) | |||||
| for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore | |||||
| item["position"] = position # type: ignore | |||||
| for hit_callback in self.hit_callbacks: | |||||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | |||||
| if document_context_list: | |||||
| document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) | |||||
| return str("\n".join([document_context.content for document_context in document_context_list])) | |||||
| return "" |