您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

test_tool.py 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import time
  2. import uuid
  3. from unittest.mock import MagicMock
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from core.tools.utils.configuration import ToolParameterConfigurationManager
  6. from core.workflow.entities.variable_pool import VariablePool
  7. from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
  8. from core.workflow.graph_engine.entities.graph import Graph
  9. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  10. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  11. from core.workflow.nodes.event.event import RunCompletedEvent
  12. from core.workflow.nodes.tool.tool_node import ToolNode
  13. from core.workflow.system_variable import SystemVariable
  14. from models.enums import UserFrom
  15. from models.workflow import WorkflowType
  16. def init_tool_node(config: dict):
  17. graph_config = {
  18. "edges": [
  19. {
  20. "id": "start-source-next-target",
  21. "source": "start",
  22. "target": "1",
  23. },
  24. ],
  25. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  26. }
  27. graph = Graph.init(graph_config=graph_config)
  28. init_params = GraphInitParams(
  29. tenant_id="1",
  30. app_id="1",
  31. workflow_type=WorkflowType.WORKFLOW,
  32. workflow_id="1",
  33. graph_config=graph_config,
  34. user_id="1",
  35. user_from=UserFrom.ACCOUNT,
  36. invoke_from=InvokeFrom.DEBUGGER,
  37. call_depth=0,
  38. )
  39. # construct variable pool
  40. variable_pool = VariablePool(
  41. system_variables=SystemVariable(user_id="aaa", files=[]),
  42. user_inputs={},
  43. environment_variables=[],
  44. conversation_variables=[],
  45. )
  46. node = ToolNode(
  47. id=str(uuid.uuid4()),
  48. graph_init_params=init_params,
  49. graph=graph,
  50. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  51. config=config,
  52. )
  53. node.init_node_data(config.get("data", {}))
  54. return node
  55. def test_tool_variable_invoke():
  56. node = init_tool_node(
  57. config={
  58. "id": "1",
  59. "data": {
  60. "title": "a",
  61. "desc": "a",
  62. "provider_id": "time",
  63. "provider_type": "builtin",
  64. "provider_name": "time",
  65. "tool_name": "current_time",
  66. "tool_label": "current_time",
  67. "tool_configurations": {},
  68. "tool_parameters": {},
  69. },
  70. }
  71. )
  72. ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
  73. node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
  74. # execute node
  75. result = node._run()
  76. for item in result:
  77. if isinstance(item, RunCompletedEvent):
  78. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  79. assert item.run_result.outputs is not None
  80. assert item.run_result.outputs.get("text") is not None
  81. def test_tool_mixed_invoke():
  82. node = init_tool_node(
  83. config={
  84. "id": "1",
  85. "data": {
  86. "title": "a",
  87. "desc": "a",
  88. "provider_id": "time",
  89. "provider_type": "builtin",
  90. "provider_name": "time",
  91. "tool_name": "current_time",
  92. "tool_label": "current_time",
  93. "tool_configurations": {
  94. "format": "%Y-%m-%d %H:%M:%S",
  95. },
  96. "tool_parameters": {},
  97. },
  98. }
  99. )
  100. ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
  101. # execute node
  102. result = node._run()
  103. for item in result:
  104. if isinstance(item, RunCompletedEvent):
  105. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  106. assert item.run_result.outputs is not None
  107. assert item.run_result.outputs.get("text") is not None