您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

generator.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from collections.abc import Sequence
  2. from flask_login import current_user
  3. from flask_restx import Resource, reqparse
  4. from controllers.console import api
  5. from controllers.console.app.error import (
  6. CompletionRequestError,
  7. ProviderModelCurrentlyNotSupportError,
  8. ProviderNotInitializeError,
  9. ProviderQuotaExceededError,
  10. )
  11. from controllers.console.wraps import account_initialization_required, setup_required
  12. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  13. from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
  14. from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
  15. from core.llm_generator.llm_generator import LLMGenerator
  16. from core.model_runtime.errors.invoke import InvokeError
  17. from extensions.ext_database import db
  18. from libs.login import login_required
  19. from models import App
  20. from services.workflow_service import WorkflowService
  21. class RuleGenerateApi(Resource):
  22. @setup_required
  23. @login_required
  24. @account_initialization_required
  25. def post(self):
  26. parser = reqparse.RequestParser()
  27. parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
  28. parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
  29. parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
  30. args = parser.parse_args()
  31. account = current_user
  32. try:
  33. rules = LLMGenerator.generate_rule_config(
  34. tenant_id=account.current_tenant_id,
  35. instruction=args["instruction"],
  36. model_config=args["model_config"],
  37. no_variable=args["no_variable"],
  38. )
  39. except ProviderTokenNotInitError as ex:
  40. raise ProviderNotInitializeError(ex.description)
  41. except QuotaExceededError:
  42. raise ProviderQuotaExceededError()
  43. except ModelCurrentlyNotSupportError:
  44. raise ProviderModelCurrentlyNotSupportError()
  45. except InvokeError as e:
  46. raise CompletionRequestError(e.description)
  47. return rules
  48. class RuleCodeGenerateApi(Resource):
  49. @setup_required
  50. @login_required
  51. @account_initialization_required
  52. def post(self):
  53. parser = reqparse.RequestParser()
  54. parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
  55. parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
  56. parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
  57. parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
  58. args = parser.parse_args()
  59. account = current_user
  60. try:
  61. code_result = LLMGenerator.generate_code(
  62. tenant_id=account.current_tenant_id,
  63. instruction=args["instruction"],
  64. model_config=args["model_config"],
  65. code_language=args["code_language"],
  66. )
  67. except ProviderTokenNotInitError as ex:
  68. raise ProviderNotInitializeError(ex.description)
  69. except QuotaExceededError:
  70. raise ProviderQuotaExceededError()
  71. except ModelCurrentlyNotSupportError:
  72. raise ProviderModelCurrentlyNotSupportError()
  73. except InvokeError as e:
  74. raise CompletionRequestError(e.description)
  75. return code_result
  76. class RuleStructuredOutputGenerateApi(Resource):
  77. @setup_required
  78. @login_required
  79. @account_initialization_required
  80. def post(self):
  81. parser = reqparse.RequestParser()
  82. parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
  83. parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
  84. args = parser.parse_args()
  85. account = current_user
  86. try:
  87. structured_output = LLMGenerator.generate_structured_output(
  88. tenant_id=account.current_tenant_id,
  89. instruction=args["instruction"],
  90. model_config=args["model_config"],
  91. )
  92. except ProviderTokenNotInitError as ex:
  93. raise ProviderNotInitializeError(ex.description)
  94. except QuotaExceededError:
  95. raise ProviderQuotaExceededError()
  96. except ModelCurrentlyNotSupportError:
  97. raise ProviderModelCurrentlyNotSupportError()
  98. except InvokeError as e:
  99. raise CompletionRequestError(e.description)
  100. return structured_output
  101. class InstructionGenerateApi(Resource):
  102. @setup_required
  103. @login_required
  104. @account_initialization_required
  105. def post(self):
  106. parser = reqparse.RequestParser()
  107. parser.add_argument("flow_id", type=str, required=True, default="", location="json")
  108. parser.add_argument("node_id", type=str, required=False, default="", location="json")
  109. parser.add_argument("current", type=str, required=False, default="", location="json")
  110. parser.add_argument("language", type=str, required=False, default="javascript", location="json")
  111. parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
  112. parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
  113. parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
  114. args = parser.parse_args()
  115. code_template = (
  116. Python3CodeProvider.get_default_code()
  117. if args["language"] == "python"
  118. else (JavascriptCodeProvider.get_default_code())
  119. if args["language"] == "javascript"
  120. else ""
  121. )
  122. try:
  123. # Generate from nothing for a workflow node
  124. if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
  125. app = db.session.query(App).where(App.id == args["flow_id"]).first()
  126. if not app:
  127. return {"error": f"app {args['flow_id']} not found"}, 400
  128. workflow = WorkflowService().get_draft_workflow(app_model=app)
  129. if not workflow:
  130. return {"error": f"workflow {args['flow_id']} not found"}, 400
  131. nodes: Sequence = workflow.graph_dict["nodes"]
  132. node = [node for node in nodes if node["id"] == args["node_id"]]
  133. if len(node) == 0:
  134. return {"error": f"node {args['node_id']} not found"}, 400
  135. node_type = node[0]["data"]["type"]
  136. match node_type:
  137. case "llm":
  138. return LLMGenerator.generate_rule_config(
  139. current_user.current_tenant_id,
  140. instruction=args["instruction"],
  141. model_config=args["model_config"],
  142. no_variable=True,
  143. )
  144. case "agent":
  145. return LLMGenerator.generate_rule_config(
  146. current_user.current_tenant_id,
  147. instruction=args["instruction"],
  148. model_config=args["model_config"],
  149. no_variable=True,
  150. )
  151. case "code":
  152. return LLMGenerator.generate_code(
  153. tenant_id=current_user.current_tenant_id,
  154. instruction=args["instruction"],
  155. model_config=args["model_config"],
  156. code_language=args["language"],
  157. )
  158. case _:
  159. return {"error": f"invalid node type: {node_type}"}
  160. if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
  161. return LLMGenerator.instruction_modify_legacy(
  162. tenant_id=current_user.current_tenant_id,
  163. flow_id=args["flow_id"],
  164. current=args["current"],
  165. instruction=args["instruction"],
  166. model_config=args["model_config"],
  167. ideal_output=args["ideal_output"],
  168. )
  169. if args["node_id"] != "" and args["current"] != "": # For workflow node
  170. return LLMGenerator.instruction_modify_workflow(
  171. tenant_id=current_user.current_tenant_id,
  172. flow_id=args["flow_id"],
  173. node_id=args["node_id"],
  174. current=args["current"],
  175. instruction=args["instruction"],
  176. model_config=args["model_config"],
  177. ideal_output=args["ideal_output"],
  178. )
  179. return {"error": "incompatible parameters"}, 400
  180. except ProviderTokenNotInitError as ex:
  181. raise ProviderNotInitializeError(ex.description)
  182. except QuotaExceededError:
  183. raise ProviderQuotaExceededError()
  184. except ModelCurrentlyNotSupportError:
  185. raise ProviderModelCurrentlyNotSupportError()
  186. except InvokeError as e:
  187. raise CompletionRequestError(e.description)
  188. class InstructionGenerationTemplateApi(Resource):
  189. @setup_required
  190. @login_required
  191. @account_initialization_required
  192. def post(self):
  193. parser = reqparse.RequestParser()
  194. parser.add_argument("type", type=str, required=True, default=False, location="json")
  195. args = parser.parse_args()
  196. match args["type"]:
  197. case "prompt":
  198. from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
  199. return {"data": INSTRUCTION_GENERATE_TEMPLATE_PROMPT}
  200. case "code":
  201. from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_CODE
  202. return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
  203. case _:
  204. raise ValueError(f"Invalid type: {args['type']}")
  205. api.add_resource(RuleGenerateApi, "/rule-generate")
  206. api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
  207. api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
  208. api.add_resource(InstructionGenerateApi, "/instruction-generate")
  209. api.add_resource(InstructionGenerationTemplateApi, "/instruction-generate/template")