|
|
|
@@ -1,9 +1,11 @@ |
|
|
|
import concurrent.futures |
|
|
|
import json |
|
|
|
import threading |
|
|
|
from typing import Optional |
|
|
|
|
|
|
|
from flask import Flask, current_app |
|
|
|
from sqlalchemy.orm import load_only |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor |
|
|
|
from core.rag.datasource.keyword.keyword_factory import Keyword |
|
|
|
from core.rag.datasource.vdb.vector_factory import Vector |
|
|
|
@@ -27,6 +29,7 @@ default_retrieval_model = { |
|
|
|
|
|
|
|
|
|
|
|
class RetrievalService: |
|
|
|
# Cache precompiled regular expressions to avoid repeated compilation |
|
|
|
@classmethod |
|
|
|
def retrieve( |
|
|
|
cls, |
|
|
|
@@ -41,74 +44,62 @@ class RetrievalService: |
|
|
|
): |
|
|
|
if not query: |
|
|
|
return [] |
|
|
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
|
|
if not dataset: |
|
|
|
return [] |
|
|
|
|
|
|
|
dataset = cls._get_dataset(dataset_id) |
|
|
|
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: |
|
|
|
return [] |
|
|
|
|
|
|
|
all_documents: list[Document] = [] |
|
|
|
threads: list[threading.Thread] = [] |
|
|
|
exceptions: list[str] = [] |
|
|
|
# retrieval_model source with keyword |
|
|
|
if retrieval_method == "keyword_search": |
|
|
|
keyword_thread = threading.Thread( |
|
|
|
target=RetrievalService.keyword_search, |
|
|
|
kwargs={ |
|
|
|
"flask_app": current_app._get_current_object(), # type: ignore |
|
|
|
"dataset_id": dataset_id, |
|
|
|
"query": query, |
|
|
|
"top_k": top_k, |
|
|
|
"all_documents": all_documents, |
|
|
|
"exceptions": exceptions, |
|
|
|
}, |
|
|
|
) |
|
|
|
threads.append(keyword_thread) |
|
|
|
keyword_thread.start() |
|
|
|
# retrieval_model source with semantic |
|
|
|
if RetrievalMethod.is_support_semantic_search(retrieval_method): |
|
|
|
embedding_thread = threading.Thread( |
|
|
|
target=RetrievalService.embedding_search, |
|
|
|
kwargs={ |
|
|
|
"flask_app": current_app._get_current_object(), # type: ignore |
|
|
|
"dataset_id": dataset_id, |
|
|
|
"query": query, |
|
|
|
"top_k": top_k, |
|
|
|
"score_threshold": score_threshold, |
|
|
|
"reranking_model": reranking_model, |
|
|
|
"all_documents": all_documents, |
|
|
|
"retrieval_method": retrieval_method, |
|
|
|
"exceptions": exceptions, |
|
|
|
}, |
|
|
|
) |
|
|
|
threads.append(embedding_thread) |
|
|
|
embedding_thread.start() |
|
|
|
|
|
|
|
# retrieval source with full text |
|
|
|
if RetrievalMethod.is_support_fulltext_search(retrieval_method): |
|
|
|
full_text_index_thread = threading.Thread( |
|
|
|
target=RetrievalService.full_text_index_search, |
|
|
|
kwargs={ |
|
|
|
"flask_app": current_app._get_current_object(), # type: ignore |
|
|
|
"dataset_id": dataset_id, |
|
|
|
"query": query, |
|
|
|
"retrieval_method": retrieval_method, |
|
|
|
"score_threshold": score_threshold, |
|
|
|
"top_k": top_k, |
|
|
|
"reranking_model": reranking_model, |
|
|
|
"all_documents": all_documents, |
|
|
|
"exceptions": exceptions, |
|
|
|
}, |
|
|
|
) |
|
|
|
threads.append(full_text_index_thread) |
|
|
|
full_text_index_thread.start() |
|
|
|
|
|
|
|
for thread in threads: |
|
|
|
thread.join() |
|
|
|
# Optimize multithreading with thread pools |
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_WORKER) as executor: # type: ignore |
|
|
|
futures = [] |
|
|
|
if retrieval_method == "keyword_search": |
|
|
|
futures.append( |
|
|
|
executor.submit( |
|
|
|
cls.keyword_search, |
|
|
|
flask_app=current_app._get_current_object(), # type: ignore |
|
|
|
dataset_id=dataset_id, |
|
|
|
query=query, |
|
|
|
top_k=top_k, |
|
|
|
all_documents=all_documents, |
|
|
|
exceptions=exceptions, |
|
|
|
) |
|
|
|
) |
|
|
|
if RetrievalMethod.is_support_semantic_search(retrieval_method): |
|
|
|
futures.append( |
|
|
|
executor.submit( |
|
|
|
cls.embedding_search, |
|
|
|
flask_app=current_app._get_current_object(), # type: ignore |
|
|
|
dataset_id=dataset_id, |
|
|
|
query=query, |
|
|
|
top_k=top_k, |
|
|
|
score_threshold=score_threshold, |
|
|
|
reranking_model=reranking_model, |
|
|
|
all_documents=all_documents, |
|
|
|
retrieval_method=retrieval_method, |
|
|
|
exceptions=exceptions, |
|
|
|
) |
|
|
|
) |
|
|
|
if RetrievalMethod.is_support_fulltext_search(retrieval_method): |
|
|
|
futures.append( |
|
|
|
executor.submit( |
|
|
|
cls.full_text_index_search, |
|
|
|
flask_app=current_app._get_current_object(), # type: ignore |
|
|
|
dataset_id=dataset_id, |
|
|
|
query=query, |
|
|
|
top_k=top_k, |
|
|
|
score_threshold=score_threshold, |
|
|
|
reranking_model=reranking_model, |
|
|
|
all_documents=all_documents, |
|
|
|
retrieval_method=retrieval_method, |
|
|
|
exceptions=exceptions, |
|
|
|
) |
|
|
|
) |
|
|
|
concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) |
|
|
|
|
|
|
|
if exceptions: |
|
|
|
exception_message = ";\n".join(exceptions) |
|
|
|
raise ValueError(exception_message) |
|
|
|
raise ValueError(";\n".join(exceptions)) |
|
|
|
|
|
|
|
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: |
|
|
|
data_post_processor = DataPostProcessor( |
|
|
|
@@ -133,18 +124,21 @@ class RetrievalService: |
|
|
|
) |
|
|
|
return all_documents |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: |
|
|
|
return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def keyword_search( |
|
|
|
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list |
|
|
|
): |
|
|
|
with flask_app.app_context(): |
|
|
|
try: |
|
|
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
|
|
dataset = cls._get_dataset(dataset_id) |
|
|
|
if not dataset: |
|
|
|
raise ValueError("dataset not found") |
|
|
|
|
|
|
|
keyword = Keyword(dataset=dataset) |
|
|
|
|
|
|
|
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) |
|
|
|
all_documents.extend(documents) |
|
|
|
except Exception as e: |
|
|
|
@@ -165,12 +159,11 @@ class RetrievalService: |
|
|
|
): |
|
|
|
with flask_app.app_context(): |
|
|
|
try: |
|
|
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
|
|
dataset = cls._get_dataset(dataset_id) |
|
|
|
if not dataset: |
|
|
|
raise ValueError("dataset not found") |
|
|
|
|
|
|
|
vector = Vector(dataset=dataset) |
|
|
|
|
|
|
|
documents = vector.search_by_vector( |
|
|
|
query, |
|
|
|
search_type="similarity_score_threshold", |
|
|
|
@@ -187,7 +180,7 @@ class RetrievalService: |
|
|
|
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value |
|
|
|
): |
|
|
|
data_post_processor = DataPostProcessor( |
|
|
|
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False |
|
|
|
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False |
|
|
|
) |
|
|
|
all_documents.extend( |
|
|
|
data_post_processor.invoke( |
|
|
|
@@ -217,13 +210,11 @@ class RetrievalService: |
|
|
|
): |
|
|
|
with flask_app.app_context(): |
|
|
|
try: |
|
|
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
|
|
dataset = cls._get_dataset(dataset_id) |
|
|
|
if not dataset: |
|
|
|
raise ValueError("dataset not found") |
|
|
|
|
|
|
|
vector_processor = Vector( |
|
|
|
dataset=dataset, |
|
|
|
) |
|
|
|
vector_processor = Vector(dataset=dataset) |
|
|
|
|
|
|
|
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) |
|
|
|
if documents: |
|
|
|
@@ -234,7 +225,7 @@ class RetrievalService: |
|
|
|
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value |
|
|
|
): |
|
|
|
data_post_processor = DataPostProcessor( |
|
|
|
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False |
|
|
|
str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL.value), reranking_model, None, False |
|
|
|
) |
|
|
|
all_documents.extend( |
|
|
|
data_post_processor.invoke( |
|
|
|
@@ -253,64 +244,105 @@ class RetrievalService: |
|
|
|
def escape_query_for_search(query: str) -> str: |
|
|
|
return json.dumps(query).strip('"') |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: |
|
|
|
records = [] |
|
|
|
include_segment_ids = [] |
|
|
|
segment_child_map = {} |
|
|
|
for document in documents: |
|
|
|
document_id = document.metadata.get("document_id") |
|
|
|
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() |
|
|
|
if dataset_document: |
|
|
|
@classmethod |
|
|
|
def format_retrieval_documents(cls, documents: list[Document]) -> list[RetrievalSegments]: |
|
|
|
"""Format retrieval documents with optimized batch processing""" |
|
|
|
if not documents: |
|
|
|
return [] |
|
|
|
|
|
|
|
try: |
|
|
|
# Collect document IDs |
|
|
|
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata} |
|
|
|
if not document_ids: |
|
|
|
return [] |
|
|
|
|
|
|
|
# Batch query dataset documents |
|
|
|
dataset_documents = { |
|
|
|
doc.id: doc |
|
|
|
for doc in db.session.query(DatasetDocument) |
|
|
|
.filter(DatasetDocument.id.in_(document_ids)) |
|
|
|
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) |
|
|
|
.all() |
|
|
|
} |
|
|
|
|
|
|
|
records = [] |
|
|
|
include_segment_ids = set() |
|
|
|
segment_child_map = {} |
|
|
|
|
|
|
|
# Process documents |
|
|
|
for document in documents: |
|
|
|
document_id = document.metadata.get("document_id") |
|
|
|
if document_id not in dataset_documents: |
|
|
|
continue |
|
|
|
|
|
|
|
dataset_document = dataset_documents[document_id] |
|
|
|
|
|
|
|
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: |
|
|
|
# Handle parent-child documents |
|
|
|
child_index_node_id = document.metadata.get("doc_id") |
|
|
|
result = ( |
|
|
|
db.session.query(ChildChunk, DocumentSegment) |
|
|
|
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) |
|
|
|
|
|
|
|
child_chunk = ( |
|
|
|
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() |
|
|
|
) |
|
|
|
|
|
|
|
if not child_chunk: |
|
|
|
continue |
|
|
|
|
|
|
|
segment = ( |
|
|
|
db.session.query(DocumentSegment) |
|
|
|
.filter( |
|
|
|
ChildChunk.index_node_id == child_index_node_id, |
|
|
|
DocumentSegment.dataset_id == dataset_document.dataset_id, |
|
|
|
DocumentSegment.enabled == True, |
|
|
|
DocumentSegment.status == "completed", |
|
|
|
DocumentSegment.id == child_chunk.segment_id, |
|
|
|
) |
|
|
|
.options( |
|
|
|
load_only( |
|
|
|
DocumentSegment.id, |
|
|
|
DocumentSegment.content, |
|
|
|
DocumentSegment.answer, |
|
|
|
DocumentSegment.doc_metadata, |
|
|
|
) |
|
|
|
) |
|
|
|
.first() |
|
|
|
) |
|
|
|
if result: |
|
|
|
child_chunk, segment = result |
|
|
|
if not segment: |
|
|
|
continue |
|
|
|
if segment.id not in include_segment_ids: |
|
|
|
include_segment_ids.append(segment.id) |
|
|
|
child_chunk_detail = { |
|
|
|
"id": child_chunk.id, |
|
|
|
"content": child_chunk.content, |
|
|
|
"position": child_chunk.position, |
|
|
|
"score": document.metadata.get("score", 0.0), |
|
|
|
} |
|
|
|
map_detail = { |
|
|
|
"max_score": document.metadata.get("score", 0.0), |
|
|
|
"child_chunks": [child_chunk_detail], |
|
|
|
} |
|
|
|
segment_child_map[segment.id] = map_detail |
|
|
|
record = { |
|
|
|
"segment": segment, |
|
|
|
} |
|
|
|
records.append(record) |
|
|
|
else: |
|
|
|
child_chunk_detail = { |
|
|
|
"id": child_chunk.id, |
|
|
|
"content": child_chunk.content, |
|
|
|
"position": child_chunk.position, |
|
|
|
"score": document.metadata.get("score", 0.0), |
|
|
|
} |
|
|
|
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) |
|
|
|
segment_child_map[segment.id]["max_score"] = max( |
|
|
|
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) |
|
|
|
) |
|
|
|
else: |
|
|
|
|
|
|
|
if not segment: |
|
|
|
continue |
|
|
|
|
|
|
|
if segment.id not in include_segment_ids: |
|
|
|
include_segment_ids.add(segment.id) |
|
|
|
child_chunk_detail = { |
|
|
|
"id": child_chunk.id, |
|
|
|
"content": child_chunk.content, |
|
|
|
"position": child_chunk.position, |
|
|
|
"score": document.metadata.get("score", 0.0), |
|
|
|
} |
|
|
|
map_detail = { |
|
|
|
"max_score": document.metadata.get("score", 0.0), |
|
|
|
"child_chunks": [child_chunk_detail], |
|
|
|
} |
|
|
|
segment_child_map[segment.id] = map_detail |
|
|
|
record = { |
|
|
|
"segment": segment, |
|
|
|
} |
|
|
|
records.append(record) |
|
|
|
else: |
|
|
|
child_chunk_detail = { |
|
|
|
"id": child_chunk.id, |
|
|
|
"content": child_chunk.content, |
|
|
|
"position": child_chunk.position, |
|
|
|
"score": document.metadata.get("score", 0.0), |
|
|
|
} |
|
|
|
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) |
|
|
|
segment_child_map[segment.id]["max_score"] = max( |
|
|
|
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) |
|
|
|
) |
|
|
|
else: |
|
|
|
index_node_id = document.metadata["doc_id"] |
|
|
|
# Handle normal documents |
|
|
|
index_node_id = document.metadata.get("doc_id") |
|
|
|
if not index_node_id: |
|
|
|
continue |
|
|
|
|
|
|
|
segment = ( |
|
|
|
db.session.query(DocumentSegment) |
|
|
|
@@ -325,16 +357,24 @@ class RetrievalService: |
|
|
|
|
|
|
|
if not segment: |
|
|
|
continue |
|
|
|
include_segment_ids.append(segment.id) |
|
|
|
|
|
|
|
include_segment_ids.add(segment.id) |
|
|
|
record = { |
|
|
|
"segment": segment, |
|
|
|
"score": document.metadata.get("score", None), |
|
|
|
"score": document.metadata.get("score"), # type: ignore |
|
|
|
"segment_metadata": segment.doc_metadata, |
|
|
|
} |
|
|
|
|
|
|
|
records.append(record) |
|
|
|
|
|
|
|
# Add child chunks information to records |
|
|
|
for record in records: |
|
|
|
if record["segment"].id in segment_child_map: |
|
|
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) |
|
|
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore |
|
|
|
record["score"] = segment_child_map[record["segment"].id]["max_score"] |
|
|
|
|
|
|
|
return [RetrievalSegments(**record) for record in records] |
|
|
|
return [RetrievalSegments(**record) for record in records] |
|
|
|
except Exception as e: |
|
|
|
db.session.rollback() |
|
|
|
raise e |
|
|
|
finally: |
|
|
|
db.session.close() |