|
|
|
@@ -1,4 +1,6 @@ |
|
|
|
import concurrent.futures |
|
|
|
import logging |
|
|
|
import time |
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
from typing import Optional |
|
|
|
|
|
|
|
@@ -46,7 +48,7 @@ class RetrievalService: |
|
|
|
if not query: |
|
|
|
return [] |
|
|
|
dataset = cls._get_dataset(dataset_id) |
|
|
|
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: |
|
|
|
if not dataset: |
|
|
|
return [] |
|
|
|
|
|
|
|
all_documents: list[Document] = [] |
|
|
|
@@ -178,6 +180,7 @@ class RetrievalService: |
|
|
|
if not dataset: |
|
|
|
raise ValueError("dataset not found") |
|
|
|
|
|
|
|
start = time.time() |
|
|
|
vector = Vector(dataset=dataset) |
|
|
|
documents = vector.search_by_vector( |
|
|
|
query, |
|
|
|
@@ -187,6 +190,7 @@ class RetrievalService: |
|
|
|
filter={"group_id": [dataset.id]}, |
|
|
|
document_ids_filter=document_ids_filter, |
|
|
|
) |
|
|
|
logging.debug(f"embedding_search ends at {time.time() - start:.2f} seconds") |
|
|
|
|
|
|
|
if documents: |
|
|
|
if ( |
|
|
|
@@ -270,7 +274,8 @@ class RetrievalService: |
|
|
|
return [] |
|
|
|
|
|
|
|
try: |
|
|
|
# Collect document IDs |
|
|
|
start_time = time.time() |
|
|
|
# Collect document IDs with existence check |
|
|
|
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata} |
|
|
|
if not document_ids: |
|
|
|
return [] |
|
|
|
@@ -288,110 +293,126 @@ class RetrievalService: |
|
|
|
include_segment_ids = set() |
|
|
|
segment_child_map = {} |
|
|
|
|
|
|
|
# Process documents |
|
|
|
# Precompute doc_forms to avoid redundant checks |
|
|
|
doc_forms = {} |
|
|
|
for doc in documents: |
|
|
|
document_id = doc.metadata.get("document_id") |
|
|
|
dataset_doc = dataset_documents.get(document_id) |
|
|
|
if dataset_doc: |
|
|
|
doc_forms[document_id] = dataset_doc.doc_form |
|
|
|
|
|
|
|
# Batch collect index node IDs with type safety |
|
|
|
child_index_node_ids = [] |
|
|
|
index_node_ids = [] |
|
|
|
for doc in documents: |
|
|
|
document_id = doc.metadata.get("document_id") |
|
|
|
if doc_forms.get(document_id) == IndexType.PARENT_CHILD_INDEX: |
|
|
|
child_index_node_ids.append(doc.metadata.get("doc_id")) |
|
|
|
else: |
|
|
|
index_node_ids.append(doc.metadata.get("doc_id")) |
|
|
|
|
|
|
|
# Batch query ChildChunk |
|
|
|
child_chunks = db.session.query(ChildChunk).filter(ChildChunk.index_node_id.in_(child_index_node_ids)).all() |
|
|
|
child_chunk_map = {chunk.index_node_id: chunk for chunk in child_chunks} |
|
|
|
|
|
|
|
# Batch query DocumentSegment with unified conditions |
|
|
|
segment_map = { |
|
|
|
segment.id: segment |
|
|
|
for segment in db.session.query(DocumentSegment) |
|
|
|
.filter( |
|
|
|
( |
|
|
|
DocumentSegment.index_node_id.in_(index_node_ids) |
|
|
|
| DocumentSegment.id.in_([chunk.segment_id for chunk in child_chunks]) |
|
|
|
), |
|
|
|
DocumentSegment.enabled == True, |
|
|
|
DocumentSegment.status == "completed", |
|
|
|
) |
|
|
|
.options( |
|
|
|
load_only( |
|
|
|
DocumentSegment.id, |
|
|
|
DocumentSegment.content, |
|
|
|
DocumentSegment.answer, |
|
|
|
) |
|
|
|
) |
|
|
|
.all() |
|
|
|
} |
|
|
|
|
|
|
|
for document in documents: |
|
|
|
document_id = document.metadata.get("document_id") |
|
|
|
if document_id not in dataset_documents: |
|
|
|
continue |
|
|
|
|
|
|
|
dataset_document = dataset_documents[document_id] |
|
|
|
dataset_document = dataset_documents.get(document_id) |
|
|
|
if not dataset_document: |
|
|
|
continue |
|
|
|
|
|
|
|
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: |
|
|
|
# Handle parent-child documents |
|
|
|
doc_form = doc_forms.get(document_id) |
|
|
|
if doc_form == IndexType.PARENT_CHILD_INDEX: |
|
|
|
# Handle parent-child documents using preloaded data |
|
|
|
child_index_node_id = document.metadata.get("doc_id") |
|
|
|
if not child_index_node_id: |
|
|
|
continue |
|
|
|
|
|
|
|
child_chunk = ( |
|
|
|
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() |
|
|
|
) |
|
|
|
|
|
|
|
child_chunk = child_chunk_map.get(child_index_node_id) |
|
|
|
if not child_chunk: |
|
|
|
continue |
|
|
|
|
|
|
|
segment = ( |
|
|
|
db.session.query(DocumentSegment) |
|
|
|
.filter( |
|
|
|
DocumentSegment.dataset_id == dataset_document.dataset_id, |
|
|
|
DocumentSegment.enabled == True, |
|
|
|
DocumentSegment.status == "completed", |
|
|
|
DocumentSegment.id == child_chunk.segment_id, |
|
|
|
) |
|
|
|
.options( |
|
|
|
load_only( |
|
|
|
DocumentSegment.id, |
|
|
|
DocumentSegment.content, |
|
|
|
DocumentSegment.answer, |
|
|
|
) |
|
|
|
) |
|
|
|
.first() |
|
|
|
) |
|
|
|
|
|
|
|
segment = segment_map.get(child_chunk.segment_id) |
|
|
|
if not segment: |
|
|
|
continue |
|
|
|
|
|
|
|
if segment.id not in include_segment_ids: |
|
|
|
include_segment_ids.add(segment.id) |
|
|
|
child_chunk_detail = { |
|
|
|
"id": child_chunk.id, |
|
|
|
"content": child_chunk.content, |
|
|
|
"position": child_chunk.position, |
|
|
|
"score": document.metadata.get("score", 0.0), |
|
|
|
} |
|
|
|
map_detail = { |
|
|
|
"max_score": document.metadata.get("score", 0.0), |
|
|
|
"child_chunks": [child_chunk_detail], |
|
|
|
} |
|
|
|
map_detail = {"max_score": document.metadata.get("score", 0.0), "child_chunks": []} |
|
|
|
segment_child_map[segment.id] = map_detail |
|
|
|
record = { |
|
|
|
"segment": segment, |
|
|
|
} |
|
|
|
records.append(record) |
|
|
|
else: |
|
|
|
child_chunk_detail = { |
|
|
|
"id": child_chunk.id, |
|
|
|
"content": child_chunk.content, |
|
|
|
"position": child_chunk.position, |
|
|
|
"score": document.metadata.get("score", 0.0), |
|
|
|
} |
|
|
|
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) |
|
|
|
segment_child_map[segment.id]["max_score"] = max( |
|
|
|
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) |
|
|
|
) |
|
|
|
records.append({"segment": segment}) |
|
|
|
|
|
|
|
# Append child chunk details |
|
|
|
child_chunk_detail = { |
|
|
|
"id": child_chunk.id, |
|
|
|
"content": child_chunk.content, |
|
|
|
"position": child_chunk.position, |
|
|
|
"score": document.metadata.get("score", 0.0), |
|
|
|
} |
|
|
|
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) |
|
|
|
segment_child_map[segment.id]["max_score"] = max( |
|
|
|
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) |
|
|
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
# Handle normal documents |
|
|
|
index_node_id = document.metadata.get("doc_id") |
|
|
|
if not index_node_id: |
|
|
|
continue |
|
|
|
|
|
|
|
segment = ( |
|
|
|
db.session.query(DocumentSegment) |
|
|
|
.filter( |
|
|
|
DocumentSegment.dataset_id == dataset_document.dataset_id, |
|
|
|
DocumentSegment.enabled == True, |
|
|
|
DocumentSegment.status == "completed", |
|
|
|
DocumentSegment.index_node_id == index_node_id, |
|
|
|
) |
|
|
|
.first() |
|
|
|
segment = next( |
|
|
|
( |
|
|
|
s |
|
|
|
for s in segment_map.values() |
|
|
|
if s.index_node_id == index_node_id and s.dataset_id == dataset_document.dataset_id |
|
|
|
), |
|
|
|
None, |
|
|
|
) |
|
|
|
|
|
|
|
if not segment: |
|
|
|
continue |
|
|
|
|
|
|
|
include_segment_ids.add(segment.id) |
|
|
|
record = { |
|
|
|
"segment": segment, |
|
|
|
"score": document.metadata.get("score"), # type: ignore |
|
|
|
} |
|
|
|
records.append(record) |
|
|
|
if segment.id not in include_segment_ids: |
|
|
|
include_segment_ids.add(segment.id) |
|
|
|
records.append( |
|
|
|
{ |
|
|
|
"segment": segment, |
|
|
|
"score": document.metadata.get("score", 0.0), |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
# Add child chunks information to records |
|
|
|
# Merge child chunks information |
|
|
|
for record in records: |
|
|
|
if record["segment"].id in segment_child_map: |
|
|
|
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore |
|
|
|
record["score"] = segment_child_map[record["segment"].id]["max_score"] |
|
|
|
segment_id = record["segment"].id |
|
|
|
if segment_id in segment_child_map: |
|
|
|
record["child_chunks"] = segment_child_map[segment_id]["child_chunks"] |
|
|
|
record["score"] = segment_child_map[segment_id]["max_score"] |
|
|
|
|
|
|
|
logging.debug(f"Formatting retrieval documents took {time.time() - start_time:.2f} seconds") |
|
|
|
return [RetrievalSegments(**record) for record in records] |
|
|
|
except Exception as e: |
|
|
|
# Only rollback if there were write operations |
|
|
|
db.session.rollback() |
|
|
|
raise e |