### What problem does this PR solve? Hello, our use case requires LLM agent to invoke some tools, so I made a simple implementation here. This PR does two things: 1. A simple plugin mechanism based on `pluginlib`: This mechanism lives in the `plugin` directory. It will only load plugins from `plugin/embedded_plugins` for now. A sample plugin `bad_calculator.py` is placed in `plugin/embedded_plugins/llm_tools`, it accepts two numbers `a` and `b`, then give a wrong result `a + b + 100`. In the future, it can load plugins from external location with little code change. Plugins are divided into different types. The only plugin type supported in this PR is `llm_tools`, which must implement the `LLMToolPlugin` class in the `plugin/llm_tool_plugin.py`. More plugin types can be added in the future. 2. A tool selector in the `Generate` component: Added a tool selector to select one or more tools for LLM:  And with the `bad_calculator` tool, it results this with the `qwen-max` model:  ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe): Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>tags/v0.19.0
| @@ -199,6 +199,7 @@ COPY graphrag graphrag | |||
| COPY agentic_reasoning agentic_reasoning | |||
| COPY pyproject.toml uv.lock ./ | |||
| COPY mcp mcp | |||
| COPY plugin plugin | |||
| COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template | |||
| COPY docker/entrypoint.sh ./ | |||
| @@ -33,6 +33,7 @@ ADD ./rag ./rag | |||
| ADD ./requirements.txt ./requirements.txt | |||
| ADD ./agent ./agent | |||
| ADD ./graphrag ./graphrag | |||
| ADD ./plugin ./plugin | |||
| RUN dnf install -y openmpi openmpi-devel python3-openmpi | |||
| ENV C_INCLUDE_PATH /usr/include/openmpi-x86_64:$C_INCLUDE_PATH | |||
| @@ -16,15 +16,29 @@ | |||
| import json | |||
| import re | |||
| from functools import partial | |||
| from typing import Any | |||
| import pandas as pd | |||
| from api.db import LLMType | |||
| from api.db.services.conversation_service import structure_answer | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api import settings | |||
| from agent.component.base import ComponentBase, ComponentParamBase | |||
| from plugin import GlobalPluginManager | |||
| from plugin.llm_tool_plugin import llm_tool_metadata_to_openai_tool | |||
| from rag.llm.chat_model import ToolCallSession | |||
| from rag.prompts import message_fit_in | |||
| class LLMToolPluginCallSession(ToolCallSession): | |||
| def tool_call(self, name: str, arguments: dict[str, Any]) -> str: | |||
| tool = GlobalPluginManager.get_llm_tool_by_name(name) | |||
| if tool is None: | |||
| raise ValueError(f"LLM tool {name} does not exist") | |||
| return tool().invoke(**arguments) | |||
| class GenerateParam(ComponentParamBase): | |||
| """ | |||
| Define the Generate component parameters. | |||
| @@ -41,6 +55,7 @@ class GenerateParam(ComponentParamBase): | |||
| self.frequency_penalty = 0 | |||
| self.cite = True | |||
| self.parameters = [] | |||
| self.llm_enabled_tools = [] | |||
| def check(self): | |||
| self.check_decimal_float(self.temperature, "[Generate] Temperature") | |||
| @@ -133,6 +148,15 @@ class Generate(ComponentBase): | |||
| def _run(self, history, **kwargs): | |||
| chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) | |||
| if len(self._param.llm_enabled_tools) > 0: | |||
| tools = GlobalPluginManager.get_llm_tools_by_names(self._param.llm_enabled_tools) | |||
| chat_mdl.bind_tools( | |||
| LLMToolPluginCallSession(), | |||
| [llm_tool_metadata_to_openai_tool(t.get_metadata()) for t in tools] | |||
| ) | |||
| prompt = self._param.prompt | |||
| retrieval_res = [] | |||
| @@ -0,0 +1,12 @@ | |||
| from flask import Response | |||
| from flask_login import login_required | |||
| from api.utils.api_utils import get_json_result | |||
| from plugin import GlobalPluginManager | |||
| @manager.route('/llm_tools', methods=['GET']) # noqa: F821 | |||
| @login_required | |||
| def llm_tools() -> Response: | |||
| tools = GlobalPluginManager.get_llm_tools() | |||
| tools_metadata = [t.get_metadata() for t in tools] | |||
| return get_json_result(data=tools_metadata) | |||
| @@ -226,6 +226,7 @@ class LLMBundle: | |||
| def bind_tools(self, toolcall_session, tools): | |||
| if not self.is_tools: | |||
| logging.warning(f"Model {self.llm_name} does not support tool call, but you have assigned one or more tools to it!") | |||
| return | |||
| self.mdl.bind_tools(toolcall_session, tools) | |||
| @@ -19,6 +19,7 @@ | |||
| # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code | |||
| from api.utils.log_utils import initRootLogger | |||
| from plugin import GlobalPluginManager | |||
| initRootLogger("ragflow_server") | |||
| import logging | |||
| @@ -119,6 +120,8 @@ if __name__ == '__main__': | |||
| RuntimeConfig.init_env() | |||
| RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT) | |||
| GlobalPluginManager.load_plugins() | |||
| signal.signal(signal.SIGINT, signal_handler) | |||
| signal.signal(signal.SIGTERM, signal_handler) | |||
| @@ -0,0 +1,97 @@ | |||
| # Plugins | |||
| This directory contains the plugin mechanism for RAGFlow. | |||
| RAGFlow will load plugins from `embedded_plugins` subdirectory recursively. | |||
| ## Supported plugin types | |||
| Currently, the only supported plugin type is `llm_tools`. | |||
| - `llm_tools`: A tool for LLM to call. | |||
| ## How to add a plugin | |||
| Add a LLM tool plugin is simple: create a plugin file, put a class inherits the `LLMToolPlugin` class in it, then implement the `get_metadata` and the `invoke` methods. | |||
| - `get_metadata` method: This method returns a `LLMToolMetadata` object, which contains the description of this tool. | |||
| The description will be provided to LLM, and the RAGFlow web frontend for displaying. | |||
| - `invoke` method: This method accepts parameters generated by LLM, and return a `str` containing the tool execution result. | |||
| All the execution logic of this tool should go into this method. | |||
| When you start RAGFlow, you can see your plugin was loaded in the log: | |||
| ``` | |||
| 2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/plugin/embedded_plugins` | |||
| 2025-05-15 19:29:08,960 INFO 34670 Loaded llm_tools plugin BadCalculatorPlugin version 1.0.0 | |||
| ``` | |||
| Or it may contain some errors for you to fix your plugin. | |||
| ### Demo | |||
| We will demonstrate how to add a plugin with a calculator tool which will give wrong answers. | |||
| First, create a plugin file `bad_calculator.py` under the `embedded_plugins/llm_tools` directory. | |||
| Then, we create a `BadCalculatorPlugin` class, extending the `LLMToolPlugin` base class: | |||
| ```python | |||
| class BadCalculatorPlugin(LLMToolPlugin): | |||
| _version_ = "1.0.0" | |||
| ``` | |||
| The `_version_` field is required, which specifies the version of the plugin. | |||
| Our calculator has two numbers `a` and `b` as inputs, so we add a `invoke` method to our `BadCalculatorPlugin` class: | |||
| ```python | |||
| def invoke(self, a: int, b: int) -> str: | |||
| return str(a + b + 100) | |||
| ``` | |||
| The `invoke` method will be called by LLM. It can have many parameters, but the return type must be a `str`. | |||
| Finally, we have to add a `get_metadata` method, to tell LLM how to use our `bad_calculator`: | |||
| ```python | |||
| @classmethod | |||
| def get_metadata(cls) -> LLMToolMetadata: | |||
| return { | |||
| # Name of this tool, providing to LLM | |||
| "name": "bad_calculator", | |||
| # Display name of this tool, providing to RAGFlow frontend | |||
| "displayName": "$t:bad_calculator.name", | |||
| # Description of the usage of this tool, providing to LLM | |||
| "description": "A tool to calculate the sum of two numbers (will give wrong answer)", | |||
| # Description of this tool, providing to RAGFlow frontend | |||
| "displayDescription": "$t:bad_calculator.description", | |||
| # Parameters of this tool | |||
| "parameters": { | |||
| # The first parameter - a | |||
| "a": { | |||
| # Parameter type, options are: number, string, or whatever the LLM can recognise | |||
| "type": "number", | |||
| # Description of this parameter, providing to LLM | |||
| "description": "The first number", | |||
| # Description of this parameter, provding to RAGFlow frontend | |||
| "displayDescription": "$t:bad_calculator.params.a", | |||
| # Whether this parameter is required | |||
| "required": True | |||
| }, | |||
| # The second parameter - b | |||
| "b": { | |||
| "type": "number", | |||
| "description": "The second number", | |||
| "displayDescription": "$t:bad_calculator.params.b", | |||
| "required": True | |||
| } | |||
| } | |||
| ``` | |||
| The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM. | |||
| The fields starts with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation. | |||
| Now our tool is ready. You can select it in the `Generate` component and try it out. | |||
| @@ -0,0 +1,98 @@ | |||
| # 插件 | |||
| 这个文件夹包含了RAGFlow的插件机制。 | |||
| RAGFlow将会从`embedded_plugins`子文件夹中递归加载所有的插件。 | |||
| ## 支持的插件类型 | |||
| 目前,唯一支持的插件类型是`llm_tools`。 | |||
| - `llm_tools`:用于供LLM进行调用的工具。 | |||
| ## 如何添加一个插件 | |||
| 添加一个LLM工具插件是很简单的:创建一个插件文件,向其中放一个继承自`LLMToolPlugin`的类,再实现它的`get_metadata`和`invoke`方法即可。 | |||
| - `get_metadata`方法:这个方法返回一个`LLMToolMetadata`对象,其中包含了对这个工具的描述。 | |||
| 这些描述信息将被提供给LLM进行调用,和RAGFlow的Web前端用作展示。 | |||
| - `invoke`方法:这个方法接受LLM生成的参数,并且返回一个`str`对象,其中包含了这个工具的执行结果。 | |||
| 这个工具的所有执行逻辑都应当放到这个方法里。 | |||
| 当你启动RAGFlow时,你会在日志中看见你的插件被加载了: | |||
| ``` | |||
| 2025-05-15 19:29:08,959 INFO 34670 Recursively importing plugins from path `/some-path/ragflow/plugin/embedded_plugins` | |||
| 2025-05-15 19:29:08,960 INFO 34670 Loaded llm_tools plugin BadCalculatorPlugin version 1.0.0 | |||
| ``` | |||
| 也可能会报错,这时就需要根据报错对你的插件进行修复。 | |||
| ### 示例 | |||
| 我们将会添加一个会给出错误答案的计算器工具,来演示添加插件的过程。 | |||
| 首先,在`embedded_plugins/llm_tools`文件夹下创建一个插件文件`bad_calculator.py`。 | |||
| 接下来,我们创建一个`BadCalculatorPlugin`类,继承基类`LLMToolPlugin`: | |||
| ```python | |||
| class BadCalculatorPlugin(LLMToolPlugin): | |||
| _version_ = "1.0.0" | |||
| ``` | |||
| `_version_`字段是必填的,用于指定这个插件的版本号。 | |||
| 我们的计算器拥有两个输入字段`a`和`b`,所以我们添加如下的`invoke`方法到`BadCalculatorPlugin`类中: | |||
| ```python | |||
| def invoke(self, a: int, b: int) -> str: | |||
| return str(a + b + 100) | |||
| ``` | |||
| `invoke`方法将会被LLM所调用。这个方法可以有许多参数,但它必须返回一个`str`。 | |||
| 最后,我们需要添加一个`get_metadata`方法,来告诉LLM怎样使用我们的`bad_calculator`工具: | |||
| ```python | |||
| @classmethod | |||
| def get_metadata(cls) -> LLMToolMetadata: | |||
| return { | |||
| # 这个工具的名称,会提供给LLM | |||
| "name": "bad_calculator", | |||
| # 这个工具的展示名称,会提供给RAGFlow的Web前端 | |||
| "displayName": "$t:bad_calculator.name", | |||
| # 这个工具的用法描述,会提供给LLM | |||
| "description": "A tool to calculate the sum of two numbers (will give wrong answer)", | |||
| # 这个工具的描述,会提供给RAGFlow的Web前端 | |||
| "displayDescription": "$t:bad_calculator.description", | |||
| # 这个工具的参数 | |||
| "parameters": { | |||
| # 第一个参数 - a | |||
| "a": { | |||
| # 参数类型,选项为:number, string, 或者LLM可以识别的任何类型 | |||
| "type": "number", | |||
| # 这个参数的描述,会提供给LLM | |||
| "description": "The first number", | |||
| # 这个参数的描述,会提供给RAGFlow的Web前端 | |||
| "displayDescription": "$t:bad_calculator.params.a", | |||
| # 这个参数是否是必填的 | |||
| "required": True | |||
| }, | |||
| # 第二个参数 - b | |||
| "b": { | |||
| "type": "number", | |||
| "description": "The second number", | |||
| "displayDescription": "$t:bad_calculator.params.b", | |||
| "required": True | |||
| } | |||
| } | |||
| ``` | |||
| `get_metadata`方法是一个`classmethod`。它会把这个工具的描述提供给LLM。 | |||
| 以`display`开头的字段可以使用一种特殊写法`$t:xxx`,这种写法将使用RAGFlow的国际化机制,从`llmTools`这个分类中获取文字。如果你不使用这种写法,那么前端将会显示此处的原始内容。 | |||
| 现在,我们的工具已经做好了,你可以在`生成回答`组件中选择这个工具来尝试一下。 | |||
| @@ -0,0 +1,3 @@ | |||
| from .plugin_manager import PluginManager | |||
| GlobalPluginManager = PluginManager() | |||
| @@ -0,0 +1 @@ | |||
| PLUGIN_TYPE_LLM_TOOLS = "llm_tools" | |||
| @@ -0,0 +1,37 @@ | |||
| import logging | |||
| from plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin | |||
| class BadCalculatorPlugin(LLMToolPlugin): | |||
| """ | |||
| A sample LLM tool plugin, will add two numbers with 100. | |||
| It only present for demo purpose. Do not use it in production. | |||
| """ | |||
| _version_ = "1.0.0" | |||
| @classmethod | |||
| def get_metadata(cls) -> LLMToolMetadata: | |||
| return { | |||
| "name": "bad_calculator", | |||
| "displayName": "$t:bad_calculator.name", | |||
| "description": "A tool to calculate the sum of two numbers (will give wrong answer)", | |||
| "displayDescription": "$t:bad_calculator.description", | |||
| "parameters": { | |||
| "a": { | |||
| "type": "number", | |||
| "description": "The first number", | |||
| "displayDescription": "$t:bad_calculator.params.a", | |||
| "required": True | |||
| }, | |||
| "b": { | |||
| "type": "number", | |||
| "description": "The second number", | |||
| "displayDescription": "$t:bad_calculator.params.b", | |||
| "required": True | |||
| } | |||
| } | |||
| } | |||
| def invoke(self, a: int, b: int) -> str: | |||
| logging.info(f"Bad calculator tool was called with arguments {a} and {b}") | |||
| return str(a + b + 100) | |||
| @@ -0,0 +1,51 @@ | |||
| from typing import Any, TypedDict | |||
| import pluginlib | |||
| from .common import PLUGIN_TYPE_LLM_TOOLS | |||
| class LLMToolParameter(TypedDict): | |||
| type: str | |||
| description: str | |||
| displayDescription: str | |||
| required: bool | |||
| class LLMToolMetadata(TypedDict): | |||
| name: str | |||
| displayName: str | |||
| description: str | |||
| displayDescription: str | |||
| parameters: dict[str, LLMToolParameter] | |||
| @pluginlib.Parent(PLUGIN_TYPE_LLM_TOOLS) | |||
| class LLMToolPlugin: | |||
| @classmethod | |||
| @pluginlib.abstractmethod | |||
| def get_metadata(cls) -> LLMToolMetadata: | |||
| pass | |||
| def invoke(self, **kwargs) -> str: | |||
| raise NotImplementedError | |||
| def llm_tool_metadata_to_openai_tool(llm_tool_metadata: LLMToolMetadata) -> dict[str, Any]: | |||
| return { | |||
| "type": "function", | |||
| "function": { | |||
| "name": llm_tool_metadata["name"], | |||
| "description": llm_tool_metadata["description"], | |||
| "parameters": { | |||
| "type": "object", | |||
| "properties": { | |||
| k: { | |||
| "type": p["type"], | |||
| "description": p["description"] | |||
| } | |||
| for k, p in llm_tool_metadata["parameters"].items() | |||
| }, | |||
| "required": [k for k, p in llm_tool_metadata["parameters"].items() if p["required"]] | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| import logging | |||
| import os | |||
| from pathlib import Path | |||
| import pluginlib | |||
| from .common import PLUGIN_TYPE_LLM_TOOLS | |||
| from .llm_tool_plugin import LLMToolPlugin | |||
| class PluginManager: | |||
| _llm_tool_plugins: dict[str, LLMToolPlugin] | |||
| def __init__(self) -> None: | |||
| self._llm_tool_plugins = {} | |||
| def load_plugins(self) -> None: | |||
| loader = pluginlib.PluginLoader( | |||
| paths=[str(Path(os.path.dirname(__file__), "embedded_plugins"))] | |||
| ) | |||
| for type, plugins in loader.plugins.items(): | |||
| for name, plugin in plugins.items(): | |||
| logging.info(f"Loaded {type} plugin {name} version {plugin.version}") | |||
| if type == PLUGIN_TYPE_LLM_TOOLS: | |||
| metadata = plugin.get_metadata() | |||
| self._llm_tool_plugins[metadata["name"]] = plugin | |||
| def get_llm_tools(self) -> list[LLMToolPlugin]: | |||
| return list(self._llm_tool_plugins.values()) | |||
| def get_llm_tool_by_name(self, name: str) -> LLMToolPlugin | None: | |||
| return self._llm_tool_plugins.get(name) | |||
| def get_llm_tools_by_names(self, tool_names: list[str]) -> list[LLMToolPlugin]: | |||
| results = [] | |||
| for name in tool_names: | |||
| plugin = self._llm_tool_plugins.get(name) | |||
| if plugin is not None: | |||
| results.append(plugin) | |||
| return results | |||
| @@ -125,7 +125,8 @@ dependencies = [ | |||
| "langfuse>=2.60.0", | |||
| "debugpy>=1.8.13", | |||
| "mcp>=1.6.0", | |||
| "opensearch-py==2.7.1" | |||
| "opensearch-py==2.7.1", | |||
| "pluginlib==0.9.4", | |||
| ] | |||
| [project.optional-dependencies] | |||
| @@ -21,6 +21,7 @@ import random | |||
| import re | |||
| import time | |||
| from abc import ABC | |||
| from typing import Any, Protocol | |||
| import openai | |||
| import requests | |||
| @@ -51,6 +52,10 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小 | |||
| LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." | |||
| class ToolCallSession(Protocol): | |||
| def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ... | |||
| class Base(ABC): | |||
| def __init__(self, key, model_name, base_url): | |||
| timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) | |||
| @@ -251,10 +256,8 @@ class Base(ABC): | |||
| if index not in final_tool_calls: | |||
| final_tool_calls[index] = tool_call | |||
| final_tool_calls[index].function.arguments += tool_call.function.arguments | |||
| if resp.choices[0].finish_reason != "stop": | |||
| continue | |||
| else: | |||
| final_tool_calls[index].function.arguments += tool_call.function.arguments | |||
| else: | |||
| if not resp.choices: | |||
| continue | |||
| @@ -276,58 +279,57 @@ class Base(ABC): | |||
| else: | |||
| total_tokens += tol | |||
| finish_reason = resp.choices[0].finish_reason | |||
| if finish_reason == "tool_calls" and final_tool_calls: | |||
| for tool_call in final_tool_calls.values(): | |||
| name = tool_call.function.name | |||
| try: | |||
| if name == "get_current_weather": | |||
| args = json.loads('{"location":"Shanghai"}') | |||
| else: | |||
| args = json.loads(tool_call.function.arguments) | |||
| except Exception: | |||
| continue | |||
| # args = json.loads(tool_call.function.arguments) | |||
| tool_response = self.toolcall_session.tool_call(name, args) | |||
| history.append( | |||
| { | |||
| "role": "assistant", | |||
| "refusal": "", | |||
| "content": "", | |||
| "audio": "", | |||
| "function_call": "", | |||
| "tool_calls": [ | |||
| { | |||
| "index": tool_call.index, | |||
| "id": tool_call.id, | |||
| "function": tool_call.function, | |||
| "type": "function", | |||
| finish_reason = resp.choices[0].finish_reason | |||
| if finish_reason == "tool_calls" and final_tool_calls: | |||
| for tool_call in final_tool_calls.values(): | |||
| name = tool_call.function.name | |||
| try: | |||
| args = json.loads(tool_call.function.arguments) | |||
| except Exception as e: | |||
| logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}") | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| finish_completion = True | |||
| break | |||
| tool_response = self.toolcall_session.tool_call(name, args) | |||
| history.append( | |||
| { | |||
| "role": "assistant", | |||
| "tool_calls": [ | |||
| { | |||
| "index": tool_call.index, | |||
| "id": tool_call.id, | |||
| "function": { | |||
| "name": tool_call.function.name, | |||
| "arguments": tool_call.function.arguments, | |||
| }, | |||
| ], | |||
| } | |||
| ) | |||
| # if tool_response.choices[0].finish_reason == "length": | |||
| # if is_chinese(ans): | |||
| # ans += LENGTH_NOTIFICATION_CN | |||
| # else: | |||
| # ans += LENGTH_NOTIFICATION_EN | |||
| # return ans, total_tokens + self.total_token_count(tool_response) | |||
| history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) | |||
| final_tool_calls = {} | |||
| response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) | |||
| continue | |||
| if finish_reason == "length": | |||
| if is_chinese(ans): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| return ans, total_tokens + self.total_token_count(resp) | |||
| if finish_reason == "stop": | |||
| finish_completion = True | |||
| yield ans | |||
| break | |||
| yield ans | |||
| "type": "function", | |||
| }, | |||
| ], | |||
| } | |||
| ) | |||
| # if tool_response.choices[0].finish_reason == "length": | |||
| # if is_chinese(ans): | |||
| # ans += LENGTH_NOTIFICATION_CN | |||
| # else: | |||
| # ans += LENGTH_NOTIFICATION_EN | |||
| # return ans, total_tokens + self.total_token_count(tool_response) | |||
| history.append({"role": "tool", "tool_call_id": tool_call.id, "content": str(tool_response)}) | |||
| final_tool_calls = {} | |||
| response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf) | |||
| continue | |||
| if finish_reason == "length": | |||
| if is_chinese(ans): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| return ans, total_tokens | |||
| if finish_reason == "stop": | |||
| finish_completion = True | |||
| yield ans | |||
| break | |||
| yield ans | |||
| continue | |||
| except openai.APIError as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| @@ -854,6 +856,14 @@ class ZhipuChat(Base): | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| def chat_with_tools(self, system: str, history: list, gen_conf: dict): | |||
| if "presence_penalty" in gen_conf: | |||
| del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| del gen_conf["frequency_penalty"] | |||
| return super().chat_with_tools(system, history, gen_conf) | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| @@ -886,6 +896,14 @@ class ZhipuChat(Base): | |||
| yield tk_count | |||
| def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict): | |||
| if "presence_penalty" in gen_conf: | |||
| del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: | |||
| del gen_conf["frequency_penalty"] | |||
| return super().chat_streamly_with_tools(system, history, gen_conf) | |||
| class OllamaChat(Base): | |||
| def __init__(self, key, model_name, **kwargs): | |||
| @@ -3952,6 +3952,18 @@ wheels = [ | |||
| { url = "https://mirrors.aliyun.com/pypi/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669" }, | |||
| ] | |||
| [[package]] | |||
| name = "pluginlib" | |||
| version = "0.9.4" | |||
| source = { registry = "https://mirrors.aliyun.com/pypi/simple" } | |||
| dependencies = [ | |||
| { name = "setuptools" }, | |||
| ] | |||
| sdist = { url = "https://mirrors.aliyun.com/pypi/packages/58/38/ca974ba2d8ccc7954d8ccb0394cce184ac6269bd1fbfe06f70a0da3c8946/pluginlib-0.9.4.tar.gz", hash = "sha256:88727037138f759a3952f6391ae3751536f04ad8be6023607620ea49695a3a83" } | |||
| wheels = [ | |||
| { url = "https://mirrors.aliyun.com/pypi/packages/b0/b5/c869b3d2ed1613afeb02c635be11f5d35fa5b2b665f4d059cfe5b8e82941/pluginlib-0.9.4-py2.py3-none-any.whl", hash = "sha256:d4cfb7d74a6d2454e256b6512fbc4bc2dd7620cb7764feb67331ef56ce4b33f2" }, | |||
| ] | |||
| [[package]] | |||
| name = "polars-lts-cpu" | |||
| version = "1.9.0" | |||
| @@ -4872,6 +4884,7 @@ dependencies = [ | |||
| { name = "pdfplumber" }, | |||
| { name = "peewee" }, | |||
| { name = "pillow" }, | |||
| { name = "pluginlib" }, | |||
| { name = "protobuf" }, | |||
| { name = "psycopg2-binary" }, | |||
| { name = "pyclipper" }, | |||
| @@ -5009,6 +5022,7 @@ requires-dist = [ | |||
| { name = "pdfplumber", specifier = "==0.10.4" }, | |||
| { name = "peewee", specifier = "==3.17.1" }, | |||
| { name = "pillow", specifier = "==10.4.0" }, | |||
| { name = "pluginlib", specifier = "==0.9.4" }, | |||
| { name = "protobuf", specifier = "==5.27.2" }, | |||
| { name = "psycopg2-binary", specifier = "==2.9.9" }, | |||
| { name = "pyclipper", specifier = "==1.3.0.post5" }, | |||
| @@ -11,19 +11,31 @@ import { Select, SelectTrigger, SelectValue } from '../ui/select'; | |||
| interface IProps { | |||
| id?: string; | |||
| value?: string; | |||
| onChange?: (value: string) => void; | |||
| onInitialValue?: (value: string, option: any) => void; | |||
| onChange?: (value: string, option: any) => void; | |||
| disabled?: boolean; | |||
| } | |||
| const LLMSelect = ({ id, value, onChange, disabled }: IProps) => { | |||
| const LLMSelect = ({ id, value, onInitialValue, onChange, disabled }: IProps) => { | |||
| const modelOptions = useComposeLlmOptionsByModelTypes([ | |||
| LlmModelType.Chat, | |||
| LlmModelType.Image2text, | |||
| ]); | |||
| if (onInitialValue && value) { | |||
| for (const modelOption of modelOptions) { | |||
| for (const option of modelOption.options) { | |||
| if (option.value === value) { | |||
| onInitialValue(value, option); | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| const content = ( | |||
| <div style={{ width: 400 }}> | |||
| <LlmSettingItems | |||
| <LlmSettingItems onChange={onChange} | |||
| formItemLayout={{ labelCol: { span: 10 }, wrapperCol: { span: 14 } }} | |||
| ></LlmSettingItems> | |||
| </div> | |||
| @@ -16,9 +16,10 @@ interface IProps { | |||
| prefix?: string; | |||
| formItemLayout?: any; | |||
| handleParametersChange?(value: ModelVariableType): void; | |||
| onChange?(value: string, option: any): void; | |||
| } | |||
| const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => { | |||
| const LlmSettingItems = ({ prefix, formItemLayout = {}, onChange }: IProps) => { | |||
| const form = Form.useFormInstance(); | |||
| const { t } = useTranslate('chat'); | |||
| const parameterOptions = Object.values(ModelVariableType).map((x) => ({ | |||
| @@ -58,6 +59,7 @@ const LlmSettingItems = ({ prefix, formItemLayout = {} }: IProps) => { | |||
| options={modelOptions} | |||
| showSearch | |||
| popupMatchSelectWidth={false} | |||
| onChange={onChange} | |||
| /> | |||
| </Form.Item> | |||
| <div className="border rounded-md"> | |||
| @@ -0,0 +1,51 @@ | |||
| import { useTranslate } from '@/hooks/common-hooks'; | |||
| import { useLlmToolsList } from '@/hooks/plugin-hooks'; | |||
| import { Select, Space } from 'antd'; | |||
| interface IProps { | |||
| value?: string; | |||
| onChange?: (value: string) => void; | |||
| disabled?: boolean; | |||
| } | |||
| const LLMToolsSelect = ({ value, onChange, disabled }: IProps) => { | |||
| const { t } = useTranslate("llmTools"); | |||
| const tools = useLlmToolsList(); | |||
| function wrapTranslation(text: string): string { | |||
| if (!text) { | |||
| return text; | |||
| } | |||
| if (text.startsWith("$t:")) { | |||
| return t(text.substring(3)); | |||
| } | |||
| return text; | |||
| } | |||
| const toolOptions = tools.map(t => ({ | |||
| label: wrapTranslation(t.displayName), | |||
| description: wrapTranslation(t.displayDescription), | |||
| value: t.name, | |||
| title: wrapTranslation(t.displayDescription), | |||
| })); | |||
| return ( | |||
| <Select | |||
| mode="multiple" | |||
| options={toolOptions} | |||
| optionRender={option => ( | |||
| <Space size="large"> | |||
| {option.label} | |||
| {option.data.description} | |||
| </Space> | |||
| )} | |||
| onChange={onChange} | |||
| value={value} | |||
| disabled={disabled} | |||
| ></Select> | |||
| ); | |||
| }; | |||
| export default LLMToolsSelect; | |||
| @@ -71,6 +71,7 @@ function buildLlmOptionsWithIcon(x: IThirdOAIModel) { | |||
| ), | |||
| value: `${x.llm_name}@${x.fid}`, | |||
| disabled: !x.available, | |||
| is_tools: x.is_tools, | |||
| }; | |||
| } | |||
| @@ -142,7 +143,7 @@ export const useComposeLlmOptionsByModelTypes = ( | |||
| return modelTypes.reduce< | |||
| (DefaultOptionType & { | |||
| options: { label: JSX.Element; value: string; disabled: boolean }[]; | |||
| options: { label: JSX.Element; value: string; disabled: boolean; is_tools: boolean }[]; | |||
| })[] | |||
| >((pre, cur) => { | |||
| const options = allOptions[cur]; | |||
| @@ -0,0 +1,17 @@ | |||
| import { ILLMTools } from '@/interfaces/database/plugin'; | |||
| import pluginService from '@/services/plugin-service'; | |||
| import { useQuery } from '@tanstack/react-query'; | |||
| export const useLlmToolsList = (): ILLMTools => { | |||
| const { data } = useQuery({ | |||
| queryKey: ['llmTools'], | |||
| initialData: [], | |||
| queryFn: async () => { | |||
| const { data } = await pluginService.getLlmTools(); | |||
| return data?.data ?? []; | |||
| }, | |||
| }); | |||
| return data; | |||
| }; | |||
| @@ -13,6 +13,7 @@ export interface IThirdOAIModel { | |||
| update_time: number; | |||
| tenant_id?: string; | |||
| tenant_name?: string; | |||
| is_tools: boolean; | |||
| } | |||
| export type IThirdOAIModelCollection = Record<string, IThirdOAIModel[]>; | |||
| @@ -0,0 +1,13 @@ | |||
| export type ILLMTools = ILLMToolMetadata[]; | |||
| export interface ILLMToolMetadata { | |||
| name: string; | |||
| displayName: string; | |||
| displayDescription: string; | |||
| parameters: Map<string, ILLMToolParameter>; | |||
| } | |||
| export interface ILLMToolParameter { | |||
| type: string; | |||
| displayDescription: string; | |||
| } | |||
| @@ -454,6 +454,8 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s | |||
| model: 'Model', | |||
| modelTip: 'Large language chat model', | |||
| modelMessage: 'Please select!', | |||
| modelEnabledTools: 'Enabled tools', | |||
| modelEnabledToolsTip: 'Please select one or more tools for the chat model to use. It takes no effect for models not supporting tool call.', | |||
| freedom: 'Freedom', | |||
| improvise: 'Improvise', | |||
| precise: 'Precise', | |||
| @@ -1267,5 +1269,15 @@ This delimiter is used to split the input text into several text pieces echo of | |||
| inputVariables: 'Input variables', | |||
| runningHintText: 'is running...🕞', | |||
| }, | |||
| llmTools: { | |||
| bad_calculator: { | |||
| name: "Calculator", | |||
| description: "A tool to calculate the sum of two numbers (will give wrong answer)", | |||
| params: { | |||
| a: "The first number", | |||
| b: "The second number", | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| }; | |||
| @@ -461,6 +461,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 | |||
| model: '模型', | |||
| modelTip: '大语言聊天模型', | |||
| modelMessage: '请选择', | |||
| modelEnabledTools: '可用的工具', | |||
| modelEnabledToolsTip: '请选择一个或多个可供该模型所使用的工具。仅对支持工具调用的模型生效。', | |||
| freedom: '自由度', | |||
| improvise: '即兴创作', | |||
| precise: '精确', | |||
| @@ -1231,5 +1233,15 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于 | |||
| knowledge: 'knowledge', | |||
| chat: 'chat', | |||
| }, | |||
| llmTools: { | |||
| bad_calculator: { | |||
| name: "计算器", | |||
| description: "用于计算两个数的和的工具(会给出错误答案)", | |||
| params: { | |||
| a: "第一个数", | |||
| b: "第二个数", | |||
| }, | |||
| }, | |||
| }, | |||
| }, | |||
| }; | |||
| @@ -4,10 +4,18 @@ import { PromptEditor } from '@/components/prompt-editor'; | |||
| import { useTranslate } from '@/hooks/common-hooks'; | |||
| import { Form, Switch } from 'antd'; | |||
| import { IOperatorForm } from '../../interface'; | |||
| import LLMToolsSelect from '@/components/llm-tools-select'; | |||
| import { useState } from 'react'; | |||
| const GenerateForm = ({ onValuesChange, form }: IOperatorForm) => { | |||
| const { t } = useTranslate('flow'); | |||
| const [isCurrentLlmSupportTools, setCurrentLlmSupportTools] = useState(false); | |||
| const onLlmSelectChanged = (_: string, option: any) => { | |||
| setCurrentLlmSupportTools(option.is_tools); | |||
| }; | |||
| return ( | |||
| <Form | |||
| name="basic" | |||
| @@ -21,7 +29,7 @@ const GenerateForm = ({ onValuesChange, form }: IOperatorForm) => { | |||
| label={t('model', { keyPrefix: 'chat' })} | |||
| tooltip={t('modelTip', { keyPrefix: 'chat' })} | |||
| > | |||
| <LLMSelect></LLMSelect> | |||
| <LLMSelect onInitialValue={onLlmSelectChanged} onChange={onLlmSelectChanged}></LLMSelect> | |||
| </Form.Item> | |||
| <Form.Item | |||
| name={['prompt']} | |||
| @@ -38,6 +46,13 @@ const GenerateForm = ({ onValuesChange, form }: IOperatorForm) => { | |||
| {/* <Input.TextArea rows={8}></Input.TextArea> */} | |||
| <PromptEditor></PromptEditor> | |||
| </Form.Item> | |||
| <Form.Item | |||
| name={'llm_enabled_tools'} | |||
| label={t('modelEnabledTools', { keyPrefix: 'chat' })} | |||
| tooltip={t('modelEnabledToolsTip', { keyPrefix: 'chat' })} | |||
| > | |||
| <LLMToolsSelect disabled={!isCurrentLlmSupportTools}></LLMToolsSelect> | |||
| </Form.Item> | |||
| <Form.Item | |||
| name={['cite']} | |||
| label={t('cite')} | |||
| @@ -0,0 +1,18 @@ | |||
| import api from '@/utils/api'; | |||
| import registerServer from '@/utils/register-server'; | |||
| import request from '@/utils/request'; | |||
| const { | |||
| llm_tools | |||
| } = api; | |||
| const methods = { | |||
| getLlmTools: { | |||
| url: llm_tools, | |||
| method: 'get', | |||
| }, | |||
| } as const; | |||
| const pluginService = registerServer<keyof typeof methods>(methods, request); | |||
| export default pluginService; | |||
| @@ -32,6 +32,9 @@ export default { | |||
| delete_llm: `${api_host}/llm/delete_llm`, | |||
| deleteFactory: `${api_host}/llm/delete_factory`, | |||
| // plugin | |||
| llm_tools: `${api_host}/plugin/llm_tools`, | |||
| // knowledge base | |||
| kb_list: `${api_host}/kb/list`, | |||
| create_kb: `${api_host}/kb/create`, | |||