Ver código fonte

Add component ExeSQL (#1966)

### What problem does this PR solve?

#1965 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
tags/v0.10.0
H 1 ano atrás
pai
commit
644f68de97
Nenhuma conta vinculada ao e-mail do autor do commit

+ 1
- 0
agent/component/__init__.py Ver arquivo

from .github import GitHub, GitHubParam from .github import GitHub, GitHubParam
from .baidufanyi import BaiduFanyi, BaiduFanyiParam from .baidufanyi import BaiduFanyi, BaiduFanyiParam
from .qweather import QWeather, QWeatherParam from .qweather import QWeather, QWeatherParam
from .exesql import ExeSQL, ExeSQLParam


def component_class(class_name): def component_class(class_name):
m = importlib.import_module("agent.component") m = importlib.import_module("agent.component")

+ 85
- 0
agent/component/exesql.py Ver arquivo

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC

import pandas as pd
from peewee import MySQLDatabase, PostgresqlDatabase
from agent.component.base import ComponentBase, ComponentParamBase


class ExeSQLParam(ComponentParamBase):
"""
Define the ExeSQL component parameters.
"""

def __init__(self):
super().__init__()
self.db_type = "mysql"
self.database = ""
self.username = ""
self.host = ""
self.port = 3306
self.password = ""
self.loop = 3
self.top_n = 30

def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")


class ExeSQL(ComponentBase, ABC):
component_name = "ExeSQL"

def _run(self, history, **kwargs):
if not hasattr(self, "_loop"):
setattr(self, "_loop", 0)
if self._loop >= self._param.loop:
self._loop = 0
raise Exception("Maximum loop time exceeds. Can't query the correct data via sql statement.")
self._loop += 1

ans = self.get_input()
ans = "".join(ans["content"]) if "content" in ans else ""
if not ans:
return ExeSQL.be_output("SQL statement not found!")

if self._param.db_type in ["mysql", "mariadb"]:
db = MySQLDatabase(self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'postgresql':
db = PostgresqlDatabase(self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)

try:
db.connect()
query = db.execute_sql(ans)
sql_res = [{"content": rec + "\n"} for rec in [str(i) for i in query.fetchall()]]
db.close()
except Exception as e:
return ExeSQL.be_output("**Error**:" + str(e))

if not sql_res:
return ExeSQL.be_output("No record in the database!")

sql_res.insert(0, {"content": "Number of records retrieved from the database is " + str(len(sql_res)) + "\n"})
df = pd.DataFrame(sql_res[0:self._param.top_n + 1])
return ExeSQL.be_output(df.to_markdown())

+ 43
- 0
agent/test/dsl_examples/exesql.json Ver arquivo

{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["exesql:0"],
"upstream": ["begin", "exesql:0"]
},
"exesql:0": {
"obj": {
"component_name": "ExeSQL",
"params": {
"database": "rag_flow",
"username": "root",
"host": "mysql",
"port": 3306,
"password": "infini_rag_flow",
"top_n": 3
}
},
"downstream": ["answer:0"],
"upstream": ["answer:0"]
}
},
"history": [],
"messages": [],
"reference": {},
"path": [],
"answer": []
}


+ 20
- 0
api/apps/canvas_app.py Ver arquivo

from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request from api.utils.api_utils import get_json_result, server_error_response, validate_request
from agent.canvas import Canvas from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase




@manager.route('/templates', methods=['GET']) @manager.route('/templates', methods=['GET'])
return get_json_result(data=req["dsl"]) return get_json_result(data=req["dsl"])
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)


@manager.route('/test_db_connect', methods=['POST'])
@validate_request("db_type", "database", "username", "host", "port", "password")
@login_required
def test_db_connect():
req = request.json
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
elif req["db_type"] == 'postgresql':
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
password=req["password"])
db.connect()
db.close()
return get_json_result(retmsg="Database Connection Successful!")
except Exception as e:
return server_error_response(str(e))

+ 2
- 0
requirements.txt Ver arquivo

Pillow==10.3.0 Pillow==10.3.0
pipreqs==0.5.0 pipreqs==0.5.0
protobuf==5.27.2 protobuf==5.27.2
psycopg2-binary==2.9.9
pyclipper==1.3.0.post5 pyclipper==1.3.0.post5
pycryptodomex==3.20.0 pycryptodomex==3.20.0
pypdf==4.3.0 pypdf==4.3.0
Shapely==2.0.5 Shapely==2.0.5
six==1.16.0 six==1.16.0
StrEnum==0.4.15 StrEnum==0.4.15
tabulate==0.9.0
tika==2.6.0 tika==2.6.0
tiktoken==0.6.0 tiktoken==0.6.0
torch==2.3.0 torch==2.3.0

+ 2
- 0
requirements_arm.txt Ver arquivo

markdown_to_json==2.1.1 markdown_to_json==2.1.1
scholarly==1.7.11 scholarly==1.7.11
deepl==1.18.0 deepl==1.18.0
psycopg2-binary==2.9.9
tabulate-0.9.0

Carregando…
Cancelar
Salvar