選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import re
  18. from abc import ABC
  19. import pandas as pd
  20. import pymysql
  21. import psycopg2
  22. import pyodbc
  23. from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
  24. from api.utils.api_utils import timeout
  25. class ExeSQLParam(ToolParamBase):
  26. """
  27. Define the ExeSQL component parameters.
  28. """
  29. def __init__(self):
  30. self.meta:ToolMeta = {
  31. "name": "execute_sql",
  32. "description": "This is a tool that can execute SQL.",
  33. "parameters": {
  34. "sql": {
  35. "type": "string",
  36. "description": "The SQL needs to be executed.",
  37. "default": "{sys.query}",
  38. "required": True
  39. }
  40. }
  41. }
  42. super().__init__()
  43. self.db_type = "mysql"
  44. self.database = ""
  45. self.username = ""
  46. self.host = ""
  47. self.port = 3306
  48. self.password = ""
  49. self.max_records = 1024
  50. def check(self):
  51. self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql'])
  52. self.check_empty(self.database, "Database name")
  53. self.check_empty(self.username, "database username")
  54. self.check_empty(self.host, "IP Address")
  55. self.check_positive_integer(self.port, "IP Port")
  56. self.check_empty(self.password, "Database password")
  57. self.check_positive_integer(self.max_records, "Maximum number of records")
  58. if self.database == "rag_flow":
  59. if self.host == "ragflow-mysql":
  60. raise ValueError("For the security reason, it dose not support database named rag_flow.")
  61. if self.password == "infini_rag_flow":
  62. raise ValueError("For the security reason, it dose not support database named rag_flow.")
  63. def get_input_form(self) -> dict[str, dict]:
  64. return {
  65. "sql": {
  66. "name": "SQL",
  67. "type": "line"
  68. }
  69. }
  70. class ExeSQL(ToolBase, ABC):
  71. component_name = "ExeSQL"
  72. @timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
  73. def _invoke(self, **kwargs):
  74. def convert_decimals(obj):
  75. from decimal import Decimal
  76. if isinstance(obj, Decimal):
  77. return float(obj) # 或 str(obj)
  78. elif isinstance(obj, dict):
  79. return {k: convert_decimals(v) for k, v in obj.items()}
  80. elif isinstance(obj, list):
  81. return [convert_decimals(item) for item in obj]
  82. return obj
  83. sql = kwargs.get("sql")
  84. if not sql:
  85. raise Exception("SQL for `ExeSQL` MUST not be empty.")
  86. sqls = sql.split(";")
  87. if self._param.db_type in ["mysql", "mariadb"]:
  88. db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
  89. port=self._param.port, password=self._param.password)
  90. elif self._param.db_type == 'postgresql':
  91. db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
  92. port=self._param.port, password=self._param.password)
  93. elif self._param.db_type == 'mssql':
  94. conn_str = (
  95. r'DRIVER={ODBC Driver 17 for SQL Server};'
  96. r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
  97. r'DATABASE=' + self._param.database + ';'
  98. r'UID=' + self._param.username + ';'
  99. r'PWD=' + self._param.password
  100. )
  101. db = pyodbc.connect(conn_str)
  102. try:
  103. cursor = db.cursor()
  104. except Exception as e:
  105. raise Exception("Database Connection Failed! \n" + str(e))
  106. sql_res = []
  107. formalized_content = []
  108. for single_sql in sqls:
  109. single_sql = single_sql.replace('```','')
  110. if not single_sql:
  111. continue
  112. single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
  113. cursor.execute(single_sql)
  114. if cursor.rowcount == 0:
  115. sql_res.append({"content": "No record in the database!"})
  116. break
  117. if self._param.db_type == 'mssql':
  118. single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records),
  119. columns=[desc[0] for desc in cursor.description])
  120. else:
  121. single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)])
  122. single_res.columns = [i[0] for i in cursor.description]
  123. for col in single_res.columns:
  124. if pd.api.types.is_datetime64_any_dtype(single_res[col]):
  125. single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
  126. sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
  127. formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
  128. self.set_output("json", sql_res)
  129. self.set_output("formalized_content", "\n\n".join(formalized_content))
  130. return self.output("formalized_content")
  131. def thoughts(self) -> str:
  132. return "Query sent—waiting for the data."