|
|
|
@@ -36,17 +36,20 @@ class DeepResearcher: |
|
|
|
self._kb_retrieve = kb_retrieve |
|
|
|
self._kg_retrieve = kg_retrieve |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _remove_query_tags(text): |
|
|
|
"""Remove query tags from text""" |
|
|
|
pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY) |
|
|
|
def _remove_tags(text: str, start_tag: str, end_tag: str) -> str: |
|
|
|
"""General Tag Removal Method""" |
|
|
|
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag) |
|
|
|
return re.sub(pattern, "", text) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _remove_result_tags(text): |
|
|
|
"""Remove result tags from text""" |
|
|
|
pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT) |
|
|
|
return re.sub(pattern, "", text) |
|
|
|
def _remove_query_tags(text: str) -> str: |
|
|
|
"""Remove Query Tags""" |
|
|
|
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _remove_result_tags(text: str) -> str: |
|
|
|
"""Remove Result Tags""" |
|
|
|
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT) |
|
|
|
|
|
|
|
def _generate_reasoning(self, msg_history): |
|
|
|
"""Generate reasoning steps""" |
|
|
|
@@ -95,21 +98,31 @@ class DeepResearcher: |
|
|
|
def _retrieve_information(self, search_query): |
|
|
|
"""Retrieve information from different sources""" |
|
|
|
# 1. Knowledge base retrieval |
|
|
|
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} |
|
|
|
|
|
|
|
kbinfos = [] |
|
|
|
try: |
|
|
|
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []} |
|
|
|
except Exception as e: |
|
|
|
logging.error(f"Knowledge base retrieval error: {e}") |
|
|
|
|
|
|
|
# 2. Web retrieval (if Tavily API is configured) |
|
|
|
if self.prompt_config.get("tavily_api_key"): |
|
|
|
tav = Tavily(self.prompt_config["tavily_api_key"]) |
|
|
|
tav_res = tav.retrieve_chunks(search_query) |
|
|
|
kbinfos["chunks"].extend(tav_res["chunks"]) |
|
|
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) |
|
|
|
|
|
|
|
try: |
|
|
|
if self.prompt_config.get("tavily_api_key"): |
|
|
|
tav = Tavily(self.prompt_config["tavily_api_key"]) |
|
|
|
tav_res = tav.retrieve_chunks(search_query) |
|
|
|
kbinfos["chunks"].extend(tav_res["chunks"]) |
|
|
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) |
|
|
|
except Exception as e: |
|
|
|
logging.error(f"Web retrieval error: {e}") |
|
|
|
|
|
|
|
# 3. Knowledge graph retrieval (if configured) |
|
|
|
if self.prompt_config.get("use_kg") and self._kg_retrieve: |
|
|
|
ck = self._kg_retrieve(question=search_query) |
|
|
|
if ck["content_with_weight"]: |
|
|
|
kbinfos["chunks"].insert(0, ck) |
|
|
|
|
|
|
|
try: |
|
|
|
if self.prompt_config.get("use_kg") and self._kg_retrieve: |
|
|
|
ck = self._kg_retrieve(question=search_query) |
|
|
|
if ck["content_with_weight"]: |
|
|
|
kbinfos["chunks"].insert(0, ck) |
|
|
|
except Exception as e: |
|
|
|
logging.error(f"Knowledge graph retrieval error: {e}") |
|
|
|
|
|
|
|
return kbinfos |
|
|
|
|
|
|
|
def _update_chunk_info(self, chunk_info, kbinfos): |