| @@ -1,13 +1,9 @@ | |||
| import concurrent.futures | |||
| import logging | |||
| import time | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Optional | |||
| from flask import Flask, current_app | |||
| from sqlalchemy import and_, or_ | |||
| from sqlalchemy.orm import load_only | |||
| from sqlalchemy.sql.expression import false | |||
| from configs import dify_config | |||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||
| @@ -182,7 +178,6 @@ class RetrievalService: | |||
| if not dataset: | |||
| raise ValueError("dataset not found") | |||
| start = time.time() | |||
| vector = Vector(dataset=dataset) | |||
| documents = vector.search_by_vector( | |||
| query, | |||
| @@ -192,7 +187,6 @@ class RetrievalService: | |||
| filter={"group_id": [dataset.id]}, | |||
| document_ids_filter=document_ids_filter, | |||
| ) | |||
| logging.debug(f"embedding_search ends at {time.time() - start:.2f} seconds") | |||
| if documents: | |||
| if ( | |||
| @@ -276,8 +270,7 @@ class RetrievalService: | |||
| return [] | |||
| try: | |||
| start_time = time.time() | |||
| # Collect document IDs with existence check | |||
| # Collect document IDs | |||
| document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata} | |||
| if not document_ids: | |||
| return [] | |||
| @@ -295,138 +288,110 @@ class RetrievalService: | |||
| include_segment_ids = set() | |||
| segment_child_map = {} | |||
| # Precompute doc_forms to avoid redundant checks | |||
| doc_forms = {} | |||
| for doc in documents: | |||
| document_id = doc.metadata.get("document_id") | |||
| dataset_doc = dataset_documents.get(document_id) | |||
| if dataset_doc: | |||
| doc_forms[document_id] = dataset_doc.doc_form | |||
| # Batch collect index node IDs with type safety | |||
| child_index_node_ids = [] | |||
| index_node_ids = [] | |||
| for doc in documents: | |||
| document_id = doc.metadata.get("document_id") | |||
| if doc_forms.get(document_id) == IndexType.PARENT_CHILD_INDEX: | |||
| child_index_node_ids.append(doc.metadata.get("doc_id")) | |||
| else: | |||
| index_node_ids.append(doc.metadata.get("doc_id")) | |||
| # Batch query ChildChunk | |||
| child_chunks = db.session.query(ChildChunk).filter(ChildChunk.index_node_id.in_(child_index_node_ids)).all() | |||
| child_chunk_map = {chunk.index_node_id: chunk for chunk in child_chunks} | |||
| segment_ids_from_child = [chunk.segment_id for chunk in child_chunks] | |||
| segment_conditions = [] | |||
| if index_node_ids: | |||
| segment_conditions.append(DocumentSegment.index_node_id.in_(index_node_ids)) | |||
| if segment_ids_from_child: | |||
| segment_conditions.append(DocumentSegment.id.in_(segment_ids_from_child)) | |||
| if segment_conditions: | |||
| filter_expr = or_(*segment_conditions) | |||
| else: | |||
| filter_expr = false() | |||
| segment_map = { | |||
| segment.id: segment | |||
| for segment in db.session.query(DocumentSegment) | |||
| .filter( | |||
| and_( | |||
| filter_expr, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| ) | |||
| ) | |||
| .options( | |||
| load_only( | |||
| DocumentSegment.id, | |||
| DocumentSegment.content, | |||
| DocumentSegment.answer, | |||
| ) | |||
| ) | |||
| .all() | |||
| } | |||
| # Process documents | |||
| for document in documents: | |||
| document_id = document.metadata.get("document_id") | |||
| dataset_document = dataset_documents.get(document_id) | |||
| if document_id not in dataset_documents: | |||
| continue | |||
| dataset_document = dataset_documents[document_id] | |||
| if not dataset_document: | |||
| continue | |||
| doc_form = doc_forms.get(document_id) | |||
| if doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| # Handle parent-child documents using preloaded data | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| # Handle parent-child documents | |||
| child_index_node_id = document.metadata.get("doc_id") | |||
| if not child_index_node_id: | |||
| continue | |||
| child_chunk = child_chunk_map.get(child_index_node_id) | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() | |||
| ) | |||
| if not child_chunk: | |||
| continue | |||
| segment = segment_map.get(child_chunk.segment_id) | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.id == child_chunk.segment_id, | |||
| ) | |||
| .options( | |||
| load_only( | |||
| DocumentSegment.id, | |||
| DocumentSegment.content, | |||
| DocumentSegment.answer, | |||
| ) | |||
| ) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| continue | |||
| if segment.id not in include_segment_ids: | |||
| include_segment_ids.add(segment.id) | |||
| map_detail = {"max_score": document.metadata.get("score", 0.0), "child_chunks": []} | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| map_detail = { | |||
| "max_score": document.metadata.get("score", 0.0), | |||
| "child_chunks": [child_chunk_detail], | |||
| } | |||
| segment_child_map[segment.id] = map_detail | |||
| records.append({"segment": segment}) | |||
| # Append child chunk details | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | |||
| segment_child_map[segment.id]["max_score"] = max( | |||
| segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | |||
| ) | |||
| record = { | |||
| "segment": segment, | |||
| } | |||
| records.append(record) | |||
| else: | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | |||
| segment_child_map[segment.id]["max_score"] = max( | |||
| segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | |||
| ) | |||
| else: | |||
| # Handle normal documents | |||
| index_node_id = document.metadata.get("doc_id") | |||
| if not index_node_id: | |||
| continue | |||
| segment = next( | |||
| ( | |||
| s | |||
| for s in segment_map.values() | |||
| if s.index_node_id == index_node_id and s.dataset_id == dataset_document.dataset_id | |||
| ), | |||
| None, | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.index_node_id == index_node_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| continue | |||
| if segment.id not in include_segment_ids: | |||
| include_segment_ids.add(segment.id) | |||
| records.append( | |||
| { | |||
| "segment": segment, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| ) | |||
| include_segment_ids.add(segment.id) | |||
| record = { | |||
| "segment": segment, | |||
| "score": document.metadata.get("score"), # type: ignore | |||
| } | |||
| records.append(record) | |||
| # Merge child chunks information | |||
| # Add child chunks information to records | |||
| for record in records: | |||
| segment_id = record["segment"].id | |||
| if segment_id in segment_child_map: | |||
| record["child_chunks"] = segment_child_map[segment_id]["child_chunks"] | |||
| record["score"] = segment_child_map[segment_id]["max_score"] | |||
| if record["segment"].id in segment_child_map: | |||
| record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore | |||
| record["score"] = segment_child_map[record["segment"].id]["max_score"] | |||
| logging.debug(f"Formatting retrieval documents took {time.time() - start_time:.2f} seconds") | |||
| return [RetrievalSegments(**record) for record in records] | |||
| except Exception as e: | |||
| # Only rollback if there were write operations | |||
| db.session.rollback() | |||
| raise e | |||