|  |  | @@ -1,9 +1,11 @@ | 
		
	
		
			
			|  |  |  | import concurrent.futures | 
		
	
		
			
			|  |  |  | import json | 
		
	
		
			
			|  |  |  | import threading | 
		
	
		
			
			|  |  |  | from typing import Optional | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | from flask import Flask, current_app | 
		
	
		
			
			|  |  |  | from sqlalchemy.orm import load_only | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | from configs import dify_config | 
		
	
		
			
			|  |  |  | from core.rag.data_post_processor.data_post_processor import DataPostProcessor | 
		
	
		
			
			|  |  |  | from core.rag.datasource.keyword.keyword_factory import Keyword | 
		
	
		
			
			|  |  |  | from core.rag.datasource.vdb.vector_factory import Vector | 
		
	
	
		
			
			|  |  | @@ -27,6 +29,7 @@ default_retrieval_model = { | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | class RetrievalService: | 
		
	
		
			
			|  |  |  | # Cache precompiled regular expressions to avoid repeated compilation | 
		
	
		
			
			|  |  |  | @classmethod | 
		
	
		
			
			|  |  |  | def retrieve( | 
		
	
		
			
			|  |  |  | cls, | 
		
	
	
		
			
			|  |  | @@ -41,74 +44,62 @@ class RetrievalService: | 
		
	
		
			
			|  |  |  | ): | 
		
	
		
			
			|  |  |  | if not query: | 
		
	
		
			
			|  |  |  | return [] | 
		
	
		
			
			|  |  |  | dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
		
	
		
			
			|  |  |  | if not dataset: | 
		
	
		
			
			|  |  |  | return [] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | dataset = cls._get_dataset(dataset_id) | 
		
	
		
			
			|  |  |  | if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: | 
		
	
		
			
			|  |  |  | return [] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | all_documents: list[Document] = [] | 
		
	
		
			
			|  |  |  | threads: list[threading.Thread] = [] | 
		
	
		
			
			|  |  |  | exceptions: list[str] = [] | 
		
	
		
			
			|  |  |  | # retrieval_model source with keyword | 
		
	
		
			
			|  |  |  | if retrieval_method == "keyword_search": | 
		
	
		
			
			|  |  |  | keyword_thread = threading.Thread( | 
		
	
		
			
			|  |  |  | target=RetrievalService.keyword_search, | 
		
	
		
			
			|  |  |  | kwargs={ | 
		
	
		
			
			|  |  |  | "flask_app": current_app._get_current_object(),  # type: ignore | 
		
	
		
			
			|  |  |  | "dataset_id": dataset_id, | 
		
	
		
			
			|  |  |  | "query": query, | 
		
	
		
			
			|  |  |  | "top_k": top_k, | 
		
	
		
			
			|  |  |  | "all_documents": all_documents, | 
		
	
		
			
			|  |  |  | "exceptions": exceptions, | 
		
	
		
			
			|  |  |  | }, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | threads.append(keyword_thread) | 
		
	
		
			
			|  |  |  | keyword_thread.start() | 
		
	
		
			
			|  |  |  | # retrieval_model source with semantic | 
		
	
		
			
			|  |  |  | if RetrievalMethod.is_support_semantic_search(retrieval_method): | 
		
	
		
			
			|  |  |  | embedding_thread = threading.Thread( | 
		
	
		
			
			|  |  |  | target=RetrievalService.embedding_search, | 
		
	
		
			
			|  |  |  | kwargs={ | 
		
	
		
			
			|  |  |  | "flask_app": current_app._get_current_object(),  # type: ignore | 
		
	
		
			
			|  |  |  | "dataset_id": dataset_id, | 
		
	
		
			
			|  |  |  | "query": query, | 
		
	
		
			
			|  |  |  | "top_k": top_k, | 
		
	
		
			
			|  |  |  | "score_threshold": score_threshold, | 
		
	
		
			
			|  |  |  | "reranking_model": reranking_model, | 
		
	
		
			
			|  |  |  | "all_documents": all_documents, | 
		
	
		
			
			|  |  |  | "retrieval_method": retrieval_method, | 
		
	
		
			
			|  |  |  | "exceptions": exceptions, | 
		
	
		
			
			|  |  |  | }, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | threads.append(embedding_thread) | 
		
	
		
			
			|  |  |  | embedding_thread.start() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # retrieval source with full text | 
		
	
		
			
			|  |  |  | if RetrievalMethod.is_support_fulltext_search(retrieval_method): | 
		
	
		
			
			|  |  |  | full_text_index_thread = threading.Thread( | 
		
	
		
			
			|  |  |  | target=RetrievalService.full_text_index_search, | 
		
	
		
			
			|  |  |  | kwargs={ | 
		
	
		
			
			|  |  |  | "flask_app": current_app._get_current_object(),  # type: ignore | 
		
	
		
			
			|  |  |  | "dataset_id": dataset_id, | 
		
	
		
			
			|  |  |  | "query": query, | 
		
	
		
			
			|  |  |  | "retrieval_method": retrieval_method, | 
		
	
		
			
			|  |  |  | "score_threshold": score_threshold, | 
		
	
		
			
			|  |  |  | "top_k": top_k, | 
		
	
		
			
			|  |  |  | "reranking_model": reranking_model, | 
		
	
		
			
			|  |  |  | "all_documents": all_documents, | 
		
	
		
			
			|  |  |  | "exceptions": exceptions, | 
		
	
		
			
			|  |  |  | }, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | threads.append(full_text_index_thread) | 
		
	
		
			
			|  |  |  | full_text_index_thread.start() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for thread in threads: | 
		
	
		
			
			|  |  |  | thread.join() | 
		
	
		
			
			|  |  |  | # Optimize multithreading with thread pools | 
		
	
		
			
			|  |  |  | with concurrent.futures.ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_WORKER) as executor:  # type: ignore | 
		
	
		
			
			|  |  |  | futures = [] | 
		
	
		
			
			|  |  |  | if retrieval_method == "keyword_search": | 
		
	
		
			
			|  |  |  | futures.append( | 
		
	
		
			
			|  |  |  | executor.submit( | 
		
	
		
			
			|  |  |  | cls.keyword_search, | 
		
	
		
			
			|  |  |  | flask_app=current_app._get_current_object(),  # type: ignore | 
		
	
		
			
			|  |  |  | dataset_id=dataset_id, | 
		
	
		
			
			|  |  |  | query=query, | 
		
	
		
			
			|  |  |  | top_k=top_k, | 
		
	
		
			
			|  |  |  | all_documents=all_documents, | 
		
	
		
			
			|  |  |  | exceptions=exceptions, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | if RetrievalMethod.is_support_semantic_search(retrieval_method): | 
		
	
		
			
			|  |  |  | futures.append( | 
		
	
		
			
			|  |  |  | executor.submit( | 
		
	
		
			
			|  |  |  | cls.embedding_search, | 
		
	
		
			
			|  |  |  | flask_app=current_app._get_current_object(),  # type: ignore | 
		
	
		
			
			|  |  |  | dataset_id=dataset_id, | 
		
	
		
			
			|  |  |  | query=query, | 
		
	
		
			
			|  |  |  | top_k=top_k, | 
		
	
		
			
			|  |  |  | score_threshold=score_threshold, | 
		
	
		
			
			|  |  |  | reranking_model=reranking_model, | 
		
	
		
			
			|  |  |  | all_documents=all_documents, | 
		
	
		
			
			|  |  |  | retrieval_method=retrieval_method, | 
		
	
		
			
			|  |  |  | exceptions=exceptions, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | if RetrievalMethod.is_support_fulltext_search(retrieval_method): | 
		
	
		
			
			|  |  |  | futures.append( | 
		
	
		
			
			|  |  |  | executor.submit( | 
		
	
		
			
			|  |  |  | cls.full_text_index_search, | 
		
	
		
			
			|  |  |  | flask_app=current_app._get_current_object(),  # type: ignore | 
		
	
		
			
			|  |  |  | dataset_id=dataset_id, | 
		
	
		
			
			|  |  |  | query=query, | 
		
	
		
			
			|  |  |  | top_k=top_k, | 
		
	
		
			
			|  |  |  | score_threshold=score_threshold, | 
		
	
		
			
			|  |  |  | reranking_model=reranking_model, | 
		
	
		
			
			|  |  |  | all_documents=all_documents, | 
		
	
		
			
			|  |  |  | retrieval_method=retrieval_method, | 
		
	
		
			
			|  |  |  | exceptions=exceptions, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if exceptions: | 
		
	
		
			
			|  |  |  | exception_message = ";\n".join(exceptions) | 
		
	
		
			
			|  |  |  | raise ValueError(exception_message) | 
		
	
		
			
			|  |  |  | raise ValueError(";\n".join(exceptions)) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: | 
		
	
		
			
			|  |  |  | data_post_processor = DataPostProcessor( | 
		
	
	
		
			
			|  |  | @@ -133,18 +124,21 @@ class RetrievalService: | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | return all_documents | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | @classmethod | 
		
	
		
			
			|  |  |  | def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: | 
		
	
		
			
			|  |  |  | return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | @classmethod | 
		
	
		
			
			|  |  |  | def keyword_search( | 
		
	
		
			
			|  |  |  | cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list | 
		
	
		
			
			|  |  |  | ): | 
		
	
		
			
			|  |  |  | with flask_app.app_context(): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
		
	
		
			
			|  |  |  | dataset = cls._get_dataset(dataset_id) | 
		
	
		
			
			|  |  |  | if not dataset: | 
		
	
		
			
			|  |  |  | raise ValueError("dataset not found") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | keyword = Keyword(dataset=dataset) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) | 
		
	
		
			
			|  |  |  | all_documents.extend(documents) | 
		
	
		
			
			|  |  |  | except Exception as e: | 
		
	
	
		
			
			|  |  | @@ -165,12 +159,11 @@ class RetrievalService: | 
		
	
		
			
			|  |  |  | ): | 
		
	
		
			
			|  |  |  | with flask_app.app_context(): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
		
	
		
			
			|  |  |  | dataset = cls._get_dataset(dataset_id) | 
		
	
		
			
			|  |  |  | if not dataset: | 
		
	
		
			
			|  |  |  | raise ValueError("dataset not found") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | vector = Vector(dataset=dataset) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | documents = vector.search_by_vector( | 
		
	
		
			
			|  |  |  | query, | 
		
	
		
			
			|  |  |  | search_type="similarity_score_threshold", | 
		
	
	
		
			
			|  |  | @@ -187,7 +180,7 @@ class RetrievalService: | 
		
	
		
			
			|  |  |  | and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value | 
		
	
		
			
			|  |  |  | ): | 
		
	
		
			
			|  |  |  | data_post_processor = DataPostProcessor( | 
		
	
		
			
			|  |  |  | str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False | 
		
	
		
			
			|  |  |  | str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | all_documents.extend( | 
		
	
		
			
			|  |  |  | data_post_processor.invoke( | 
		
	
	
		
			
			|  |  | @@ -217,13 +210,11 @@ class RetrievalService: | 
		
	
		
			
			|  |  |  | ): | 
		
	
		
			
			|  |  |  | with flask_app.app_context(): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | 
		
	
		
			
			|  |  |  | dataset = cls._get_dataset(dataset_id) | 
		
	
		
			
			|  |  |  | if not dataset: | 
		
	
		
			
			|  |  |  | raise ValueError("dataset not found") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | vector_processor = Vector( | 
		
	
		
			
			|  |  |  | dataset=dataset, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | vector_processor = Vector(dataset=dataset) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) | 
		
	
		
			
			|  |  |  | if documents: | 
		
	
	
		
			
			|  |  | @@ -234,7 +225,7 @@ class RetrievalService: | 
		
	
		
			
			|  |  |  | and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value | 
		
	
		
			
			|  |  |  | ): | 
		
	
		
			
			|  |  |  | data_post_processor = DataPostProcessor( | 
		
	
		
			
			|  |  |  | str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False | 
		
	
		
			
			|  |  |  | str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | all_documents.extend( | 
		
	
		
			
			|  |  |  | data_post_processor.invoke( | 
		
	
	
		
			
			|  |  | @@ -253,64 +244,105 @@ class RetrievalService: | 
		
	
		
			
			|  |  |  | def escape_query_for_search(query: str) -> str: | 
		
	
		
			
			|  |  |  | return json.dumps(query).strip('"') | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | @staticmethod | 
		
	
		
			
			|  |  |  | def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: | 
		
	
		
			
			|  |  |  | records = [] | 
		
	
		
			
			|  |  |  | include_segment_ids = [] | 
		
	
		
			
			|  |  |  | segment_child_map = {} | 
		
	
		
			
			|  |  |  | for document in documents: | 
		
	
		
			
			|  |  |  | document_id = document.metadata.get("document_id") | 
		
	
		
			
			|  |  |  | dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | 
		
	
		
			
			|  |  |  | if dataset_document: | 
		
	
		
			
			|  |  |  | @classmethod | 
		
	
		
			
			|  |  |  | def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]: | 
		
	
		
			
			|  |  |  | """Format retrieval documents with optimized batch processing""" | 
		
	
		
			
			|  |  |  | if not documents: | 
		
	
		
			
			|  |  |  | return [] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | # Collect document IDs | 
		
	
		
			
			|  |  |  | document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata} | 
		
	
		
			
			|  |  |  | if not document_ids: | 
		
	
		
			
			|  |  |  | return [] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # Batch query dataset documents | 
		
	
		
			
			|  |  |  | dataset_documents = { | 
		
	
		
			
			|  |  |  | doc.id: doc | 
		
	
		
			
			|  |  |  | for doc in db.session.query(DatasetDocument) | 
		
	
		
			
			|  |  |  | .filter(DatasetDocument.id.in_(document_ids)) | 
		
	
		
			
			|  |  |  | .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) | 
		
	
		
			
			|  |  |  | .all() | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | records = [] | 
		
	
		
			
			|  |  |  | include_segment_ids = set() | 
		
	
		
			
			|  |  |  | segment_child_map = {} | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # Process documents | 
		
	
		
			
			|  |  |  | for document in documents: | 
		
	
		
			
			|  |  |  | document_id = document.metadata.get("document_id") | 
		
	
		
			
			|  |  |  | if document_id not in dataset_documents: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | dataset_document = dataset_documents[document_id] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | 
		
	
		
			
			|  |  |  | # Handle parent-child documents | 
		
	
		
			
			|  |  |  | child_index_node_id = document.metadata.get("doc_id") | 
		
	
		
			
			|  |  |  | result = ( | 
		
	
		
			
			|  |  |  | db.session.query(ChildChunk, DocumentSegment) | 
		
	
		
			
			|  |  |  | .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | child_chunk = ( | 
		
	
		
			
			|  |  |  | db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if not child_chunk: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | segment = ( | 
		
	
		
			
			|  |  |  | db.session.query(DocumentSegment) | 
		
	
		
			
			|  |  |  | .filter( | 
		
	
		
			
			|  |  |  | ChildChunk.index_node_id == child_index_node_id, | 
		
	
		
			
			|  |  |  | DocumentSegment.dataset_id == dataset_document.dataset_id, | 
		
	
		
			
			|  |  |  | DocumentSegment.enabled == True, | 
		
	
		
			
			|  |  |  | DocumentSegment.status == "completed", | 
		
	
		
			
			|  |  |  | DocumentSegment.id == child_chunk.segment_id, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | .options( | 
		
	
		
			
			|  |  |  | load_only( | 
		
	
		
			
			|  |  |  | DocumentSegment.id, | 
		
	
		
			
			|  |  |  | DocumentSegment.content, | 
		
	
		
			
			|  |  |  | DocumentSegment.answer, | 
		
	
		
			
			|  |  |  | DocumentSegment.doc_metadata, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | .first() | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | if result: | 
		
	
		
			
			|  |  |  | child_chunk, segment = result | 
		
	
		
			
			|  |  |  | if not segment: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | if segment.id not in include_segment_ids: | 
		
	
		
			
			|  |  |  | include_segment_ids.append(segment.id) | 
		
	
		
			
			|  |  |  | child_chunk_detail = { | 
		
	
		
			
			|  |  |  | "id": child_chunk.id, | 
		
	
		
			
			|  |  |  | "content": child_chunk.content, | 
		
	
		
			
			|  |  |  | "position": child_chunk.position, | 
		
	
		
			
			|  |  |  | "score": document.metadata.get("score", 0.0), | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | map_detail = { | 
		
	
		
			
			|  |  |  | "max_score": document.metadata.get("score", 0.0), | 
		
	
		
			
			|  |  |  | "child_chunks": [child_chunk_detail], | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | segment_child_map[segment.id] = map_detail | 
		
	
		
			
			|  |  |  | record = { | 
		
	
		
			
			|  |  |  | "segment": segment, | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | records.append(record) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | child_chunk_detail = { | 
		
	
		
			
			|  |  |  | "id": child_chunk.id, | 
		
	
		
			
			|  |  |  | "content": child_chunk.content, | 
		
	
		
			
			|  |  |  | "position": child_chunk.position, | 
		
	
		
			
			|  |  |  | "score": document.metadata.get("score", 0.0), | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | 
		
	
		
			
			|  |  |  | segment_child_map[segment.id]["max_score"] = max( | 
		
	
		
			
			|  |  |  | segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if not segment: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if segment.id not in include_segment_ids: | 
		
	
		
			
			|  |  |  | include_segment_ids.add(segment.id) | 
		
	
		
			
			|  |  |  | child_chunk_detail = { | 
		
	
		
			
			|  |  |  | "id": child_chunk.id, | 
		
	
		
			
			|  |  |  | "content": child_chunk.content, | 
		
	
		
			
			|  |  |  | "position": child_chunk.position, | 
		
	
		
			
			|  |  |  | "score": document.metadata.get("score", 0.0), | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | map_detail = { | 
		
	
		
			
			|  |  |  | "max_score": document.metadata.get("score", 0.0), | 
		
	
		
			
			|  |  |  | "child_chunks": [child_chunk_detail], | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | segment_child_map[segment.id] = map_detail | 
		
	
		
			
			|  |  |  | record = { | 
		
	
		
			
			|  |  |  | "segment": segment, | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | records.append(record) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | child_chunk_detail = { | 
		
	
		
			
			|  |  |  | "id": child_chunk.id, | 
		
	
		
			
			|  |  |  | "content": child_chunk.content, | 
		
	
		
			
			|  |  |  | "position": child_chunk.position, | 
		
	
		
			
			|  |  |  | "score": document.metadata.get("score", 0.0), | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | 
		
	
		
			
			|  |  |  | segment_child_map[segment.id]["max_score"] = max( | 
		
	
		
			
			|  |  |  | segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | index_node_id = document.metadata["doc_id"] | 
		
	
		
			
			|  |  |  | # Handle normal documents | 
		
	
		
			
			|  |  |  | index_node_id = document.metadata.get("doc_id") | 
		
	
		
			
			|  |  |  | if not index_node_id: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | segment = ( | 
		
	
		
			
			|  |  |  | db.session.query(DocumentSegment) | 
		
	
	
		
			
			|  |  | @@ -325,16 +357,24 @@ class RetrievalService: | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if not segment: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | include_segment_ids.append(segment.id) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | include_segment_ids.add(segment.id) | 
		
	
		
			
			|  |  |  | record = { | 
		
	
		
			
			|  |  |  | "segment": segment, | 
		
	
		
			
			|  |  |  | "score": document.metadata.get("score", None), | 
		
	
		
			
			|  |  |  | "score": document.metadata.get("score"),  # type: ignore | 
		
	
		
			
			|  |  |  | "segment_metadata": segment.doc_metadata, | 
		
	
		
			
			|  |  |  | } | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | records.append(record) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | # Add child chunks information to records | 
		
	
		
			
			|  |  |  | for record in records: | 
		
	
		
			
			|  |  |  | if record["segment"].id in segment_child_map: | 
		
	
		
			
			|  |  |  | record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) | 
		
	
		
			
			|  |  |  | record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks")  # type: ignore | 
		
	
		
			
			|  |  |  | record["score"] = segment_child_map[record["segment"].id]["max_score"] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | return [RetrievalSegments(**record) for record in records] | 
		
	
		
			
			|  |  |  | return [RetrievalSegments(**record) for record in records] | 
		
	
		
			
			|  |  |  | except Exception as e: | 
		
	
		
			
			|  |  |  | db.session.rollback() | 
		
	
		
			
			|  |  |  | raise e | 
		
	
		
			
			|  |  |  | finally: | 
		
	
		
			
			|  |  |  | db.session.close() |