Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

completion.py 10KB

il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans
il y a 2 ans

  1. import logging
  2. from flask import request
  3. from flask_restx import Resource, reqparse
  4. from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
  5. import services
  6. from controllers.service_api import service_api_ns
  7. from controllers.service_api.app.error import (
  8. AppUnavailableError,
  9. CompletionRequestError,
  10. ConversationCompletedError,
  11. NotChatAppError,
  12. ProviderModelCurrentlyNotSupportError,
  13. ProviderNotInitializeError,
  14. ProviderQuotaExceededError,
  15. )
  16. from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
  17. from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
  18. from core.app.apps.base_app_queue_manager import AppQueueManager
  19. from core.app.entities.app_invoke_entities import InvokeFrom
  20. from core.errors.error import (
  21. ModelCurrentlyNotSupportError,
  22. ProviderTokenNotInitError,
  23. QuotaExceededError,
  24. )
  25. from core.helper.trace_id_helper import get_external_trace_id
  26. from core.model_runtime.errors.invoke import InvokeError
  27. from libs import helper
  28. from libs.helper import uuid_value
  29. from models.model import App, AppMode, EndUser
  30. from services.app_generate_service import AppGenerateService
  31. from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
  32. from services.errors.llm import InvokeRateLimitError
  33. # Define parser for completion API
  34. completion_parser = reqparse.RequestParser()
  35. completion_parser.add_argument(
  36. "inputs", type=dict, required=True, location="json", help="Input parameters for completion"
  37. )
  38. completion_parser.add_argument("query", type=str, location="json", default="", help="The query string")
  39. completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
  40. completion_parser.add_argument(
  41. "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode"
  42. )
  43. completion_parser.add_argument(
  44. "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
  45. )
  46. # Define parser for chat API
  47. chat_parser = reqparse.RequestParser()
  48. chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
  49. chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query")
  50. chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
  51. chat_parser.add_argument(
  52. "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode"
  53. )
  54. chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
  55. chat_parser.add_argument(
  56. "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source"
  57. )
  58. chat_parser.add_argument(
  59. "auto_generate_name",
  60. type=bool,
  61. required=False,
  62. default=True,
  63. location="json",
  64. help="Auto generate conversation name",
  65. )
  66. chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
  67. @service_api_ns.route("/completion-messages")
  68. class CompletionApi(Resource):
  69. @service_api_ns.expect(completion_parser)
  70. @service_api_ns.doc("create_completion")
  71. @service_api_ns.doc(description="Create a completion for the given prompt")
  72. @service_api_ns.doc(
  73. responses={
  74. 200: "Completion created successfully",
  75. 400: "Bad request - invalid parameters",
  76. 401: "Unauthorized - invalid API token",
  77. 404: "Conversation not found",
  78. 500: "Internal server error",
  79. }
  80. )
  81. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  82. def post(self, app_model: App, end_user: EndUser):
  83. """Create a completion for the given prompt.
  84. This endpoint generates a completion based on the provided inputs and query.
  85. Supports both blocking and streaming response modes.
  86. """
  87. if app_model.mode != "completion":
  88. raise AppUnavailableError()
  89. args = completion_parser.parse_args()
  90. external_trace_id = get_external_trace_id(request)
  91. if external_trace_id:
  92. args["external_trace_id"] = external_trace_id
  93. streaming = args["response_mode"] == "streaming"
  94. args["auto_generate_name"] = False
  95. try:
  96. response = AppGenerateService.generate(
  97. app_model=app_model,
  98. user=end_user,
  99. args=args,
  100. invoke_from=InvokeFrom.SERVICE_API,
  101. streaming=streaming,
  102. )
  103. return helper.compact_generate_response(response)
  104. except services.errors.conversation.ConversationNotExistsError:
  105. raise NotFound("Conversation Not Exists.")
  106. except services.errors.conversation.ConversationCompletedError:
  107. raise ConversationCompletedError()
  108. except services.errors.app_model_config.AppModelConfigBrokenError:
  109. logging.exception("App model config broken.")
  110. raise AppUnavailableError()
  111. except ProviderTokenNotInitError as ex:
  112. raise ProviderNotInitializeError(ex.description)
  113. except QuotaExceededError:
  114. raise ProviderQuotaExceededError()
  115. except ModelCurrentlyNotSupportError:
  116. raise ProviderModelCurrentlyNotSupportError()
  117. except InvokeError as e:
  118. raise CompletionRequestError(e.description)
  119. except ValueError as e:
  120. raise e
  121. except Exception:
  122. logging.exception("internal server error.")
  123. raise InternalServerError()
  124. @service_api_ns.route("/completion-messages/<string:task_id>/stop")
  125. class CompletionStopApi(Resource):
  126. @service_api_ns.doc("stop_completion")
  127. @service_api_ns.doc(description="Stop a running completion task")
  128. @service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
  129. @service_api_ns.doc(
  130. responses={
  131. 200: "Task stopped successfully",
  132. 401: "Unauthorized - invalid API token",
  133. 404: "Task not found",
  134. }
  135. )
  136. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  137. def post(self, app_model: App, end_user: EndUser, task_id: str):
  138. """Stop a running completion task."""
  139. if app_model.mode != "completion":
  140. raise AppUnavailableError()
  141. AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
  142. return {"result": "success"}, 200
  143. @service_api_ns.route("/chat-messages")
  144. class ChatApi(Resource):
  145. @service_api_ns.expect(chat_parser)
  146. @service_api_ns.doc("create_chat_message")
  147. @service_api_ns.doc(description="Send a message in a chat conversation")
  148. @service_api_ns.doc(
  149. responses={
  150. 200: "Message sent successfully",
  151. 400: "Bad request - invalid parameters or workflow issues",
  152. 401: "Unauthorized - invalid API token",
  153. 404: "Conversation or workflow not found",
  154. 429: "Rate limit exceeded",
  155. 500: "Internal server error",
  156. }
  157. )
  158. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  159. def post(self, app_model: App, end_user: EndUser):
  160. """Send a message in a chat conversation.
  161. This endpoint handles chat messages for chat, agent chat, and advanced chat applications.
  162. Supports conversation management and both blocking and streaming response modes.
  163. """
  164. app_mode = AppMode.value_of(app_model.mode)
  165. if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
  166. raise NotChatAppError()
  167. args = chat_parser.parse_args()
  168. external_trace_id = get_external_trace_id(request)
  169. if external_trace_id:
  170. args["external_trace_id"] = external_trace_id
  171. streaming = args["response_mode"] == "streaming"
  172. try:
  173. response = AppGenerateService.generate(
  174. app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
  175. )
  176. return helper.compact_generate_response(response)
  177. except WorkflowNotFoundError as ex:
  178. raise NotFound(str(ex))
  179. except IsDraftWorkflowError as ex:
  180. raise BadRequest(str(ex))
  181. except WorkflowIdFormatError as ex:
  182. raise BadRequest(str(ex))
  183. except services.errors.conversation.ConversationNotExistsError:
  184. raise NotFound("Conversation Not Exists.")
  185. except services.errors.conversation.ConversationCompletedError:
  186. raise ConversationCompletedError()
  187. except services.errors.app_model_config.AppModelConfigBrokenError:
  188. logging.exception("App model config broken.")
  189. raise AppUnavailableError()
  190. except ProviderTokenNotInitError as ex:
  191. raise ProviderNotInitializeError(ex.description)
  192. except QuotaExceededError:
  193. raise ProviderQuotaExceededError()
  194. except ModelCurrentlyNotSupportError:
  195. raise ProviderModelCurrentlyNotSupportError()
  196. except InvokeRateLimitError as ex:
  197. raise InvokeRateLimitHttpError(ex.description)
  198. except InvokeError as e:
  199. raise CompletionRequestError(e.description)
  200. except ValueError as e:
  201. raise e
  202. except Exception:
  203. logging.exception("internal server error.")
  204. raise InternalServerError()
  205. @service_api_ns.route("/chat-messages/<string:task_id>/stop")
  206. class ChatStopApi(Resource):
  207. @service_api_ns.doc("stop_chat_message")
  208. @service_api_ns.doc(description="Stop a running chat message generation")
  209. @service_api_ns.doc(params={"task_id": "The ID of the task to stop"})
  210. @service_api_ns.doc(
  211. responses={
  212. 200: "Task stopped successfully",
  213. 401: "Unauthorized - invalid API token",
  214. 404: "Task not found",
  215. }
  216. )
  217. @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
  218. def post(self, app_model: App, end_user: EndUser, task_id: str):
  219. """Stop a running chat message generation."""
  220. app_mode = AppMode.value_of(app_model.mode)
  221. if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
  222. raise NotChatAppError()
  223. AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
  224. return {"result": "success"}, 200