Parcourir la source

add rerank check when doing mutil-retrieval (#9998)

tags/0.11.0
Jyong il y a 1 an
Parent
révision
9ebd453b87
Aucun compte lié à l'adresse e-mail de l'auteur

+ 1
- 1
api/core/rag/rerank/rerank_type.py Voir le fichier

from enum import Enum from enum import Enum




class RerankMode(Enum):
class RerankMode(str, Enum):
RERANKING_MODEL = "reranking_model" RERANKING_MODEL = "reranking_model"
WEIGHTED_SCORE = "weighted_score" WEIGHTED_SCORE = "weighted_score"

+ 31
- 1
api/core/rag/retrieval/dataset_retrieval.py Voir le fichier

from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
reranking_enable: bool = True, reranking_enable: bool = True,
message_id: Optional[str] = None, message_id: Optional[str] = None,
): ):
if not available_datasets:
return []
threads = [] threads = []
all_documents = [] all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets] dataset_ids = [dataset.id for dataset in available_datasets]
index_type = None
index_type_check = all(
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
)
if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
raise ValueError(
"The configured knowledge base list have different indexing technique, please set reranking model."
)
index_type = available_datasets[0].indexing_technique
if index_type == "high_quality":
embedding_model_check = all(
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
)
embedding_model_provider_check = all(
item.embedding_model_provider == available_datasets[0].embedding_model_provider
for item in available_datasets
)
if (
reranking_enable
and reranking_mode == "weighted_score"
and (not embedding_model_check or not embedding_model_provider_check)
):
raise ValueError(
"The configured knowledge base list have different embedding model, please set reranking model."
)
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model

for dataset in available_datasets: for dataset in available_datasets:
index_type = dataset.indexing_technique index_type = dataset.indexing_technique
retrieval_thread = threading.Thread( retrieval_thread = threading.Thread(

Chargement…
Annuler
Enregistrer