選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

request.py 5.8KB


  1. from typing import Any, Literal, Optional
  2. from pydantic import BaseModel, ConfigDict, Field, field_validator
  3. from core.entities.provider_entities import BasicProviderConfig
  4. from core.model_runtime.entities.message_entities import (
  5. AssistantPromptMessage,
  6. PromptMessage,
  7. PromptMessageRole,
  8. PromptMessageTool,
  9. SystemPromptMessage,
  10. ToolPromptMessage,
  11. UserPromptMessage,
  12. )
  13. from core.model_runtime.entities.model_entities import ModelType
  14. from core.workflow.nodes.parameter_extractor.entities import (
  15. ModelConfig as ParameterExtractorModelConfig,
  16. )
  17. from core.workflow.nodes.parameter_extractor.entities import (
  18. ParameterConfig,
  19. )
  20. from core.workflow.nodes.question_classifier.entities import (
  21. ClassConfig,
  22. )
  23. from core.workflow.nodes.question_classifier.entities import (
  24. ModelConfig as QuestionClassifierModelConfig,
  25. )
  26. class InvokeCredentials(BaseModel):
  27. tool_credentials: dict[str, str] = Field(
  28. default_factory=dict,
  29. description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
  30. )
  31. class PluginInvokeContext(BaseModel):
  32. credentials: Optional[InvokeCredentials] = Field(
  33. default_factory=InvokeCredentials,
  34. description="Credentials context for the plugin invocation or backward invocation.",
  35. )
  36. class RequestInvokeTool(BaseModel):
  37. """
  38. Request to invoke a tool
  39. """
  40. tool_type: Literal["builtin", "workflow", "api", "mcp"]
  41. provider: str
  42. tool: str
  43. tool_parameters: dict
  44. credential_id: Optional[str] = None
  45. class BaseRequestInvokeModel(BaseModel):
  46. provider: str
  47. model: str
  48. model_type: ModelType
  49. model_config = ConfigDict(protected_namespaces=())
  50. class RequestInvokeLLM(BaseRequestInvokeModel):
  51. """
  52. Request to invoke LLM
  53. """
  54. model_type: ModelType = ModelType.LLM
  55. mode: str
  56. completion_params: dict[str, Any] = Field(default_factory=dict)
  57. prompt_messages: list[PromptMessage] = Field(default_factory=list)
  58. tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool])
  59. stop: Optional[list[str]] = Field(default_factory=list[str])
  60. stream: Optional[bool] = False
  61. model_config = ConfigDict(protected_namespaces=())
  62. @field_validator("prompt_messages", mode="before")
  63. @classmethod
  64. def convert_prompt_messages(cls, v):
  65. if not isinstance(v, list):
  66. raise ValueError("prompt_messages must be a list")
  67. for i in range(len(v)):
  68. if v[i]["role"] == PromptMessageRole.USER.value:
  69. v[i] = UserPromptMessage(**v[i])
  70. elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
  71. v[i] = AssistantPromptMessage(**v[i])
  72. elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
  73. v[i] = SystemPromptMessage(**v[i])
  74. elif v[i]["role"] == PromptMessageRole.TOOL.value:
  75. v[i] = ToolPromptMessage(**v[i])
  76. else:
  77. v[i] = PromptMessage(**v[i])
  78. return v
  79. class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
  80. """
  81. Request to invoke LLM with structured output
  82. """
  83. structured_output_schema: dict[str, Any] = Field(
  84. default_factory=dict, description="The schema of the structured output in JSON schema format"
  85. )
  86. class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
  87. """
  88. Request to invoke text embedding
  89. """
  90. model_type: ModelType = ModelType.TEXT_EMBEDDING
  91. texts: list[str]
  92. class RequestInvokeRerank(BaseRequestInvokeModel):
  93. """
  94. Request to invoke rerank
  95. """
  96. model_type: ModelType = ModelType.RERANK
  97. query: str
  98. docs: list[str]
  99. score_threshold: float
  100. top_n: int
  101. class RequestInvokeTTS(BaseRequestInvokeModel):
  102. """
  103. Request to invoke TTS
  104. """
  105. model_type: ModelType = ModelType.TTS
  106. content_text: str
  107. voice: str
  108. class RequestInvokeSpeech2Text(BaseRequestInvokeModel):
  109. """
  110. Request to invoke speech2text
  111. """
  112. model_type: ModelType = ModelType.SPEECH2TEXT
  113. file: bytes
  114. @field_validator("file", mode="before")
  115. @classmethod
  116. def convert_file(cls, v):
  117. # hex string to bytes
  118. if isinstance(v, str):
  119. return bytes.fromhex(v)
  120. else:
  121. raise ValueError("file must be a hex string")
  122. class RequestInvokeModeration(BaseRequestInvokeModel):
  123. """
  124. Request to invoke moderation
  125. """
  126. model_type: ModelType = ModelType.MODERATION
  127. text: str
  128. class RequestInvokeParameterExtractorNode(BaseModel):
  129. """
  130. Request to invoke parameter extractor node
  131. """
  132. parameters: list[ParameterConfig]
  133. model: ParameterExtractorModelConfig
  134. instruction: str
  135. query: str
  136. class RequestInvokeQuestionClassifierNode(BaseModel):
  137. """
  138. Request to invoke question classifier node
  139. """
  140. query: str
  141. model: QuestionClassifierModelConfig
  142. classes: list[ClassConfig]
  143. instruction: str
  144. class RequestInvokeApp(BaseModel):
  145. """
  146. Request to invoke app
  147. """
  148. app_id: str
  149. inputs: dict[str, Any]
  150. query: Optional[str] = None
  151. response_mode: Literal["blocking", "streaming"]
  152. conversation_id: Optional[str] = None
  153. user: Optional[str] = None
  154. files: list[dict] = Field(default_factory=list)
  155. class RequestInvokeEncrypt(BaseModel):
  156. """
  157. Request to encryption
  158. """
  159. opt: Literal["encrypt", "decrypt", "clear"]
  160. namespace: Literal["endpoint"]
  161. identity: str
  162. data: dict = Field(default_factory=dict)
  163. config: list[BasicProviderConfig] = Field(default_factory=list)
  164. class RequestInvokeSummary(BaseModel):
  165. """
  166. Request to summary
  167. """
  168. text: str
  169. instruction: str
  170. class RequestRequestUploadFile(BaseModel):
  171. """
  172. Request to upload file
  173. """
  174. filename: str
  175. mimetype: str
  176. class RequestFetchAppInfo(BaseModel):
  177. """
  178. Request to fetch app info
  179. """
  180. app_id: str