You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_llm.py 7.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import json
  2. import os
  3. import time
  4. import uuid
  5. from collections.abc import Generator
  6. from unittest.mock import MagicMock
  7. import pytest
  8. from core.app.entities.app_invoke_entities import InvokeFrom
  9. from core.workflow.entities.variable_pool import VariablePool
  10. from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
  11. from core.workflow.enums import SystemVariableKey
  12. from core.workflow.graph_engine.entities.graph import Graph
  13. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  14. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  15. from core.workflow.nodes.event import RunCompletedEvent
  16. from core.workflow.nodes.llm.node import LLMNode
  17. from extensions.ext_database import db
  18. from models.enums import UserFrom
  19. from models.workflow import WorkflowType
  20. from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
  21. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  22. from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
  23. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  24. def init_llm_node(config: dict) -> LLMNode:
  25. graph_config = {
  26. "edges": [
  27. {
  28. "id": "start-source-next-target",
  29. "source": "start",
  30. "target": "llm",
  31. },
  32. ],
  33. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  34. }
  35. graph = Graph.init(graph_config=graph_config)
  36. init_params = GraphInitParams(
  37. tenant_id="1",
  38. app_id="1",
  39. workflow_type=WorkflowType.WORKFLOW,
  40. workflow_id="1",
  41. graph_config=graph_config,
  42. user_id="1",
  43. user_from=UserFrom.ACCOUNT,
  44. invoke_from=InvokeFrom.DEBUGGER,
  45. call_depth=0,
  46. )
  47. # construct variable pool
  48. variable_pool = VariablePool(
  49. system_variables={
  50. SystemVariableKey.QUERY: "what's the weather today?",
  51. SystemVariableKey.FILES: [],
  52. SystemVariableKey.CONVERSATION_ID: "abababa",
  53. SystemVariableKey.USER_ID: "aaa",
  54. },
  55. user_inputs={},
  56. environment_variables=[],
  57. conversation_variables=[],
  58. )
  59. variable_pool.add(["abc", "output"], "sunny")
  60. node = LLMNode(
  61. id=str(uuid.uuid4()),
  62. graph_init_params=init_params,
  63. graph=graph,
  64. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  65. config=config,
  66. )
  67. return node
  68. def test_execute_llm(setup_model_mock):
  69. node = init_llm_node(
  70. config={
  71. "id": "llm",
  72. "data": {
  73. "title": "123",
  74. "type": "llm",
  75. "model": {
  76. "provider": "langgenius/openai/openai",
  77. "name": "gpt-3.5-turbo",
  78. "mode": "chat",
  79. "completion_params": {},
  80. },
  81. "prompt_template": [
  82. {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
  83. {"role": "user", "text": "{{#sys.query#}}"},
  84. ],
  85. "memory": None,
  86. "context": {"enabled": False},
  87. "vision": {"enabled": False},
  88. },
  89. },
  90. )
  91. credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  92. # Mock db.session.close()
  93. db.session.close = MagicMock()
  94. node._fetch_model_config = get_mocked_fetch_model_config(
  95. provider="langgenius/openai/openai",
  96. model="gpt-3.5-turbo",
  97. mode="chat",
  98. credentials=credentials,
  99. )
  100. # execute node
  101. result = node._run()
  102. assert isinstance(result, Generator)
  103. for item in result:
  104. if isinstance(item, RunCompletedEvent):
  105. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  106. assert item.run_result.process_data is not None
  107. assert item.run_result.outputs is not None
  108. assert item.run_result.outputs.get("text") is not None
  109. assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
  110. @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
  111. def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock):
  112. """
  113. Test execute LLM node with jinja2
  114. """
  115. node = init_llm_node(
  116. config={
  117. "id": "llm",
  118. "data": {
  119. "title": "123",
  120. "type": "llm",
  121. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  122. "prompt_config": {
  123. "jinja2_variables": [
  124. {"variable": "sys_query", "value_selector": ["sys", "query"]},
  125. {"variable": "output", "value_selector": ["abc", "output"]},
  126. ]
  127. },
  128. "prompt_template": [
  129. {
  130. "role": "system",
  131. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
  132. "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
  133. "edition_type": "jinja2",
  134. },
  135. {
  136. "role": "user",
  137. "text": "{{#sys.query#}}",
  138. "jinja2_text": "{{sys_query}}",
  139. "edition_type": "basic",
  140. },
  141. ],
  142. "memory": None,
  143. "context": {"enabled": False},
  144. "vision": {"enabled": False},
  145. },
  146. },
  147. )
  148. credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  149. # Mock db.session.close()
  150. db.session.close = MagicMock()
  151. node._fetch_model_config = get_mocked_fetch_model_config(
  152. provider="langgenius/openai/openai",
  153. model="gpt-3.5-turbo",
  154. mode="chat",
  155. credentials=credentials,
  156. )
  157. # execute node
  158. result = node._run()
  159. for item in result:
  160. if isinstance(item, RunCompletedEvent):
  161. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  162. assert item.run_result.process_data is not None
  163. assert "sunny" in json.dumps(item.run_result.process_data)
  164. assert "what's the weather today?" in json.dumps(item.run_result.process_data)
  165. def test_extract_json():
  166. node = init_llm_node(
  167. config={
  168. "id": "llm",
  169. "data": {
  170. "title": "123",
  171. "type": "llm",
  172. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  173. "prompt_config": {
  174. "structured_output": {
  175. "enabled": True,
  176. "schema": {
  177. "type": "object",
  178. "properties": {"name": {"type": "string"}, "age": {"type": "number"}},
  179. },
  180. }
  181. },
  182. "prompt_template": [{"role": "user", "text": "{{#sys.query#}}"}],
  183. "memory": None,
  184. "context": {"enabled": False},
  185. "vision": {"enabled": False},
  186. },
  187. },
  188. )
  189. llm_texts = [
  190. '<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
  191. '{"name":"test","age":123}', # json schema model (gpt-4o)
  192. '{\n "name": "test",\n "age": 123\n}', # small model (llama-3.2-1b)
  193. '```json\n{"name": "test", "age": 123}\n```', # json markdown (deepseek-chat)
  194. '{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
  195. ]
  196. result = {"name": "test", "age": 123}
  197. assert all(node._parse_structured_output(item) == result for item in llm_texts)