|
|
|
|
|
|
|
|
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold") |
|
|
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold") |
|
|
self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keywords similarity weight") |
|
|
self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keywords similarity weight") |
|
|
self.check_positive_number(self.top_n, "[Retrieval] Top N") |
|
|
self.check_positive_number(self.top_n, "[Retrieval] Top N") |
|
|
self.check_empty(self.kb_ids, "[Retrieval] Knowledge bases") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Retrieval(ComponentBase, ABC): |
|
|
class Retrieval(ComponentBase, ABC): |
|
|
component_name = "Retrieval" |
|
|
component_name = "Retrieval" |
|
|
|
|
|
|
|
|
def _run(self, history, **kwargs): |
|
|
def _run(self, history, **kwargs): |
|
|
# query = [] |
|
|
|
|
|
# for role, cnt in history[::-1][:self._param.message_history_window_size]: |
|
|
|
|
|
# if role != "user":continue |
|
|
|
|
|
# query.append(cnt) |
|
|
|
|
|
# # query = "\n".join(query) |
|
|
|
|
|
# query = query[0] |
|
|
|
|
|
query = self.get_input() |
|
|
query = self.get_input() |
|
|
query = str(query["content"][0]) if "content" in query else "" |
|
|
query = str(query["content"][0]) if "content" in query else "" |
|
|
|
|
|
|
|
|
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids) |
|
|
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids) |
|
|
if not kbs: |
|
|
if not kbs: |
|
|
raise ValueError("Can't find knowledgebases by {}".format(self._param.kb_ids)) |
|
|
|
|
|
|
|
|
return Retrieval.be_output("") |
|
|
|
|
|
|
|
|
embd_nms = list(set([kb.embd_id for kb in kbs])) |
|
|
embd_nms = list(set([kb.embd_id for kb in kbs])) |
|
|
assert len(embd_nms) == 1, "Knowledge bases use different embedding models." |
|
|
assert len(embd_nms) == 1, "Knowledge bases use different embedding models." |
|
|
|
|
|
|