You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

request.py 5.3KB


  1. from typing import Any, Literal, Optional
  2. from pydantic import BaseModel, ConfigDict, Field, field_validator
  3. from core.entities.provider_entities import BasicProviderConfig
  4. from core.model_runtime.entities.message_entities import (
  5. AssistantPromptMessage,
  6. PromptMessage,
  7. PromptMessageRole,
  8. PromptMessageTool,
  9. SystemPromptMessage,
  10. ToolPromptMessage,
  11. UserPromptMessage,
  12. )
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from core.workflow.nodes.parameter_extractor.entities import (
  15. ModelConfig as ParameterExtractorModelConfig,
  16. )
  17. from core.workflow.nodes.parameter_extractor.entities import (
  18. ParameterConfig,
  19. )
  20. from core.workflow.nodes.question_classifier.entities import (
  21. ClassConfig,
  22. )
  23. from core.workflow.nodes.question_classifier.entities import (
  24. ModelConfig as QuestionClassifierModelConfig,
  25. )
  26. class RequestInvokeTool(BaseModel):
  27. """
  28. Request to invoke a tool
  29. """
  30. tool_type: Literal["builtin", "workflow", "api", "mcp"]
  31. provider: str
  32. tool: str
  33. tool_parameters: dict
  34. class BaseRequestInvokeModel(BaseModel):
  35. provider: str
  36. model: str
  37. model_type: ModelType
  38. model_config = ConfigDict(protected_namespaces=())
  39. class RequestInvokeLLM(BaseRequestInvokeModel):
  40. """
  41. Request to invoke LLM
  42. """
  43. model_type: ModelType = ModelType.LLM
  44. mode: str
  45. completion_params: dict[str, Any] = Field(default_factory=dict)
  46. prompt_messages: list[PromptMessage] = Field(default_factory=list)
  47. tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool])
  48. stop: Optional[list[str]] = Field(default_factory=list[str])
  49. stream: Optional[bool] = False
  50. model_config = ConfigDict(protected_namespaces=())
  51. @field_validator("prompt_messages", mode="before")
  52. @classmethod
  53. def convert_prompt_messages(cls, v):
  54. if not isinstance(v, list):
  55. raise ValueError("prompt_messages must be a list")
  56. for i in range(len(v)):
  57. if v[i]["role"] == PromptMessageRole.USER.value:
  58. v[i] = UserPromptMessage(**v[i])
  59. elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
  60. v[i] = AssistantPromptMessage(**v[i])
  61. elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
  62. v[i] = SystemPromptMessage(**v[i])
  63. elif v[i]["role"] == PromptMessageRole.TOOL.value:
  64. v[i] = ToolPromptMessage(**v[i])
  65. else:
  66. v[i] = PromptMessage(**v[i])
  67. return v
  68. class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
  69. """
  70. Request to invoke LLM with structured output
  71. """
  72. structured_output_schema: dict[str, Any] = Field(
  73. default_factory=dict, description="The schema of the structured output in JSON schema format"
  74. )
  75. class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
  76. """
  77. Request to invoke text embedding
  78. """
  79. model_type: ModelType = ModelType.TEXT_EMBEDDING
  80. texts: list[str]
  81. class RequestInvokeRerank(BaseRequestInvokeModel):
  82. """
  83. Request to invoke rerank
  84. """
  85. model_type: ModelType = ModelType.RERANK
  86. query: str
  87. docs: list[str]
  88. score_threshold: float
  89. top_n: int
  90. class RequestInvokeTTS(BaseRequestInvokeModel):
  91. """
  92. Request to invoke TTS
  93. """
  94. model_type: ModelType = ModelType.TTS
  95. content_text: str
  96. voice: str
  97. class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
  98. """
  99. Request to invoke speech2text
  100. """
  101. model_type: ModelType = ModelType.SPEECH2TEXT
  102. file: bytes
  103. @field_validator("file", mode="before")
  104. @classmethod
  105. def convert_file(cls, v):
  106. # hex string to bytes
  107. if isinstance(v, str):
  108. return bytes.fromhex(v)
  109. else:
  110. raise ValueError("file must be a hex string")
  111. class RequestInvokeModeration(BaseRequestInvokeModel):
  112. """
  113. Request to invoke moderation
  114. """
  115. model_type: ModelType = ModelType.MODERATION
  116. text: str
  117. class RequestInvokeParameterExtractorNode(BaseModel):
  118. """
  119. Request to invoke parameter extractor node
  120. """
  121. parameters: list[ParameterConfig]
  122. model: ParameterExtractorModelConfig
  123. instruction: str
  124. query: str
  125. class RequestInvokeQuestionClassifierNode(BaseModel):
  126. """
  127. Request to invoke question classifier node
  128. """
  129. query: str
  130. model: QuestionClassifierModelConfig
  131. classes: list[ClassConfig]
  132. instruction: str
  133. class RequestInvokeApp(BaseModel):
  134. """
  135. Request to invoke app
  136. """
  137. app_id: str
  138. inputs: dict[str, Any]
  139. query: Optional[str] = None
  140. response_mode: Literal["blocking", "streaming"]
  141. conversation_id: Optional[str] = None
  142. user: Optional[str] = None
  143. files: list[dict] = Field(default_factory=list)
  144. class RequestInvokeEncrypt(BaseModel):
  145. """
  146. Request to encryption
  147. """
  148. opt: Literal["encrypt", "decrypt", "clear"]
  149. namespace: Literal["endpoint"]
  150. identity: str
  151. data: dict = Field(default_factory=dict)
  152. config: list[BasicProviderConfig] = Field(default_factory=list)
  153. class RequestInvokeSummary(BaseModel):
  154. """
  155. Request to summary
  156. """
  157. text: str
  158. instruction: str
  159. class RequestRequestUploadFile(BaseModel):
  160. """
  161. Request to upload file
  162. """
  163. filename: str
  164. mimetype: str
  165. class RequestFetchAppInfo(BaseModel):
  166. """
  167. Request to fetch app info
  168. """
  169. app_id: str