| @@ -1,12 +1,32 @@ | |||
| import os | |||
| from typing import Any, Optional, Union | |||
| from typing import Any, Optional, TextIO, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.input import print_text | |||
| from pydantic import BaseModel | |||
| _TEXT_COLOR_MAPPING = { | |||
| "blue": "36;1", | |||
| "yellow": "33;1", | |||
| "pink": "38;5;200", | |||
| "green": "32;1", | |||
| "red": "31;1", | |||
| } | |||
| class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): | |||
| def get_colored_text(text: str, color: str) -> str: | |||
| """Get colored text.""" | |||
| color_str = _TEXT_COLOR_MAPPING[color] | |||
| return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" | |||
| def print_text( | |||
| text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None | |||
| ) -> None: | |||
| """Print text with highlighting and no end characters.""" | |||
| text_to_print = get_colored_text(text, color) if color else text | |||
| print(text_to_print, end=end, file=file) | |||
| if file: | |||
| file.flush() # ensure all printed content are written to file | |||
| class DifyAgentCallbackHandler(BaseModel): | |||
| """Callback Handler that prints to std out.""" | |||
| color: Optional[str] = '' | |||
| current_loop = 1 | |||
| @@ -1,11 +1,92 @@ | |||
| from typing import Any | |||
| import logging | |||
| from typing import Any, Optional | |||
| from langchain.utilities import ArxivAPIWrapper | |||
| import arxiv | |||
| from pydantic import BaseModel, Field | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| logger = logging.getLogger(__name__) | |||
| class ArxivAPIWrapper(BaseModel): | |||
| """Wrapper around ArxivAPI. | |||
| To use, you should have the ``arxiv`` python package installed. | |||
| https://lukasschwab.me/arxiv.py/index.html | |||
| This wrapper will use the Arxiv API to conduct searches and | |||
| fetch document summaries. By default, it will return the document summaries | |||
| of the top-k results. | |||
| It limits the Document content by doc_content_chars_max. | |||
| Set doc_content_chars_max=None if you don't want to limit the content size. | |||
| Args: | |||
| top_k_results: number of the top-scored document used for the arxiv tool | |||
| ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool. | |||
| load_max_docs: a limit to the number of loaded documents | |||
| load_all_available_meta: | |||
| if True: the `metadata` of the loaded Documents contains all available | |||
| meta info (see https://lukasschwab.me/arxiv.py/index.html#Result), | |||
| if False: the `metadata` contains only the published date, title, | |||
| authors and summary. | |||
| doc_content_chars_max: an optional cut limit for the length of a document's | |||
| content | |||
| Example: | |||
| .. code-block:: python | |||
| arxiv = ArxivAPIWrapper( | |||
| top_k_results = 3, | |||
| ARXIV_MAX_QUERY_LENGTH = 300, | |||
| load_max_docs = 3, | |||
| load_all_available_meta = False, | |||
| doc_content_chars_max = 40000 | |||
| ) | |||
| arxiv.run("tree of thought llm) | |||
| """ | |||
| arxiv_search = arxiv.Search #: :meta private: | |||
| arxiv_exceptions = ( | |||
| arxiv.ArxivError, | |||
| arxiv.UnexpectedEmptyPageError, | |||
| arxiv.HTTPError, | |||
| ) # :meta private: | |||
| top_k_results: int = 3 | |||
| ARXIV_MAX_QUERY_LENGTH = 300 | |||
| load_max_docs: int = 100 | |||
| load_all_available_meta: bool = False | |||
| doc_content_chars_max: Optional[int] = 4000 | |||
| def run(self, query: str) -> str: | |||
| """ | |||
| Performs an arxiv search and A single string | |||
| with the publish date, title, authors, and summary | |||
| for each article separated by two newlines. | |||
| If an error occurs or no documents found, error text | |||
| is returned instead. Wrapper for | |||
| https://lukasschwab.me/arxiv.py/index.html#Search | |||
| Args: | |||
| query: a plaintext search query | |||
| """ # noqa: E501 | |||
| try: | |||
| results = self.arxiv_search( # type: ignore | |||
| query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results | |||
| ).results() | |||
| except self.arxiv_exceptions as ex: | |||
| return f"Arxiv exception: {ex}" | |||
| docs = [ | |||
| f"Published: {result.updated.date()}\n" | |||
| f"Title: {result.title}\n" | |||
| f"Authors: {', '.join(a.name for a in result.authors)}\n" | |||
| f"Summary: {result.summary}" | |||
| for result in results | |||
| ] | |||
| if docs: | |||
| return "\n\n".join(docs)[: self.doc_content_chars_max] | |||
| else: | |||
| return "No good Arxiv Result was found" | |||
| class ArxivSearchInput(BaseModel): | |||
| query: str = Field(..., description="Search query.") | |||
| @@ -1,11 +1,95 @@ | |||
| from typing import Any | |||
| import json | |||
| from typing import Any, Optional | |||
| from langchain.tools import BraveSearch | |||
| import requests | |||
| from pydantic import BaseModel, Field | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| class BraveSearchWrapper(BaseModel): | |||
| """Wrapper around the Brave search engine.""" | |||
| api_key: str | |||
| """The API key to use for the Brave search engine.""" | |||
| search_kwargs: dict = Field(default_factory=dict) | |||
| """Additional keyword arguments to pass to the search request.""" | |||
| base_url = "https://api.search.brave.com/res/v1/web/search" | |||
| """The base URL for the Brave search engine.""" | |||
| def run(self, query: str) -> str: | |||
| """Query the Brave search engine and return the results as a JSON string. | |||
| Args: | |||
| query: The query to search for. | |||
| Returns: The results as a JSON string. | |||
| """ | |||
| web_search_results = self._search_request(query=query) | |||
| final_results = [ | |||
| { | |||
| "title": item.get("title"), | |||
| "link": item.get("url"), | |||
| "snippet": item.get("description"), | |||
| } | |||
| for item in web_search_results | |||
| ] | |||
| return json.dumps(final_results) | |||
| def _search_request(self, query: str) -> list[dict]: | |||
| headers = { | |||
| "X-Subscription-Token": self.api_key, | |||
| "Accept": "application/json", | |||
| } | |||
| req = requests.PreparedRequest() | |||
| params = {**self.search_kwargs, **{"q": query}} | |||
| req.prepare_url(self.base_url, params) | |||
| if req.url is None: | |||
| raise ValueError("prepared url is None, this should not happen") | |||
| response = requests.get(req.url, headers=headers) | |||
| if not response.ok: | |||
| raise Exception(f"HTTP error {response.status_code}") | |||
| return response.json().get("web", {}).get("results", []) | |||
| class BraveSearch(BaseModel): | |||
| """Tool that queries the BraveSearch.""" | |||
| name = "brave_search" | |||
| description = ( | |||
| "a search engine. " | |||
| "useful for when you need to answer questions about current events." | |||
| " input should be a search query." | |||
| ) | |||
| search_wrapper: BraveSearchWrapper | |||
| @classmethod | |||
| def from_api_key( | |||
| cls, api_key: str, search_kwargs: Optional[dict] = None, **kwargs: Any | |||
| ) -> "BraveSearch": | |||
| """Create a tool from an api key. | |||
| Args: | |||
| api_key: The api key to use. | |||
| search_kwargs: Any additional kwargs to pass to the search wrapper. | |||
| **kwargs: Any additional kwargs to pass to the tool. | |||
| Returns: | |||
| A tool. | |||
| """ | |||
| wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {}) | |||
| return cls(search_wrapper=wrapper, **kwargs) | |||
| def _run( | |||
| self, | |||
| query: str, | |||
| ) -> str: | |||
| """Use the tool.""" | |||
| return self.search_wrapper.run(query) | |||
| class BraveSearchTool(BuiltinTool): | |||
| """ | |||
| Tool for performing a search using Brave search engine. | |||
| @@ -31,7 +115,7 @@ class BraveSearchTool(BuiltinTool): | |||
| tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": count}) | |||
| results = tool.run(query) | |||
| results = tool._run(query) | |||
| if not results: | |||
| return self.create_text_message(f"No results found for '{query}' in Tavily") | |||
| @@ -1,16 +1,147 @@ | |||
| from typing import Any | |||
| from typing import Any, Optional | |||
| from langchain.tools import DuckDuckGoSearchRun | |||
| from pydantic import BaseModel, Field | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| class DuckDuckGoSearchAPIWrapper(BaseModel): | |||
| """Wrapper for DuckDuckGo Search API. | |||
| Free and does not require any setup. | |||
| """ | |||
| region: Optional[str] = "wt-wt" | |||
| safesearch: str = "moderate" | |||
| time: Optional[str] = "y" | |||
| max_results: int = 5 | |||
| def get_snippets(self, query: str) -> list[str]: | |||
| """Run query through DuckDuckGo and return concatenated results.""" | |||
| from duckduckgo_search import DDGS | |||
| with DDGS() as ddgs: | |||
| results = ddgs.text( | |||
| query, | |||
| region=self.region, | |||
| safesearch=self.safesearch, | |||
| timelimit=self.time, | |||
| ) | |||
| if results is None: | |||
| return ["No good DuckDuckGo Search Result was found"] | |||
| snippets = [] | |||
| for i, res in enumerate(results, 1): | |||
| if res is not None: | |||
| snippets.append(res["body"]) | |||
| if len(snippets) == self.max_results: | |||
| break | |||
| return snippets | |||
| def run(self, query: str) -> str: | |||
| snippets = self.get_snippets(query) | |||
| return " ".join(snippets) | |||
| def results( | |||
| self, query: str, num_results: int, backend: str = "api" | |||
| ) -> list[dict[str, str]]: | |||
| """Run query through DuckDuckGo and return metadata. | |||
| Args: | |||
| query: The query to search for. | |||
| num_results: The number of results to return. | |||
| Returns: | |||
| A list of dictionaries with the following keys: | |||
| snippet - The description of the result. | |||
| title - The title of the result. | |||
| link - The link to the result. | |||
| """ | |||
| from duckduckgo_search import DDGS | |||
| with DDGS() as ddgs: | |||
| results = ddgs.text( | |||
| query, | |||
| region=self.region, | |||
| safesearch=self.safesearch, | |||
| timelimit=self.time, | |||
| backend=backend, | |||
| ) | |||
| if results is None: | |||
| return [{"Result": "No good DuckDuckGo Search Result was found"}] | |||
| def to_metadata(result: dict) -> dict[str, str]: | |||
| if backend == "news": | |||
| return { | |||
| "date": result["date"], | |||
| "title": result["title"], | |||
| "snippet": result["body"], | |||
| "source": result["source"], | |||
| "link": result["url"], | |||
| } | |||
| return { | |||
| "snippet": result["body"], | |||
| "title": result["title"], | |||
| "link": result["href"], | |||
| } | |||
| formatted_results = [] | |||
| for i, res in enumerate(results, 1): | |||
| if res is not None: | |||
| formatted_results.append(to_metadata(res)) | |||
| if len(formatted_results) == num_results: | |||
| break | |||
| return formatted_results | |||
| class DuckDuckGoSearchRun(BaseModel): | |||
| """Tool that queries the DuckDuckGo search API.""" | |||
| name = "duckduckgo_search" | |||
| description = ( | |||
| "A wrapper around DuckDuckGo Search. " | |||
| "Useful for when you need to answer questions about current events. " | |||
| "Input should be a search query." | |||
| ) | |||
| api_wrapper: DuckDuckGoSearchAPIWrapper = Field( | |||
| default_factory=DuckDuckGoSearchAPIWrapper | |||
| ) | |||
| def _run( | |||
| self, | |||
| query: str, | |||
| ) -> str: | |||
| """Use the tool.""" | |||
| return self.api_wrapper.run(query) | |||
| class DuckDuckGoSearchResults(BaseModel): | |||
| """Tool that queries the DuckDuckGo search API and gets back json.""" | |||
| name = "DuckDuckGo Results JSON" | |||
| description = ( | |||
| "A wrapper around Duck Duck Go Search. " | |||
| "Useful for when you need to answer questions about current events. " | |||
| "Input should be a search query. Output is a JSON array of the query results" | |||
| ) | |||
| num_results: int = 4 | |||
| api_wrapper: DuckDuckGoSearchAPIWrapper = Field( | |||
| default_factory=DuckDuckGoSearchAPIWrapper | |||
| ) | |||
| backend: str = "api" | |||
| def _run( | |||
| self, | |||
| query: str, | |||
| ) -> str: | |||
| """Use the tool.""" | |||
| res = self.api_wrapper.results(query, self.num_results, backend=self.backend) | |||
| res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res] | |||
| return ", ".join([f"[{rs}]" for rs in res_strs]) | |||
| class DuckDuckGoInput(BaseModel): | |||
| query: str = Field(..., description="Search query.") | |||
| class DuckDuckGoSearchTool(BuiltinTool): | |||
| """ | |||
| Tool for performing a search using DuckDuckGo search engine. | |||
| @@ -34,7 +165,7 @@ class DuckDuckGoSearchTool(BuiltinTool): | |||
| tool = DuckDuckGoSearchRun(args_schema=DuckDuckGoInput) | |||
| result = tool.run(query) | |||
| result = tool._run(query) | |||
| return self.create_text_message(self.summary(user_id=user_id, content=result)) | |||
| @@ -1,16 +1,187 @@ | |||
| import json | |||
| import time | |||
| import urllib.error | |||
| import urllib.parse | |||
| import urllib.request | |||
| from typing import Any | |||
| from langchain.tools import PubmedQueryRun | |||
| from pydantic import BaseModel, Field | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| class PubMedAPIWrapper(BaseModel): | |||
| """ | |||
| Wrapper around PubMed API. | |||
| This wrapper will use the PubMed API to conduct searches and fetch | |||
| document summaries. By default, it will return the document summaries | |||
| of the top-k results of an input search. | |||
| Parameters: | |||
| top_k_results: number of the top-scored document used for the PubMed tool | |||
| load_max_docs: a limit to the number of loaded documents | |||
| load_all_available_meta: | |||
| if True: the `metadata` of the loaded Documents gets all available meta info | |||
| (see https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch) | |||
| if False: the `metadata` gets only the most informative fields. | |||
| """ | |||
| base_url_esearch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?" | |||
| base_url_efetch = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?" | |||
| max_retry = 5 | |||
| sleep_time = 0.2 | |||
| # Default values for the parameters | |||
| top_k_results: int = 3 | |||
| load_max_docs: int = 25 | |||
| ARXIV_MAX_QUERY_LENGTH = 300 | |||
| doc_content_chars_max: int = 2000 | |||
| load_all_available_meta: bool = False | |||
| email: str = "your_email@example.com" | |||
| def run(self, query: str) -> str: | |||
| """ | |||
| Run PubMed search and get the article meta information. | |||
| See https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESearch | |||
| It uses only the most informative fields of article meta information. | |||
| """ | |||
| try: | |||
| # Retrieve the top-k results for the query | |||
| docs = [ | |||
| f"Published: {result['pub_date']}\nTitle: {result['title']}\n" | |||
| f"Summary: {result['summary']}" | |||
| for result in self.load(query[: self.ARXIV_MAX_QUERY_LENGTH]) | |||
| ] | |||
| # Join the results and limit the character count | |||
| return ( | |||
| "\n\n".join(docs)[:self.doc_content_chars_max] | |||
| if docs | |||
| else "No good PubMed Result was found" | |||
| ) | |||
| except Exception as ex: | |||
| return f"PubMed exception: {ex}" | |||
| def load(self, query: str) -> list[dict]: | |||
| """ | |||
| Search PubMed for documents matching the query. | |||
| Return a list of dictionaries containing the document metadata. | |||
| """ | |||
| url = ( | |||
| self.base_url_esearch | |||
| + "db=pubmed&term=" | |||
| + str({urllib.parse.quote(query)}) | |||
| + f"&retmode=json&retmax={self.top_k_results}&usehistory=y" | |||
| ) | |||
| result = urllib.request.urlopen(url) | |||
| text = result.read().decode("utf-8") | |||
| json_text = json.loads(text) | |||
| articles = [] | |||
| webenv = json_text["esearchresult"]["webenv"] | |||
| for uid in json_text["esearchresult"]["idlist"]: | |||
| article = self.retrieve_article(uid, webenv) | |||
| articles.append(article) | |||
| # Convert the list of articles to a JSON string | |||
| return articles | |||
| def retrieve_article(self, uid: str, webenv: str) -> dict: | |||
| url = ( | |||
| self.base_url_efetch | |||
| + "db=pubmed&retmode=xml&id=" | |||
| + uid | |||
| + "&webenv=" | |||
| + webenv | |||
| ) | |||
| retry = 0 | |||
| while True: | |||
| try: | |||
| result = urllib.request.urlopen(url) | |||
| break | |||
| except urllib.error.HTTPError as e: | |||
| if e.code == 429 and retry < self.max_retry: | |||
| # Too Many Requests error | |||
| # wait for an exponentially increasing amount of time | |||
| print( | |||
| f"Too Many Requests, " | |||
| f"waiting for {self.sleep_time:.2f} seconds..." | |||
| ) | |||
| time.sleep(self.sleep_time) | |||
| self.sleep_time *= 2 | |||
| retry += 1 | |||
| else: | |||
| raise e | |||
| xml_text = result.read().decode("utf-8") | |||
| # Get title | |||
| title = "" | |||
| if "<ArticleTitle>" in xml_text and "</ArticleTitle>" in xml_text: | |||
| start_tag = "<ArticleTitle>" | |||
| end_tag = "</ArticleTitle>" | |||
| title = xml_text[ | |||
| xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) | |||
| ] | |||
| # Get abstract | |||
| abstract = "" | |||
| if "<AbstractText>" in xml_text and "</AbstractText>" in xml_text: | |||
| start_tag = "<AbstractText>" | |||
| end_tag = "</AbstractText>" | |||
| abstract = xml_text[ | |||
| xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) | |||
| ] | |||
| # Get publication date | |||
| pub_date = "" | |||
| if "<PubDate>" in xml_text and "</PubDate>" in xml_text: | |||
| start_tag = "<PubDate>" | |||
| end_tag = "</PubDate>" | |||
| pub_date = xml_text[ | |||
| xml_text.index(start_tag) + len(start_tag) : xml_text.index(end_tag) | |||
| ] | |||
| # Return article as dictionary | |||
| article = { | |||
| "uid": uid, | |||
| "title": title, | |||
| "summary": abstract, | |||
| "pub_date": pub_date, | |||
| } | |||
| return article | |||
| class PubmedQueryRun(BaseModel): | |||
| """Tool that searches the PubMed API.""" | |||
| name = "PubMed" | |||
| description = ( | |||
| "A wrapper around PubMed.org " | |||
| "Useful for when you need to answer questions about Physics, Mathematics, " | |||
| "Computer Science, Quantitative Biology, Quantitative Finance, Statistics, " | |||
| "Electrical Engineering, and Economics " | |||
| "from scientific articles on PubMed.org. " | |||
| "Input should be a search query." | |||
| ) | |||
| api_wrapper: PubMedAPIWrapper = Field(default_factory=PubMedAPIWrapper) | |||
| def _run( | |||
| self, | |||
| query: str, | |||
| ) -> str: | |||
| """Use the Arxiv tool.""" | |||
| return self.api_wrapper.run(query) | |||
| class PubMedInput(BaseModel): | |||
| query: str = Field(..., description="Search query.") | |||
| class PubMedSearchTool(BuiltinTool): | |||
| """ | |||
| Tool for performing a search using PubMed search engine. | |||
| @@ -34,7 +205,7 @@ class PubMedSearchTool(BuiltinTool): | |||
| tool = PubmedQueryRun(args_schema=PubMedInput) | |||
| result = tool.run(query) | |||
| result = tool._run(query) | |||
| return self.create_text_message(self.summary(user_id=user_id, content=result)) | |||
| @@ -1,11 +1,81 @@ | |||
| from typing import Any, Union | |||
| from typing import Any, Optional, Union | |||
| from langchain.utilities import TwilioAPIWrapper | |||
| from pydantic import BaseModel, validator | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| class TwilioAPIWrapper(BaseModel): | |||
| """Messaging Client using Twilio. | |||
| To use, you should have the ``twilio`` python package installed, | |||
| and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and | |||
| ``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as | |||
| named parameters to the constructor. | |||
| Example: | |||
| .. code-block:: python | |||
| from langchain.utilities.twilio import TwilioAPIWrapper | |||
| twilio = TwilioAPIWrapper( | |||
| account_sid="ACxxx", | |||
| auth_token="xxx", | |||
| from_number="+10123456789" | |||
| ) | |||
| twilio.run('test', '+12484345508') | |||
| """ | |||
| client: Any #: :meta private: | |||
| account_sid: Optional[str] = None | |||
| """Twilio account string identifier.""" | |||
| auth_token: Optional[str] = None | |||
| """Twilio auth token.""" | |||
| from_number: Optional[str] = None | |||
| """A Twilio phone number in [E.164](https://www.twilio.com/docs/glossary/what-e164) | |||
| format, an | |||
| [alphanumeric sender ID](https://www.twilio.com/docs/sms/send-messages#use-an-alphanumeric-sender-id), | |||
| or a [Channel Endpoint address](https://www.twilio.com/docs/sms/channels#channel-addresses) | |||
| that is enabled for the type of message you want to send. Phone numbers or | |||
| [short codes](https://www.twilio.com/docs/sms/api/short-code) purchased from | |||
| Twilio also work here. You cannot, for example, spoof messages from a private | |||
| cell phone number. If you are using `messaging_service_sid`, this parameter | |||
| must be empty. | |||
| """ # noqa: E501 | |||
| @validator("client", pre=True, always=True) | |||
| def set_validator(cls, values: dict) -> dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| try: | |||
| from twilio.rest import Client | |||
| except ImportError: | |||
| raise ImportError( | |||
| "Could not import twilio python package. " | |||
| "Please install it with `pip install twilio`." | |||
| ) | |||
| account_sid = values.get("account_sid") | |||
| auth_token = values.get("auth_token") | |||
| values["from_number"] = values.get("from_number") | |||
| values["client"] = Client(account_sid, auth_token) | |||
| return values | |||
| def run(self, body: str, to: str) -> str: | |||
| """Run body through Twilio and respond with message sid. | |||
| Args: | |||
| body: The text of the message you want to send. Can be up to 1,600 | |||
| characters in length. | |||
| to: The destination phone number in | |||
| [E.164](https://www.twilio.com/docs/glossary/what-e164) format for | |||
| SMS/MMS or | |||
| [Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses) | |||
| for other 3rd-party channels. | |||
| """ # noqa: E501 | |||
| message = self.client.messages.create(to, from_=self.from_number, body=body) | |||
| return message.sid | |||
| class SendMessageTool(BuiltinTool): | |||
| """ | |||
| A tool for sending messages using Twilio API. | |||
| @@ -1,16 +1,79 @@ | |||
| from typing import Any, Union | |||
| from typing import Any, Optional, Union | |||
| from langchain import WikipediaAPIWrapper | |||
| from langchain.tools import WikipediaQueryRun | |||
| from pydantic import BaseModel, Field | |||
| import wikipedia | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| from core.tools.tool.builtin_tool import BuiltinTool | |||
| WIKIPEDIA_MAX_QUERY_LENGTH = 300 | |||
| class WikipediaInput(BaseModel): | |||
| query: str = Field(..., description="search query.") | |||
| class WikipediaAPIWrapper: | |||
| """Wrapper around WikipediaAPI. | |||
| To use, you should have the ``wikipedia`` python package installed. | |||
| This wrapper will use the Wikipedia API to conduct searches and | |||
| fetch page summaries. By default, it will return the page summaries | |||
| of the top-k results. | |||
| It limits the Document content by doc_content_chars_max. | |||
| """ | |||
| top_k_results: int = 3 | |||
| lang: str = "en" | |||
| load_all_available_meta: bool = False | |||
| doc_content_chars_max: int = 4000 | |||
| def __init__(self, doc_content_chars_max: int = 4000): | |||
| self.doc_content_chars_max = doc_content_chars_max | |||
| def run(self, query: str) -> str: | |||
| wikipedia.set_lang(self.lang) | |||
| wiki_client = wikipedia | |||
| """Run Wikipedia search and get page summaries.""" | |||
| page_titles = wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH]) | |||
| summaries = [] | |||
| for page_title in page_titles[: self.top_k_results]: | |||
| if wiki_page := self._fetch_page(page_title): | |||
| if summary := self._formatted_page_summary(page_title, wiki_page): | |||
| summaries.append(summary) | |||
| if not summaries: | |||
| return "No good Wikipedia Search Result was found" | |||
| return "\n\n".join(summaries)[: self.doc_content_chars_max] | |||
| @staticmethod | |||
| def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]: | |||
| return f"Page: {page_title}\nSummary: {wiki_page.summary}" | |||
| def _fetch_page(self, page: str) -> Optional[str]: | |||
| try: | |||
| return wikipedia.page(title=page, auto_suggest=False) | |||
| except ( | |||
| wikipedia.exceptions.PageError, | |||
| wikipedia.exceptions.DisambiguationError, | |||
| ): | |||
| return None | |||
| class WikipediaQueryRun: | |||
| """Tool that searches the Wikipedia API.""" | |||
| name = "Wikipedia" | |||
| description = ( | |||
| "A wrapper around Wikipedia. " | |||
| "Useful for when you need to answer general questions about " | |||
| "people, places, companies, facts, historical events, or other subjects. " | |||
| "Input should be a search query." | |||
| ) | |||
| api_wrapper: WikipediaAPIWrapper | |||
| def __init__(self, api_wrapper: WikipediaAPIWrapper): | |||
| self.api_wrapper = api_wrapper | |||
| def _run( | |||
| self, | |||
| query: str, | |||
| ) -> str: | |||
| """Use the Wikipedia tool.""" | |||
| return self.api_wrapper.run(query) | |||
| class WikiPediaSearchTool(BuiltinTool): | |||
| def _invoke(self, | |||
| user_id: str, | |||
| @@ -24,14 +87,10 @@ class WikiPediaSearchTool(BuiltinTool): | |||
| return self.create_text_message('Please input query') | |||
| tool = WikipediaQueryRun( | |||
| name="wikipedia", | |||
| api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), | |||
| args_schema=WikipediaInput | |||
| ) | |||
| result = tool.run(tool_input={ | |||
| 'query': query | |||
| }) | |||
| result = tool._run(query) | |||
| return self.create_text_message(self.summary(user_id=user_id,content=result)) | |||