選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

tavily.py 9.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import os
  18. import time
  19. from abc import ABC
  20. from tavily import TavilyClient
  21. from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
  22. from api.utils.api_utils import timeout
  23. class TavilySearchParam(ToolParamBase):
  24. """
  25. Define the Retrieval component parameters.
  26. """
  27. def __init__(self):
  28. self.meta:ToolMeta = {
  29. "name": "tavily_search",
  30. "description": """
  31. Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
  32. When searching:
  33. - Start with specific query which should focus on just a single aspect.
  34. - Number of keywords in query should be less than 5.
  35. - Broaden search terms if needed
  36. - Cross-reference information from multiple sources
  37. """,
  38. "parameters": {
  39. "query": {
  40. "type": "string",
  41. "description": "The search keywords to execute with Tavily. The keywords should be the most important words/terms(includes synonyms) from the original request.",
  42. "default": "{sys.query}",
  43. "required": True
  44. },
  45. "topic": {
  46. "type": "string",
  47. "description": "default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.",
  48. "enum": ["general", "news"],
  49. "default": "general",
  50. "required": False,
  51. },
  52. "include_domains": {
  53. "type": "array",
  54. "description": "default:[]. A list of domains only from which the search results can be included.",
  55. "default": [],
  56. "items": {
  57. "type": "string",
  58. "description": "Domain name that must be included, e.g. www.yahoo.com"
  59. },
  60. "required": False
  61. },
  62. "exclude_domains": {
  63. "type": "array",
  64. "description": "default:[]. A list of domains from which the search results can not be included",
  65. "default": [],
  66. "items": {
  67. "type": "string",
  68. "description": "Domain name that must be excluded, e.g. www.yahoo.com"
  69. },
  70. "required": False
  71. },
  72. }
  73. }
  74. super().__init__()
  75. self.api_key = ""
  76. self.search_depth = "basic" # basic/advanced
  77. self.max_results = 6
  78. self.days = 14
  79. self.include_answer = False
  80. self.include_raw_content = False
  81. self.include_images = False
  82. self.include_image_descriptions = False
  83. def check(self):
  84. self.check_valid_value(self.topic, "Tavily topic: should be in 'general/news'", ["general", "news"])
  85. self.check_valid_value(self.search_depth, "Tavily search depth should be in 'basic/advanced'", ["basic", "advanced"])
  86. self.check_positive_integer(self.max_results, "Tavily max result number should be within [1, 20]")
  87. self.check_positive_integer(self.days, "Tavily days should be greater than 1")
  88. def get_input_form(self) -> dict[str, dict]:
  89. return {
  90. "query": {
  91. "name": "Query",
  92. "type": "line"
  93. }
  94. }
  95. class TavilySearch(ToolBase, ABC):
  96. component_name = "TavilySearch"
  97. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
  98. def _invoke(self, **kwargs):
  99. if not kwargs.get("query"):
  100. self.set_output("formalized_content", "")
  101. return ""
  102. self.tavily_client = TavilyClient(api_key=self._param.api_key)
  103. last_e = None
  104. for fld in ["search_depth", "topic", "max_results", "days", "include_answer", "include_raw_content", "include_images", "include_image_descriptions", "include_domains", "exclude_domains"]:
  105. if fld not in kwargs:
  106. kwargs[fld] = getattr(self._param, fld)
  107. for _ in range(self._param.max_retries+1):
  108. try:
  109. kwargs["include_images"] = False
  110. kwargs["include_raw_content"] = False
  111. res = self.tavily_client.search(**kwargs)
  112. self._retrieve_chunks(res["results"],
  113. get_title=lambda r: r["title"],
  114. get_url=lambda r: r["url"],
  115. get_content=lambda r: r["raw_content"] if r["raw_content"] else r["content"],
  116. get_score=lambda r: r["score"])
  117. self.set_output("json", res["results"])
  118. return self.output("formalized_content")
  119. except Exception as e:
  120. last_e = e
  121. logging.exception(f"Tavily error: {e}")
  122. time.sleep(self._param.delay_after_error)
  123. if last_e:
  124. self.set_output("_ERROR", str(last_e))
  125. return f"Tavily error: {last_e}"
  126. assert False, self.output()
  127. def thoughts(self) -> str:
  128. return """
  129. Keywords: {}
  130. Looking for the most relevant articles.
  131. """.format(self.get_input().get("query", "-_-!"))
  132. class TavilyExtractParam(ToolParamBase):
  133. """
  134. Define the Retrieval component parameters.
  135. """
  136. def __init__(self):
  137. self.meta:ToolMeta = {
  138. "name": "tavily_extract",
  139. "description": "Extract web page content from one or more specified URLs using Tavily Extract.",
  140. "parameters": {
  141. "urls": {
  142. "type": "array",
  143. "description": "The URLs to extract content from.",
  144. "default": "",
  145. "items": {
  146. "type": "string",
  147. "description": "The URL to extract content from, e.g. www.yahoo.com"
  148. },
  149. "required": True
  150. },
  151. "extract_depth": {
  152. "type": "string",
  153. "description": "The depth of the extraction process. advanced extraction retrieves more data, including tables and embedded content, with higher success but may increase latency.basic extraction costs 1 credit per 5 successful URL extractions, while advanced extraction costs 2 credits per 5 successful URL extractions.",
  154. "enum": ["basic", "advanced"],
  155. "default": "basic",
  156. "required": False,
  157. },
  158. "format": {
  159. "type": "string",
  160. "description": "The format of the extracted web page content. markdown returns content in markdown format. text returns plain text and may increase latency.",
  161. "enum": ["markdown", "text"],
  162. "default": "markdown",
  163. "required": False,
  164. }
  165. }
  166. }
  167. super().__init__()
  168. self.api_key = ""
  169. self.extract_depth = "basic" # basic/advanced
  170. self.urls = []
  171. self.format = "markdown"
  172. self.include_images = False
  173. def check(self):
  174. self.check_valid_value(self.extract_depth, "Tavily extract depth should be in 'basic/advanced'", ["basic", "advanced"])
  175. self.check_valid_value(self.format, "Tavily extract format should be in 'markdown/text'", ["markdown", "text"])
  176. def get_input_form(self) -> dict[str, dict]:
  177. return {
  178. "urls": {
  179. "name": "URLs",
  180. "type": "line"
  181. }
  182. }
  183. class TavilyExtract(ToolBase, ABC):
  184. component_name = "TavilyExtract"
  185. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
  186. def _invoke(self, **kwargs):
  187. self.tavily_client = TavilyClient(api_key=self._param.api_key)
  188. last_e = None
  189. for fld in ["urls", "extract_depth", "format"]:
  190. if fld not in kwargs:
  191. kwargs[fld] = getattr(self._param, fld)
  192. if kwargs.get("urls") and isinstance(kwargs["urls"], str):
  193. kwargs["urls"] = kwargs["urls"].split(",")
  194. for _ in range(self._param.max_retries+1):
  195. try:
  196. kwargs["include_images"] = False
  197. res = self.tavily_client.extract(**kwargs)
  198. self.set_output("json", res["results"])
  199. return self.output("json")
  200. except Exception as e:
  201. last_e = e
  202. logging.exception(f"Tavily error: {e}")
  203. if last_e:
  204. self.set_output("_ERROR", str(last_e))
  205. return f"Tavily error: {last_e}"
  206. assert False, self.output()
  207. def thoughts(self) -> str:
  208. return "Opened {}—pulling out the main text…".format(self.get_input().get("urls", "-_-!"))