您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

workflow.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import logging
  2. from dateutil.parser import isoparse
  3. from flask import request
  4. from flask_restx import Api, Namespace, Resource, fields, reqparse
  5. from flask_restx.inputs import int_range
  6. from sqlalchemy.orm import Session, sessionmaker
  7. from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
  8. from controllers.service_api import service_api_ns
  9. from controllers.service_api.app.error import (
  10. CompletionRequestError,
  11. NotWorkflowAppError,
  12. ProviderModelCurrentlyNotSupportError,
  13. ProviderNotInitializeError,
  14. ProviderQuotaExceededError,
  15. )
  16. from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
  17. from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
  18. from core.app.apps.base_app_queue_manager import AppQueueManager
  19. from core.app.entities.app_invoke_entities import InvokeFrom
  20. from core.errors.error import (
  21. ModelCurrentlyNotSupportError,
  22. ProviderTokenNotInitError,
  23. QuotaExceededError,
  24. )
  25. from core.helper.trace_id_helper import get_external_trace_id
  26. from core.model_runtime.errors.invoke import InvokeError
  27. from core.workflow.enums import WorkflowExecutionStatus
  28. from core.workflow.graph_engine.manager import GraphEngineManager
  29. from extensions.ext_database import db
  30. from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
  31. from libs import helper
  32. from libs.helper import TimestampField
  33. from models.model import App, AppMode, EndUser
  34. from repositories.factory import DifyAPIRepositoryFactory
  35. from services.app_generate_service import AppGenerateService
  36. from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
  37. from services.errors.llm import InvokeRateLimitError
  38. from services.workflow_app_service import WorkflowAppService
  39. logger = logging.getLogger(__name__)
  40. # Define parsers for workflow APIs
  41. workflow_run_parser = reqparse.RequestParser()
  42. workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
  43. workflow_run_parser.add_argument("files", type=list, required=False, location="json")
  44. workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
  45. workflow_log_parser = reqparse.RequestParser()
  46. workflow_log_parser.add_argument("keyword", type=str, location="args")
  47. workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
  48. workflow_log_parser.add_argument("created_at__before", type=str, location="args")
  49. workflow_log_parser.add_argument("created_at__after", type=str, location="args")
  50. workflow_log_parser.add_argument(
  51. "created_by_end_user_session_id",
  52. type=str,
  53. location="args",
  54. required=False,
  55. default=None,
  56. )
  57. workflow_log_parser.add_argument(
  58. "created_by_account",
  59. type=str,
  60. location="args",
  61. required=False,
  62. default=None,
  63. )
  64. workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
  65. workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
  66. workflow_run_fields = {
  67. "id": fields.String,
  68. "workflow_id": fields.String,
  69. "status": fields.String,
  70. "inputs": fields.Raw,
  71. "outputs": fields.Raw,
  72. "error": fields.String,
  73. "total_steps": fields.Integer,
  74. "total_tokens": fields.Integer,
  75. "created_at": TimestampField,
  76. "finished_at": TimestampField,
  77. "elapsed_time": fields.Float,
  78. }
  79. def build_workflow_run_model(api_or_ns: Api | Namespace):
  80. """Build the workflow run model for the API or Namespace."""
  81. return api_or_ns.model("WorkflowRun", workflow_run_fields)
  82. @service_api_ns.route("/workflows/run/<string:workflow_run_id>")
  83. class WorkflowRunDetailApi(Resource):
  84. @service_api_ns.doc("get_workflow_run_detail")
  85. @service_api_ns.doc(description="Get workflow run details")
  86. @service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"})
  87. @service_api_ns.doc(
  88. responses={
  89. 200: "Workflow run details retrieved successfully",
  90. 401: "Unauthorized - invalid API token",
  91. 404: "Workflow run not found",
  92. }
  93. )
  94. @validate_app_token
  95. @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
  96. def get(self, app_model: App, workflow_run_id: str):
  97. """Get a workflow task running detail.
  98. Returns detailed information about a specific workflow run.
  99. """
  100. app_mode = AppMode.value_of(app_model.mode)
  101. if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
  102. raise NotWorkflowAppError()
  103. # Use repository to get workflow run
  104. session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
  105. workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
  106. workflow_run = workflow_run_repo.get_workflow_run_by_id(
  107. tenant_id=app_model.tenant_id,
  108. app_id=app_model.id,
  109. run_id=workflow_run_id,
  110. )
  111. return workflow_run
  112. @service_api_ns.route("/workflows/run")
  113. class WorkflowRunApi(Resource):
  114. @service_api_ns.expect(workflow_run_parser)
  115. @service_api_ns.doc("run_workflow")
  116. @service_api_ns.doc(description="Execute a workflow")
  117. @service_api_ns.doc(
  118. responses={
  119. 200: "Workflow executed successfully",
  120. 400: "Bad request - invalid parameters or workflow issues",
  121. 401: "Unauthorized - invalid API token",
  122. 404: "Workflow not found",
  123. 429: "Rate limit exceeded",
  124. 500: "Internal server error",
  125. }
  126. )
  127. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  128. def post(self, app_model: App, end_user: EndUser):
  129. """Execute a workflow.
  130. Runs a workflow with the provided inputs and returns the results.
  131. Supports both blocking and streaming response modes.
  132. """
  133. app_mode = AppMode.value_of(app_model.mode)
  134. if app_mode != AppMode.WORKFLOW:
  135. raise NotWorkflowAppError()
  136. args = workflow_run_parser.parse_args()
  137. external_trace_id = get_external_trace_id(request)
  138. if external_trace_id:
  139. args["external_trace_id"] = external_trace_id
  140. streaming = args.get("response_mode") == "streaming"
  141. try:
  142. response = AppGenerateService.generate(
  143. app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
  144. )
  145. return helper.compact_generate_response(response)
  146. except ProviderTokenNotInitError as ex:
  147. raise ProviderNotInitializeError(ex.description)
  148. except QuotaExceededError:
  149. raise ProviderQuotaExceededError()
  150. except ModelCurrentlyNotSupportError:
  151. raise ProviderModelCurrentlyNotSupportError()
  152. except InvokeRateLimitError as ex:
  153. raise InvokeRateLimitHttpError(ex.description)
  154. except InvokeError as e:
  155. raise CompletionRequestError(e.description)
  156. except ValueError as e:
  157. raise e
  158. except Exception:
  159. logger.exception("internal server error.")
  160. raise InternalServerError()
  161. @service_api_ns.route("/workflows/<string:workflow_id>/run")
  162. class WorkflowRunByIdApi(Resource):
  163. @service_api_ns.expect(workflow_run_parser)
  164. @service_api_ns.doc("run_workflow_by_id")
  165. @service_api_ns.doc(description="Execute a specific workflow by ID")
  166. @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
  167. @service_api_ns.doc(
  168. responses={
  169. 200: "Workflow executed successfully",
  170. 400: "Bad request - invalid parameters or workflow issues",
  171. 401: "Unauthorized - invalid API token",
  172. 404: "Workflow not found",
  173. 429: "Rate limit exceeded",
  174. 500: "Internal server error",
  175. }
  176. )
  177. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  178. def post(self, app_model: App, end_user: EndUser, workflow_id: str):
  179. """Run specific workflow by ID.
  180. Executes a specific workflow version identified by its ID.
  181. """
  182. app_mode = AppMode.value_of(app_model.mode)
  183. if app_mode != AppMode.WORKFLOW:
  184. raise NotWorkflowAppError()
  185. args = workflow_run_parser.parse_args()
  186. # Add workflow_id to args for AppGenerateService
  187. args["workflow_id"] = workflow_id
  188. external_trace_id = get_external_trace_id(request)
  189. if external_trace_id:
  190. args["external_trace_id"] = external_trace_id
  191. streaming = args.get("response_mode") == "streaming"
  192. try:
  193. response = AppGenerateService.generate(
  194. app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
  195. )
  196. return helper.compact_generate_response(response)
  197. except WorkflowNotFoundError as ex:
  198. raise NotFound(str(ex))
  199. except IsDraftWorkflowError as ex:
  200. raise BadRequest(str(ex))
  201. except WorkflowIdFormatError as ex:
  202. raise BadRequest(str(ex))
  203. except ProviderTokenNotInitError as ex:
  204. raise ProviderNotInitializeError(ex.description)
  205. except QuotaExceededError:
  206. raise ProviderQuotaExceededError()
  207. except ModelCurrentlyNotSupportError:
  208. raise ProviderModelCurrentlyNotSupportError()
  209. except InvokeRateLimitError as ex:
  210. raise InvokeRateLimitHttpError(ex.description)
  211. except InvokeError as e:
  212. raise CompletionRequestError(e.description)
  213. except ValueError as e:
  214. raise e
  215. except Exception:
  216. logger.exception("internal server error.")
  217. raise InternalServerError()
  218. @service_api_ns.route("/workflows/tasks/<string:task_id>/stop")
  219. class WorkflowTaskStopApi(Resource):
  220. @service_api_ns.doc("stop_workflow_task")
  221. @service_api_ns.doc(description="Stop a running workflow task")
  222. @service_api_ns.doc(params={"task_id": "Task ID to stop"})
  223. @service_api_ns.doc(
  224. responses={
  225. 200: "Task stopped successfully",
  226. 401: "Unauthorized - invalid API token",
  227. 404: "Task not found",
  228. }
  229. )
  230. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  231. def post(self, app_model: App, end_user: EndUser, task_id: str):
  232. """Stop a running workflow task."""
  233. app_mode = AppMode.value_of(app_model.mode)
  234. if app_mode != AppMode.WORKFLOW:
  235. raise NotWorkflowAppError()
  236. # Stop using both mechanisms for backward compatibility
  237. # Legacy stop flag mechanism (without user check)
  238. AppQueueManager.set_stop_flag_no_user_check(task_id)
  239. # New graph engine command channel mechanism
  240. GraphEngineManager.send_stop_command(task_id)
  241. return {"result": "success"}
  242. @service_api_ns.route("/workflows/logs")
  243. class WorkflowAppLogApi(Resource):
  244. @service_api_ns.expect(workflow_log_parser)
  245. @service_api_ns.doc("get_workflow_logs")
  246. @service_api_ns.doc(description="Get workflow execution logs")
  247. @service_api_ns.doc(
  248. responses={
  249. 200: "Logs retrieved successfully",
  250. 401: "Unauthorized - invalid API token",
  251. }
  252. )
  253. @validate_app_token
  254. @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns))
  255. def get(self, app_model: App):
  256. """Get workflow app logs.
  257. Returns paginated workflow execution logs with filtering options.
  258. """
  259. args = workflow_log_parser.parse_args()
  260. args.status = WorkflowExecutionStatus(args.status) if args.status else None
  261. if args.created_at__before:
  262. args.created_at__before = isoparse(args.created_at__before)
  263. if args.created_at__after:
  264. args.created_at__after = isoparse(args.created_at__after)
  265. # get paginate workflow app logs
  266. workflow_app_service = WorkflowAppService()
  267. with Session(db.engine) as session:
  268. workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
  269. session=session,
  270. app_model=app_model,
  271. keyword=args.keyword,
  272. status=args.status,
  273. created_at_before=args.created_at__before,
  274. created_at_after=args.created_at__after,
  275. page=args.page,
  276. limit=args.limit,
  277. created_by_end_user_session_id=args.created_by_end_user_session_id,
  278. created_by_account=args.created_by_account,
  279. )
  280. return workflow_app_log_pagination