您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

test_llm.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import json
  2. import time
  3. import uuid
  4. from collections.abc import Generator
  5. from decimal import Decimal
  6. from unittest.mock import MagicMock, patch
  7. import pytest
  8. from core.app.entities.app_invoke_entities import InvokeFrom
  9. from core.llm_generator.output_parser.structured_output import _parse_structured_output
  10. from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
  11. from core.model_runtime.entities.message_entities import AssistantPromptMessage
  12. from core.workflow.entities.variable_pool import VariablePool
  13. from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
  14. from core.workflow.enums import SystemVariableKey
  15. from core.workflow.graph_engine.entities.graph import Graph
  16. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  17. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  18. from core.workflow.nodes.event import RunCompletedEvent
  19. from core.workflow.nodes.llm.node import LLMNode
  20. from extensions.ext_database import db
  21. from models.enums import UserFrom
  22. from models.workflow import WorkflowType
  23. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  24. from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
  25. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  26. def init_llm_node(config: dict) -> LLMNode:
  27. graph_config = {
  28. "edges": [
  29. {
  30. "id": "start-source-next-target",
  31. "source": "start",
  32. "target": "llm",
  33. },
  34. ],
  35. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  36. }
  37. graph = Graph.init(graph_config=graph_config)
  38. # Use proper UUIDs for database compatibility
  39. tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  40. app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
  41. workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
  42. user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
  43. init_params = GraphInitParams(
  44. tenant_id=tenant_id,
  45. app_id=app_id,
  46. workflow_type=WorkflowType.WORKFLOW,
  47. workflow_id=workflow_id,
  48. graph_config=graph_config,
  49. user_id=user_id,
  50. user_from=UserFrom.ACCOUNT,
  51. invoke_from=InvokeFrom.DEBUGGER,
  52. call_depth=0,
  53. )
  54. # construct variable pool
  55. variable_pool = VariablePool(
  56. system_variables={
  57. SystemVariableKey.QUERY: "what's the weather today?",
  58. SystemVariableKey.FILES: [],
  59. SystemVariableKey.CONVERSATION_ID: "abababa",
  60. SystemVariableKey.USER_ID: "aaa",
  61. },
  62. user_inputs={},
  63. environment_variables=[],
  64. conversation_variables=[],
  65. )
  66. variable_pool.add(["abc", "output"], "sunny")
  67. node = LLMNode(
  68. id=str(uuid.uuid4()),
  69. graph_init_params=init_params,
  70. graph=graph,
  71. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  72. config=config,
  73. )
  74. return node
  75. def test_execute_llm(flask_req_ctx):
  76. node = init_llm_node(
  77. config={
  78. "id": "llm",
  79. "data": {
  80. "title": "123",
  81. "type": "llm",
  82. "model": {
  83. "provider": "langgenius/openai/openai",
  84. "name": "gpt-3.5-turbo",
  85. "mode": "chat",
  86. "completion_params": {},
  87. },
  88. "prompt_template": [
  89. {
  90. "role": "system",
  91. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
  92. },
  93. {"role": "user", "text": "{{#sys.query#}}"},
  94. ],
  95. "memory": None,
  96. "context": {"enabled": False},
  97. "vision": {"enabled": False},
  98. },
  99. },
  100. )
  101. # Create a proper LLM result with real entities
  102. mock_usage = LLMUsage(
  103. prompt_tokens=30,
  104. prompt_unit_price=Decimal("0.001"),
  105. prompt_price_unit=Decimal(1000),
  106. prompt_price=Decimal("0.00003"),
  107. completion_tokens=20,
  108. completion_unit_price=Decimal("0.002"),
  109. completion_price_unit=Decimal(1000),
  110. completion_price=Decimal("0.00004"),
  111. total_tokens=50,
  112. total_price=Decimal("0.00007"),
  113. currency="USD",
  114. latency=0.5,
  115. )
  116. mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
  117. mock_llm_result = LLMResult(
  118. model="gpt-3.5-turbo",
  119. prompt_messages=[],
  120. message=mock_message,
  121. usage=mock_usage,
  122. )
  123. # Create a simple mock model instance that doesn't call real providers
  124. mock_model_instance = MagicMock()
  125. mock_model_instance.invoke_llm.return_value = mock_llm_result
  126. # Create a simple mock model config with required attributes
  127. mock_model_config = MagicMock()
  128. mock_model_config.mode = "chat"
  129. mock_model_config.provider = "langgenius/openai/openai"
  130. mock_model_config.model = "gpt-3.5-turbo"
  131. mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  132. # Mock the _fetch_model_config method
  133. def mock_fetch_model_config_func(_node_data_model):
  134. return mock_model_instance, mock_model_config
  135. # Also mock ModelManager.get_model_instance to avoid database calls
  136. def mock_get_model_instance(_self, **kwargs):
  137. return mock_model_instance
  138. with (
  139. patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
  140. patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
  141. ):
  142. # execute node
  143. result = node._run()
  144. assert isinstance(result, Generator)
  145. for item in result:
  146. if isinstance(item, RunCompletedEvent):
  147. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  148. assert item.run_result.process_data is not None
  149. assert item.run_result.outputs is not None
  150. assert item.run_result.outputs.get("text") is not None
  151. assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
  152. @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
  153. def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
  154. """
  155. Test execute LLM node with jinja2
  156. """
  157. node = init_llm_node(
  158. config={
  159. "id": "llm",
  160. "data": {
  161. "title": "123",
  162. "type": "llm",
  163. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  164. "prompt_config": {
  165. "jinja2_variables": [
  166. {"variable": "sys_query", "value_selector": ["sys", "query"]},
  167. {"variable": "output", "value_selector": ["abc", "output"]},
  168. ]
  169. },
  170. "prompt_template": [
  171. {
  172. "role": "system",
  173. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
  174. "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
  175. "edition_type": "jinja2",
  176. },
  177. {
  178. "role": "user",
  179. "text": "{{#sys.query#}}",
  180. "jinja2_text": "{{sys_query}}",
  181. "edition_type": "basic",
  182. },
  183. ],
  184. "memory": None,
  185. "context": {"enabled": False},
  186. "vision": {"enabled": False},
  187. },
  188. },
  189. )
  190. # Mock db.session.close()
  191. db.session.close = MagicMock()
  192. # Create a proper LLM result with real entities
  193. mock_usage = LLMUsage(
  194. prompt_tokens=30,
  195. prompt_unit_price=Decimal("0.001"),
  196. prompt_price_unit=Decimal(1000),
  197. prompt_price=Decimal("0.00003"),
  198. completion_tokens=20,
  199. completion_unit_price=Decimal("0.002"),
  200. completion_price_unit=Decimal(1000),
  201. completion_price=Decimal("0.00004"),
  202. total_tokens=50,
  203. total_price=Decimal("0.00007"),
  204. currency="USD",
  205. latency=0.5,
  206. )
  207. mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
  208. mock_llm_result = LLMResult(
  209. model="gpt-3.5-turbo",
  210. prompt_messages=[],
  211. message=mock_message,
  212. usage=mock_usage,
  213. )
  214. # Create a simple mock model instance that doesn't call real providers
  215. mock_model_instance = MagicMock()
  216. mock_model_instance.invoke_llm.return_value = mock_llm_result
  217. # Create a simple mock model config with required attributes
  218. mock_model_config = MagicMock()
  219. mock_model_config.mode = "chat"
  220. mock_model_config.provider = "openai"
  221. mock_model_config.model = "gpt-3.5-turbo"
  222. mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  223. # Mock the _fetch_model_config method
  224. def mock_fetch_model_config_func(_node_data_model):
  225. return mock_model_instance, mock_model_config
  226. # Also mock ModelManager.get_model_instance to avoid database calls
  227. def mock_get_model_instance(_self, **kwargs):
  228. return mock_model_instance
  229. with (
  230. patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
  231. patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
  232. ):
  233. # execute node
  234. result = node._run()
  235. for item in result:
  236. if isinstance(item, RunCompletedEvent):
  237. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  238. assert item.run_result.process_data is not None
  239. assert "sunny" in json.dumps(item.run_result.process_data)
  240. assert "what's the weather today?" in json.dumps(item.run_result.process_data)
  241. def test_extract_json():
  242. llm_texts = [
  243. '<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
  244. '{"name":"test","age":123}', # json schema model (gpt-4o)
  245. '{\n "name": "test",\n "age": 123\n}', # small model (llama-3.2-1b)
  246. '```json\n{"name": "test", "age": 123}\n```', # json markdown (deepseek-chat)
  247. '{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
  248. ]
  249. result = {"name": "test", "age": 123}
  250. assert all(_parse_structured_output(item) == result for item in llm_texts)