您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

workflow.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import logging
  2. from dateutil.parser import isoparse
  3. from flask import request
  4. from flask_restx import Api, Namespace, Resource, fields, reqparse
  5. from flask_restx.inputs import int_range
  6. from sqlalchemy.orm import Session, sessionmaker
  7. from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
  8. from controllers.service_api import service_api_ns
  9. from controllers.service_api.app.error import (
  10. CompletionRequestError,
  11. NotWorkflowAppError,
  12. ProviderModelCurrentlyNotSupportError,
  13. ProviderNotInitializeError,
  14. ProviderQuotaExceededError,
  15. )
  16. from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
  17. from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
  18. from core.app.apps.base_app_queue_manager import AppQueueManager
  19. from core.app.entities.app_invoke_entities import InvokeFrom
  20. from core.errors.error import (
  21. ModelCurrentlyNotSupportError,
  22. ProviderTokenNotInitError,
  23. QuotaExceededError,
  24. )
  25. from core.helper.trace_id_helper import get_external_trace_id
  26. from core.model_runtime.errors.invoke import InvokeError
  27. from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
  28. from extensions.ext_database import db
  29. from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
  30. from libs import helper
  31. from libs.helper import TimestampField
  32. from models.model import App, AppMode, EndUser
  33. from repositories.factory import DifyAPIRepositoryFactory
  34. from services.app_generate_service import AppGenerateService
  35. from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
  36. from services.errors.llm import InvokeRateLimitError
  37. from services.workflow_app_service import WorkflowAppService
  38. logger = logging.getLogger(__name__)
  39. # Define parsers for workflow APIs
  40. workflow_run_parser = reqparse.RequestParser()
  41. workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
  42. workflow_run_parser.add_argument("files", type=list, required=False, location="json")
  43. workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
  44. workflow_log_parser = reqparse.RequestParser()
  45. workflow_log_parser.add_argument("keyword", type=str, location="args")
  46. workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
  47. workflow_log_parser.add_argument("created_at__before", type=str, location="args")
  48. workflow_log_parser.add_argument("created_at__after", type=str, location="args")
  49. workflow_log_parser.add_argument(
  50. "created_by_end_user_session_id",
  51. type=str,
  52. location="args",
  53. required=False,
  54. default=None,
  55. )
  56. workflow_log_parser.add_argument(
  57. "created_by_account",
  58. type=str,
  59. location="args",
  60. required=False,
  61. default=None,
  62. )
  63. workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
  64. workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
  65. workflow_run_fields = {
  66. "id": fields.String,
  67. "workflow_id": fields.String,
  68. "status": fields.String,
  69. "inputs": fields.Raw,
  70. "outputs": fields.Raw,
  71. "error": fields.String,
  72. "total_steps": fields.Integer,
  73. "total_tokens": fields.Integer,
  74. "created_at": TimestampField,
  75. "finished_at": TimestampField,
  76. "elapsed_time": fields.Float,
  77. }
  78. def build_workflow_run_model(api_or_ns: Api | Namespace):
  79. """Build the workflow run model for the API or Namespace."""
  80. return api_or_ns.model("WorkflowRun", workflow_run_fields)
  81. @service_api_ns.route("/workflows/run/<string:workflow_run_id>")
  82. class WorkflowRunDetailApi(Resource):
  83. @service_api_ns.doc("get_workflow_run_detail")
  84. @service_api_ns.doc(description="Get workflow run details")
  85. @service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"})
  86. @service_api_ns.doc(
  87. responses={
  88. 200: "Workflow run details retrieved successfully",
  89. 401: "Unauthorized - invalid API token",
  90. 404: "Workflow run not found",
  91. }
  92. )
  93. @validate_app_token
  94. @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
  95. def get(self, app_model: App, workflow_run_id: str):
  96. """Get a workflow task running detail.
  97. Returns detailed information about a specific workflow run.
  98. """
  99. app_mode = AppMode.value_of(app_model.mode)
  100. if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
  101. raise NotWorkflowAppError()
  102. # Use repository to get workflow run
  103. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  104. workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
  105. workflow_run = workflow_run_repo.get_workflow_run_by_id(
  106. tenant_id=app_model.tenant_id,
  107. app_id=app_model.id,
  108. run_id=workflow_run_id,
  109. )
  110. return workflow_run
  111. @service_api_ns.route("/workflows/run")
  112. class WorkflowRunApi(Resource):
  113. @service_api_ns.expect(workflow_run_parser)
  114. @service_api_ns.doc("run_workflow")
  115. @service_api_ns.doc(description="Execute a workflow")
  116. @service_api_ns.doc(
  117. responses={
  118. 200: "Workflow executed successfully",
  119. 400: "Bad request - invalid parameters or workflow issues",
  120. 401: "Unauthorized - invalid API token",
  121. 404: "Workflow not found",
  122. 429: "Rate limit exceeded",
  123. 500: "Internal server error",
  124. }
  125. )
  126. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  127. def post(self, app_model: App, end_user: EndUser):
  128. """Execute a workflow.
  129. Runs a workflow with the provided inputs and returns the results.
  130. Supports both blocking and streaming response modes.
  131. """
  132. app_mode = AppMode.value_of(app_model.mode)
  133. if app_mode != AppMode.WORKFLOW:
  134. raise NotWorkflowAppError()
  135. args = workflow_run_parser.parse_args()
  136. external_trace_id = get_external_trace_id(request)
  137. if external_trace_id:
  138. args["external_trace_id"] = external_trace_id
  139. streaming = args.get("response_mode") == "streaming"
  140. try:
  141. response = AppGenerateService.generate(
  142. app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
  143. )
  144. return helper.compact_generate_response(response)
  145. except ProviderTokenNotInitError as ex:
  146. raise ProviderNotInitializeError(ex.description)
  147. except QuotaExceededError:
  148. raise ProviderQuotaExceededError()
  149. except ModelCurrentlyNotSupportError:
  150. raise ProviderModelCurrentlyNotSupportError()
  151. except InvokeRateLimitError as ex:
  152. raise InvokeRateLimitHttpError(ex.description)
  153. except InvokeError as e:
  154. raise CompletionRequestError(e.description)
  155. except ValueError as e:
  156. raise e
  157. except Exception:
  158. logger.exception("internal server error.")
  159. raise InternalServerError()
  160. @service_api_ns.route("/workflows/<string:workflow_id>/run")
  161. class WorkflowRunByIdApi(Resource):
  162. @service_api_ns.expect(workflow_run_parser)
  163. @service_api_ns.doc("run_workflow_by_id")
  164. @service_api_ns.doc(description="Execute a specific workflow by ID")
  165. @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
  166. @service_api_ns.doc(
  167. responses={
  168. 200: "Workflow executed successfully",
  169. 400: "Bad request - invalid parameters or workflow issues",
  170. 401: "Unauthorized - invalid API token",
  171. 404: "Workflow not found",
  172. 429: "Rate limit exceeded",
  173. 500: "Internal server error",
  174. }
  175. )
  176. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  177. def post(self, app_model: App, end_user: EndUser, workflow_id: str):
  178. """Run specific workflow by ID.
  179. Executes a specific workflow version identified by its ID.
  180. """
  181. app_mode = AppMode.value_of(app_model.mode)
  182. if app_mode != AppMode.WORKFLOW:
  183. raise NotWorkflowAppError()
  184. args = workflow_run_parser.parse_args()
  185. # Add workflow_id to args for AppGenerateService
  186. args["workflow_id"] = workflow_id
  187. external_trace_id = get_external_trace_id(request)
  188. if external_trace_id:
  189. args["external_trace_id"] = external_trace_id
  190. streaming = args.get("response_mode") == "streaming"
  191. try:
  192. response = AppGenerateService.generate(
  193. app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
  194. )
  195. return helper.compact_generate_response(response)
  196. except WorkflowNotFoundError as ex:
  197. raise NotFound(str(ex))
  198. except IsDraftWorkflowError as ex:
  199. raise BadRequest(str(ex))
  200. except WorkflowIdFormatError as ex:
  201. raise BadRequest(str(ex))
  202. except ProviderTokenNotInitError as ex:
  203. raise ProviderNotInitializeError(ex.description)
  204. except QuotaExceededError:
  205. raise ProviderQuotaExceededError()
  206. except ModelCurrentlyNotSupportError:
  207. raise ProviderModelCurrentlyNotSupportError()
  208. except InvokeRateLimitError as ex:
  209. raise InvokeRateLimitHttpError(ex.description)
  210. except InvokeError as e:
  211. raise CompletionRequestError(e.description)
  212. except ValueError as e:
  213. raise e
  214. except Exception:
  215. logger.exception("internal server error.")
  216. raise InternalServerError()
  217. @service_api_ns.route("/workflows/tasks/<string:task_id>/stop")
  218. class WorkflowTaskStopApi(Resource):
  219. @service_api_ns.doc("stop_workflow_task")
  220. @service_api_ns.doc(description="Stop a running workflow task")
  221. @service_api_ns.doc(params={"task_id": "Task ID to stop"})
  222. @service_api_ns.doc(
  223. responses={
  224. 200: "Task stopped successfully",
  225. 401: "Unauthorized - invalid API token",
  226. 404: "Task not found",
  227. }
  228. )
  229. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  230. def post(self, app_model: App, end_user: EndUser, task_id: str):
  231. """Stop a running workflow task."""
  232. app_mode = AppMode.value_of(app_model.mode)
  233. if app_mode != AppMode.WORKFLOW:
  234. raise NotWorkflowAppError()
  235. AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
  236. return {"result": "success"}
  237. @service_api_ns.route("/workflows/logs")
  238. class WorkflowAppLogApi(Resource):
  239. @service_api_ns.expect(workflow_log_parser)
  240. @service_api_ns.doc("get_workflow_logs")
  241. @service_api_ns.doc(description="Get workflow execution logs")
  242. @service_api_ns.doc(
  243. responses={
  244. 200: "Logs retrieved successfully",
  245. 401: "Unauthorized - invalid API token",
  246. }
  247. )
  248. @validate_app_token
  249. @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
  250. def get(self, app_model: App):
  251. """Get workflow app logs.
  252. Returns paginated workflow execution logs with filtering options.
  253. """
  254. args = workflow_log_parser.parse_args()
  255. args.status = WorkflowExecutionStatus(args.status) if args.status else None
  256. if args.created_at__before:
  257. args.created_at__before = isoparse(args.created_at__before)
  258. if args.created_at__after:
  259. args.created_at__after = isoparse(args.created_at__after)
  260. # get paginate workflow app logs
  261. workflow_app_service = WorkflowAppService()
  262. with Session(db.engine) as session:
  263. workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
  264. session=session,
  265. app_model=app_model,
  266. keyword=args.keyword,
  267. status=args.status,
  268. created_at_before=args.created_at__before,
  269. created_at_after=args.created_at__after,
  270. page=args.page,
  271. limit=args.limit,
  272. created_by_end_user_session_id=args.created_by_end_user_session_id,
  273. created_by_account=args.created_by_account,
  274. )
  275. return workflow_app_log_pagination