| 
                        123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 | 
                        - #
 - #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 - #
 - #  Licensed under the Apache License, Version 2.0 (the "License");
 - #  you may not use this file except in compliance with the License.
 - #  You may obtain a copy of the License at
 - #
 - #      http://www.apache.org/licenses/LICENSE-2.0
 - #
 - #  Unless required by applicable law or agreed to in writing, software
 - #  distributed under the License is distributed on an "AS IS" BASIS,
 - #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 - #  See the License for the specific language governing permissions and
 - #  limitations under the License.
 - #
 - import os
 - import re
 - from abc import ABC
 - from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
 - from api.db import LLMType
 - from api.db.services.knowledgebase_service import KnowledgebaseService
 - from api.db.services.llm_service import LLMBundle
 - from api import settings
 - from api.utils.api_utils import timeout
 - from rag.app.tag import label_question
 - from rag.prompts import kb_prompt
 - from rag.prompts.prompts import cross_languages
 - 
 - 
 - class RetrievalParam(ToolParamBase):
 -     """
 -     Define the Retrieval component parameters.
 -     """
 - 
 -     def __init__(self):
 -         self.meta:ToolMeta = {
 -             "name": "search_my_dateset",
 -             "description": "This tool can be utilized for relevant content searching in the datasets.",
 -             "parameters": {
 -                 "query": {
 -                     "type": "string",
 -                     "description": "The keywords to search the dataset. The keywords should be the most important words/terms(includes synonyms) from the original request.",
 -                     "default": "",
 -                     "required": True
 -                 }
 -             }
 -         }
 -         super().__init__()
 -         self.function_name = "search_my_dateset"
 -         self.description = "This tool can be utilized for relevant content searching in the datasets."
 -         self.similarity_threshold = 0.2
 -         self.keywords_similarity_weight = 0.5
 -         self.top_n = 8
 -         self.top_k = 1024
 -         self.kb_ids = []
 -         self.kb_vars = []
 -         self.rerank_id = ""
 -         self.empty_response = ""
 -         self.use_kg = False
 -         self.cross_languages = []
 - 
 -     def check(self):
 -         self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
 -         self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keyword similarity weight")
 -         self.check_positive_number(self.top_n, "[Retrieval] Top N")
 - 
 -     def get_input_form(self) -> dict[str, dict]:
 -         return {
 -             "query": {
 -                 "name": "Query",
 -                 "type": "line"
 -             }
 -         }
 - 
 - class Retrieval(ToolBase, ABC):
 -     component_name = "Retrieval"
 - 
 -     @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
 -     def _invoke(self, **kwargs):
 -         if not kwargs.get("query"):
 -             self.set_output("formalized_content", self._param.empty_response)
 - 
 -         kb_ids: list[str] = []
 -         for id in self._param.kb_ids:
 -             if id.find("@") < 0:
 -                 kb_ids.append(id)
 -                 continue
 -             kb_nm = self._canvas.get_variable_value(id)
 -             e, kb = KnowledgebaseService.get_by_name(kb_nm, self._canvas._tenant_id)
 -             if not e:
 -                 raise Exception(f"Dataset({kb_nm}) does not exist.")
 -             kb_ids.append(kb.id)
 - 
 -         filtered_kb_ids: list[str] = list(set([kb_id for kb_id in kb_ids if kb_id]))
 - 
 -         kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
 -         if not kbs:
 -             raise Exception("No dataset is selected.")
 - 
 -         embd_nms = list(set([kb.embd_id for kb in kbs]))
 -         assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
 - 
 -         embd_mdl = None
 -         if embd_nms:
 -             embd_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, embd_nms[0])
 - 
 -         rerank_mdl = None
 -         if self._param.rerank_id:
 -             rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
 - 
 -         query = kwargs["query"]
 -         if self._param.cross_languages:
 -             query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
 - 
 -         if kbs:
 -             query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
 -             kbinfos = settings.retrievaler.retrieval(
 -                 query,
 -                 embd_mdl,
 -                 [kb.tenant_id for kb in kbs],
 -                 filtered_kb_ids,
 -                 1,
 -                 self._param.top_n,
 -                 self._param.similarity_threshold,
 -                 1 - self._param.keywords_similarity_weight,
 -                 aggs=False,
 -                 rerank_mdl=rerank_mdl,
 -                 rank_feature=label_question(query, kbs),
 -             )
 -             if self._param.use_kg:
 -                 ck = settings.kg_retrievaler.retrieval(query,
 -                                                        [kb.tenant_id for kb in kbs],
 -                                                        kb_ids,
 -                                                        embd_mdl,
 -                                                        LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
 -                 if ck["content_with_weight"]:
 -                     kbinfos["chunks"].insert(0, ck)
 -         else:
 -             kbinfos = {"chunks": [], "doc_aggs": []}
 - 
 -         if self._param.use_kg and kbs:
 -             ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
 -             if ck["content_with_weight"]:
 -                 ck["content"] = ck["content_with_weight"]
 -                 del ck["content_with_weight"]
 -                 kbinfos["chunks"].insert(0, ck)
 - 
 -         for ck in kbinfos["chunks"]:
 -             if "vector" in ck:
 -                 del ck["vector"]
 -             if "content_ltks" in ck:
 -                 del ck["content_ltks"]
 - 
 -         if not kbinfos["chunks"]:
 -             self.set_output("formalized_content", self._param.empty_response)
 -             return
 - 
 -         self._canvas.add_refernce(kbinfos["chunks"], kbinfos["doc_aggs"])
 -         form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
 -         self.set_output("formalized_content", form_cnt)
 -         return form_cnt
 - 
 -     def thoughts(self) -> str:
 -         return """
 - Keywords: {} 
 - Looking for the most relevant articles.
 -         """.format(self.get_input().get("query", "-_-!"))
 
 
  |