### What problem does this PR solve? #1965 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.10.0
| from .github import GitHub, GitHubParam | from .github import GitHub, GitHubParam | ||||
| from .baidufanyi import BaiduFanyi, BaiduFanyiParam | from .baidufanyi import BaiduFanyi, BaiduFanyiParam | ||||
| from .qweather import QWeather, QWeatherParam | from .qweather import QWeather, QWeatherParam | ||||
| from .exesql import ExeSQL, ExeSQLParam | |||||
| def component_class(class_name): | def component_class(class_name): | ||||
| m = importlib.import_module("agent.component") | m = importlib.import_module("agent.component") |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| from abc import ABC | |||||
| import pandas as pd | |||||
| from peewee import MySQLDatabase, PostgresqlDatabase | |||||
| from agent.component.base import ComponentBase, ComponentParamBase | |||||
| class ExeSQLParam(ComponentParamBase): | |||||
| """ | |||||
| Define the ExeSQL component parameters. | |||||
| """ | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.db_type = "mysql" | |||||
| self.database = "" | |||||
| self.username = "" | |||||
| self.host = "" | |||||
| self.port = 3306 | |||||
| self.password = "" | |||||
| self.loop = 3 | |||||
| self.top_n = 30 | |||||
| def check(self): | |||||
| self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb']) | |||||
| self.check_empty(self.database, "Database name") | |||||
| self.check_empty(self.username, "database username") | |||||
| self.check_empty(self.host, "IP Address") | |||||
| self.check_positive_integer(self.port, "IP Port") | |||||
| self.check_empty(self.password, "Database password") | |||||
| self.check_positive_integer(self.top_n, "Number of records") | |||||
| class ExeSQL(ComponentBase, ABC): | |||||
| component_name = "ExeSQL" | |||||
| def _run(self, history, **kwargs): | |||||
| if not hasattr(self, "_loop"): | |||||
| setattr(self, "_loop", 0) | |||||
| if self._loop >= self._param.loop: | |||||
| self._loop = 0 | |||||
| raise Exception("Maximum loop time exceeds. Can't query the correct data via sql statement.") | |||||
| self._loop += 1 | |||||
| ans = self.get_input() | |||||
| ans = "".join(ans["content"]) if "content" in ans else "" | |||||
| if not ans: | |||||
| return ExeSQL.be_output("SQL statement not found!") | |||||
| if self._param.db_type in ["mysql", "mariadb"]: | |||||
| db = MySQLDatabase(self._param.database, user=self._param.username, host=self._param.host, | |||||
| port=self._param.port, password=self._param.password) | |||||
| elif self._param.db_type == 'postgresql': | |||||
| db = PostgresqlDatabase(self._param.database, user=self._param.username, host=self._param.host, | |||||
| port=self._param.port, password=self._param.password) | |||||
| try: | |||||
| db.connect() | |||||
| query = db.execute_sql(ans) | |||||
| sql_res = [{"content": rec + "\n"} for rec in [str(i) for i in query.fetchall()]] | |||||
| db.close() | |||||
| except Exception as e: | |||||
| return ExeSQL.be_output("**Error**:" + str(e)) | |||||
| if not sql_res: | |||||
| return ExeSQL.be_output("No record in the database!") | |||||
| sql_res.insert(0, {"content": "Number of records retrieved from the database is " + str(len(sql_res)) + "\n"}) | |||||
| df = pd.DataFrame(sql_res[0:self._param.top_n + 1]) | |||||
| return ExeSQL.be_output(df.to_markdown()) |
| { | |||||
| "components": { | |||||
| "begin": { | |||||
| "obj":{ | |||||
| "component_name": "Begin", | |||||
| "params": { | |||||
| "prologue": "Hi there!" | |||||
| } | |||||
| }, | |||||
| "downstream": ["answer:0"], | |||||
| "upstream": [] | |||||
| }, | |||||
| "answer:0": { | |||||
| "obj": { | |||||
| "component_name": "Answer", | |||||
| "params": {} | |||||
| }, | |||||
| "downstream": ["exesql:0"], | |||||
| "upstream": ["begin", "exesql:0"] | |||||
| }, | |||||
| "exesql:0": { | |||||
| "obj": { | |||||
| "component_name": "ExeSQL", | |||||
| "params": { | |||||
| "database": "rag_flow", | |||||
| "username": "root", | |||||
| "host": "mysql", | |||||
| "port": 3306, | |||||
| "password": "infini_rag_flow", | |||||
| "top_n": 3 | |||||
| } | |||||
| }, | |||||
| "downstream": ["answer:0"], | |||||
| "upstream": ["answer:0"] | |||||
| } | |||||
| }, | |||||
| "history": [], | |||||
| "messages": [], | |||||
| "reference": {}, | |||||
| "path": [], | |||||
| "answer": [] | |||||
| } | |||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import get_json_result, server_error_response, validate_request | from api.utils.api_utils import get_json_result, server_error_response, validate_request | ||||
| from agent.canvas import Canvas | from agent.canvas import Canvas | ||||
| from peewee import MySQLDatabase, PostgresqlDatabase | |||||
| @manager.route('/templates', methods=['GET']) | @manager.route('/templates', methods=['GET']) | ||||
| return get_json_result(data=req["dsl"]) | return get_json_result(data=req["dsl"]) | ||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| @manager.route('/test_db_connect', methods=['POST']) | |||||
| @validate_request("db_type", "database", "username", "host", "port", "password") | |||||
| @login_required | |||||
| def test_db_connect(): | |||||
| req = request.json | |||||
| try: | |||||
| if req["db_type"] in ["mysql", "mariadb"]: | |||||
| db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], | |||||
| password=req["password"]) | |||||
| elif req["db_type"] == 'postgresql': | |||||
| db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], | |||||
| password=req["password"]) | |||||
| db.connect() | |||||
| db.close() | |||||
| return get_json_result(retmsg="Database Connection Successful!") | |||||
| except Exception as e: | |||||
| return server_error_response(str(e)) |
| Pillow==10.3.0 | Pillow==10.3.0 | ||||
| pipreqs==0.5.0 | pipreqs==0.5.0 | ||||
| protobuf==5.27.2 | protobuf==5.27.2 | ||||
| psycopg2-binary==2.9.9 | |||||
| pyclipper==1.3.0.post5 | pyclipper==1.3.0.post5 | ||||
| pycryptodomex==3.20.0 | pycryptodomex==3.20.0 | ||||
| pypdf==4.3.0 | pypdf==4.3.0 | ||||
| Shapely==2.0.5 | Shapely==2.0.5 | ||||
| six==1.16.0 | six==1.16.0 | ||||
| StrEnum==0.4.15 | StrEnum==0.4.15 | ||||
| tabulate==0.9.0 | |||||
| tika==2.6.0 | tika==2.6.0 | ||||
| tiktoken==0.6.0 | tiktoken==0.6.0 | ||||
| torch==2.3.0 | torch==2.3.0 |
| markdown_to_json==2.1.1 | markdown_to_json==2.1.1 | ||||
| scholarly==1.7.11 | scholarly==1.7.11 | ||||
| deepl==1.18.0 | deepl==1.18.0 | ||||
| psycopg2-binary==2.9.9 | |||||
| tabulate-0.9.0 |