You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
  2. from core.workflow.nodes.enums import NodeType
  3. from core.workflow.nodes.parameter_extractor.entities import (
  4. ModelConfig as ParameterExtractorModelConfig,
  5. )
  6. from core.workflow.nodes.parameter_extractor.entities import (
  7. ParameterConfig,
  8. ParameterExtractorNodeData,
  9. )
  10. from core.workflow.nodes.question_classifier.entities import (
  11. ClassConfig,
  12. QuestionClassifierNodeData,
  13. )
  14. from core.workflow.nodes.question_classifier.entities import (
  15. ModelConfig as QuestionClassifierModelConfig,
  16. )
  17. from services.workflow_service import WorkflowService
  18. class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
  19. @classmethod
  20. def invoke_parameter_extractor(
  21. cls,
  22. tenant_id: str,
  23. user_id: str,
  24. parameters: list[ParameterConfig],
  25. model_config: ParameterExtractorModelConfig,
  26. instruction: str,
  27. query: str,
  28. ) -> dict:
  29. """
  30. Invoke parameter extractor node.
  31. :param tenant_id: str
  32. :param user_id: str
  33. :param parameters: list[ParameterConfig]
  34. :param model_config: ModelConfig
  35. :param instruction: str
  36. :param query: str
  37. :return: dict
  38. """
  39. # FIXME(-LAN-): Avoid import service into core
  40. workflow_service = WorkflowService()
  41. node_id = "1919810"
  42. node_data = ParameterExtractorNodeData(
  43. title="parameter_extractor",
  44. desc="parameter_extractor",
  45. parameters=parameters,
  46. reasoning_mode="function_call",
  47. query=[node_id, "query"],
  48. model=model_config,
  49. instruction=instruction, # instruct with variables are not supported
  50. )
  51. node_data_dict = node_data.model_dump()
  52. node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
  53. execution = workflow_service.run_free_workflow_node(
  54. node_data_dict,
  55. tenant_id=tenant_id,
  56. user_id=user_id,
  57. node_id=node_id,
  58. user_inputs={
  59. f"{node_id}.query": query,
  60. },
  61. )
  62. return {
  63. "inputs": execution.inputs_dict,
  64. "outputs": execution.outputs_dict,
  65. "process_data": execution.process_data_dict,
  66. }
  67. @classmethod
  68. def invoke_question_classifier(
  69. cls,
  70. tenant_id: str,
  71. user_id: str,
  72. model_config: QuestionClassifierModelConfig,
  73. classes: list[ClassConfig],
  74. instruction: str,
  75. query: str,
  76. ) -> dict:
  77. """
  78. Invoke question classifier node.
  79. :param tenant_id: str
  80. :param user_id: str
  81. :param model_config: ModelConfig
  82. :param classes: list[ClassConfig]
  83. :param instruction: str
  84. :param query: str
  85. :return: dict
  86. """
  87. # FIXME(-LAN-): Avoid import service into core
  88. workflow_service = WorkflowService()
  89. node_id = "1919810"
  90. node_data = QuestionClassifierNodeData(
  91. title="question_classifier",
  92. desc="question_classifier",
  93. query_variable_selector=[node_id, "query"],
  94. model=model_config,
  95. classes=classes,
  96. instruction=instruction, # instruct with variables are not supported
  97. )
  98. node_data_dict = node_data.model_dump()
  99. execution = workflow_service.run_free_workflow_node(
  100. node_data_dict,
  101. tenant_id=tenant_id,
  102. user_id=user_id,
  103. node_id=node_id,
  104. user_inputs={
  105. f"{node_id}.query": query,
  106. },
  107. )
  108. return {
  109. "inputs": execution.inputs_dict,
  110. "outputs": execution.outputs_dict,
  111. "process_data": execution.process_data_dict,
  112. }