|
|
|
@@ -24,6 +24,7 @@ from api.db.services.llm_service import LLMBundle |
|
|
|
from api import settings |
|
|
|
from agent.component.base import ComponentBase, ComponentParamBase |
|
|
|
from rag.app.tag import label_question |
|
|
|
from rag.utils.tavily_conn import Tavily |
|
|
|
|
|
|
|
|
|
|
|
class RetrievalParam(ComponentParamBase): |
|
|
|
@@ -40,6 +41,7 @@ class RetrievalParam(ComponentParamBase): |
|
|
|
self.kb_ids = [] |
|
|
|
self.rerank_id = "" |
|
|
|
self.empty_response = "" |
|
|
|
self.tavily_api_key = "" |
|
|
|
|
|
|
|
def check(self): |
|
|
|
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold") |
|
|
|
@@ -75,6 +77,11 @@ class Retrieval(ComponentBase, ABC): |
|
|
|
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, |
|
|
|
aggs=False, rerank_mdl=rerank_mdl, |
|
|
|
rank_feature=label_question(query, kbs)) |
|
|
|
if self._param.tavily_api_key: |
|
|
|
tav = Tavily(self._param.tavily_api_key) |
|
|
|
tav_res = tav.retrieve_chunks(query) |
|
|
|
kbinfos["chunks"].extend(tav_res["chunks"]) |
|
|
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) |
|
|
|
|
|
|
|
if not kbinfos["chunks"]: |
|
|
|
df = Retrieval.be_output("") |