You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tool.py 3.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import time
  2. import uuid
  3. from unittest.mock import MagicMock
  4. from core.app.entities.app_invoke_entities import InvokeFrom
  5. from core.tools.utils.configuration import ToolParameterConfigurationManager
  6. from core.workflow.entities.variable_pool import VariablePool
  7. from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
  8. from core.workflow.enums import SystemVariableKey
  9. from core.workflow.graph_engine.entities.graph import Graph
  10. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  11. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  12. from core.workflow.nodes.event.event import RunCompletedEvent
  13. from core.workflow.nodes.tool.tool_node import ToolNode
  14. from models.enums import UserFrom
  15. from models.workflow import WorkflowType
  16. def init_tool_node(config: dict):
  17. graph_config = {
  18. "edges": [
  19. {
  20. "id": "start-source-next-target",
  21. "source": "start",
  22. "target": "1",
  23. },
  24. ],
  25. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  26. }
  27. graph = Graph.init(graph_config=graph_config)
  28. init_params = GraphInitParams(
  29. tenant_id="1",
  30. app_id="1",
  31. workflow_type=WorkflowType.WORKFLOW,
  32. workflow_id="1",
  33. graph_config=graph_config,
  34. user_id="1",
  35. user_from=UserFrom.ACCOUNT,
  36. invoke_from=InvokeFrom.DEBUGGER,
  37. call_depth=0,
  38. )
  39. # construct variable pool
  40. variable_pool = VariablePool(
  41. system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
  42. user_inputs={},
  43. environment_variables=[],
  44. conversation_variables=[],
  45. )
  46. return ToolNode(
  47. id=str(uuid.uuid4()),
  48. graph_init_params=init_params,
  49. graph=graph,
  50. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  51. config=config,
  52. )
  53. def test_tool_variable_invoke():
  54. node = init_tool_node(
  55. config={
  56. "id": "1",
  57. "data": {
  58. "title": "a",
  59. "desc": "a",
  60. "provider_id": "time",
  61. "provider_type": "builtin",
  62. "provider_name": "time",
  63. "tool_name": "current_time",
  64. "tool_label": "current_time",
  65. "tool_configurations": {},
  66. "tool_parameters": {},
  67. },
  68. }
  69. )
  70. ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
  71. node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
  72. # execute node
  73. result = node._run()
  74. for item in result:
  75. if isinstance(item, RunCompletedEvent):
  76. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  77. assert item.run_result.outputs is not None
  78. assert item.run_result.outputs.get("text") is not None
  79. def test_tool_mixed_invoke():
  80. node = init_tool_node(
  81. config={
  82. "id": "1",
  83. "data": {
  84. "title": "a",
  85. "desc": "a",
  86. "provider_id": "time",
  87. "provider_type": "builtin",
  88. "provider_name": "time",
  89. "tool_name": "current_time",
  90. "tool_label": "current_time",
  91. "tool_configurations": {
  92. "format": "%Y-%m-%d %H:%M:%S",
  93. },
  94. "tool_parameters": {},
  95. },
  96. }
  97. )
  98. ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"})
  99. # execute node
  100. result = node._run()
  101. for item in result:
  102. if isinstance(item, RunCompletedEvent):
  103. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  104. assert item.run_result.outputs is not None
  105. assert item.run_result.outputs.get("text") is not None