|
|
|
@@ -7,6 +7,7 @@ from core.tools.tool.builtin_tool import BuiltinTool |
|
|
|
|
|
|
|
WIKIPEDIA_MAX_QUERY_LENGTH = 300 |
|
|
|
|
|
|
|
|
|
|
|
class WikipediaAPIWrapper: |
|
|
|
"""Wrapper around WikipediaAPI. |
|
|
|
|
|
|
|
@@ -25,7 +26,10 @@ class WikipediaAPIWrapper: |
|
|
|
def __init__(self, doc_content_chars_max: int = 4000): |
|
|
|
self.doc_content_chars_max = doc_content_chars_max |
|
|
|
|
|
|
|
def run(self, query: str) -> str: |
|
|
|
def run(self, query: str, lang: str = "") -> str: |
|
|
|
if lang in wikipedia.languages().keys(): |
|
|
|
self.lang = lang |
|
|
|
|
|
|
|
wikipedia.set_lang(self.lang) |
|
|
|
wiki_client = wikipedia |
|
|
|
|
|
|
|
@@ -53,6 +57,7 @@ class WikipediaAPIWrapper: |
|
|
|
): |
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
class WikipediaQueryRun: |
|
|
|
"""Tool that searches the Wikipedia API.""" |
|
|
|
|
|
|
|
@@ -71,26 +76,31 @@ class WikipediaQueryRun: |
|
|
|
def _run( |
|
|
|
self, |
|
|
|
query: str, |
|
|
|
lang: str = "", |
|
|
|
) -> str: |
|
|
|
"""Use the Wikipedia tool.""" |
|
|
|
return self.api_wrapper.run(query) |
|
|
|
return self.api_wrapper.run(query, lang) |
|
|
|
|
|
|
|
|
|
|
|
class WikiPediaSearchTool(BuiltinTool): |
|
|
|
def _invoke(self, |
|
|
|
user_id: str, |
|
|
|
tool_parameters: dict[str, Any], |
|
|
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: |
|
|
|
|
|
|
|
def _invoke( |
|
|
|
self, |
|
|
|
user_id: str, |
|
|
|
tool_parameters: dict[str, Any], |
|
|
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: |
|
|
|
""" |
|
|
|
invoke tools |
|
|
|
invoke tools |
|
|
|
""" |
|
|
|
query = tool_parameters.get('query', '') |
|
|
|
query = tool_parameters.get("query", "") |
|
|
|
lang = tool_parameters.get("language", "") |
|
|
|
if not query: |
|
|
|
return self.create_text_message('Please input query') |
|
|
|
|
|
|
|
return self.create_text_message("Please input query") |
|
|
|
|
|
|
|
tool = WikipediaQueryRun( |
|
|
|
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), |
|
|
|
) |
|
|
|
|
|
|
|
result = tool._run(query) |
|
|
|
result = tool._run(query, lang) |
|
|
|
|
|
|
|
return self.create_text_message(self.summary(user_id=user_id,content=result)) |
|
|
|
|
|
|
|
return self.create_text_message(self.summary(user_id=user_id, content=result)) |