| @@ -3,11 +3,13 @@ from typing import Any | |||
| from pydantic import BaseModel, Field | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.entities.context_entities import DocumentContext | |||
| from core.rag.models.document import Document as RetrievalDocument | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from models.dataset import Dataset | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.external_knowledge_service import ExternalDatasetService | |||
| default_retrieval_model = { | |||
| @@ -54,7 +56,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| if not dataset: | |||
| return "" | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.on_query(query, dataset.id) | |||
| if dataset.provider == "external": | |||
| @@ -125,7 +126,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| ) | |||
| else: | |||
| documents = [] | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.on_tool_end(documents) | |||
| document_score_list = {} | |||
| @@ -134,50 +134,46 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| if item.metadata is not None and item.metadata.get("score"): | |||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||
| document_context_list = [] | |||
| index_node_ids = [document.metadata["doc_id"] for document in documents] | |||
| segments = DocumentSegment.query.filter( | |||
| DocumentSegment.dataset_id == self.dataset_id, | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||
| ).all() | |||
| if segments: | |||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||
| sorted_segments = sorted( | |||
| segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) | |||
| ) | |||
| for segment in sorted_segments: | |||
| records = RetrievalService.format_retrieval_documents(documents) | |||
| if records: | |||
| for record in records: | |||
| segment = record.segment | |||
| if segment.answer: | |||
| document_context_list.append( | |||
| f"question:{segment.get_sign_content()} answer:{segment.answer}" | |||
| DocumentContext( | |||
| content=f"question:{segment.get_sign_content()} answer:{segment.answer}", | |||
| score=record.score, | |||
| ) | |||
| ) | |||
| else: | |||
| document_context_list.append(segment.get_sign_content()) | |||
| document_context_list.append( | |||
| DocumentContext( | |||
| content=segment.get_sign_content(), | |||
| score=record.score, | |||
| ) | |||
| ) | |||
| retrieval_resource_list = [] | |||
| if self.return_resource: | |||
| context_list = [] | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| document_segment = Document.query.filter( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| document = DatasetDocument.query.filter( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| ).first() | |||
| if not document_segment: | |||
| continue | |||
| if dataset and document_segment: | |||
| if dataset and document: | |||
| source = { | |||
| "position": resource_number, | |||
| "dataset_id": dataset.id, | |||
| "dataset_name": dataset.name, | |||
| "document_id": document_segment.id, | |||
| "document_name": document_segment.name, | |||
| "data_source_type": document_segment.data_source_type, | |||
| "document_id": document.id, # type: ignore | |||
| "document_name": document.name, # type: ignore | |||
| "data_source_type": document.data_source_type, # type: ignore | |||
| "segment_id": segment.id, | |||
| "retriever_from": self.retriever_from, | |||
| "score": document_score_list.get(segment.index_node_id, None), | |||
| "score": record.score or 0.0, | |||
| } | |||
| if self.retriever_from == "dev": | |||
| source["hit_count"] = segment.hit_count | |||
| source["word_count"] = segment.word_count | |||
| @@ -187,10 +183,19 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" | |||
| else: | |||
| source["content"] = segment.content | |||
| context_list.append(source) | |||
| resource_number += 1 | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(context_list) | |||
| retrieval_resource_list.append(source) | |||
| return str("\n".join(document_context_list)) | |||
| if self.return_resource and retrieval_resource_list: | |||
| retrieval_resource_list = sorted( | |||
| retrieval_resource_list, | |||
| key=lambda x: x.get("score") or 0.0, | |||
| reverse=True, | |||
| ) | |||
| for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore | |||
| item["position"] = position # type: ignore | |||
| for hit_callback in self.hit_callbacks: | |||
| hit_callback.return_retriever_resource_info(retrieval_resource_list) | |||
| if document_context_list: | |||
| document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) | |||
| return str("\n".join([document_context.content for document_context in document_context_list])) | |||
| return "" | |||