You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_llm.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import json
  2. import os
  3. import time
  4. import uuid
  5. from collections.abc import Generator
  6. from decimal import Decimal
  7. from unittest.mock import MagicMock, patch
  8. import pytest
  9. from core.app.entities.app_invoke_entities import InvokeFrom
  10. from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
  11. from core.model_runtime.entities.message_entities import AssistantPromptMessage
  12. from core.workflow.entities.variable_pool import VariablePool
  13. from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
  14. from core.workflow.enums import SystemVariableKey
  15. from core.workflow.graph_engine.entities.graph import Graph
  16. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  17. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  18. from core.workflow.nodes.event import RunCompletedEvent
  19. from core.workflow.nodes.llm.node import LLMNode
  20. from extensions.ext_database import db
  21. from models.enums import UserFrom
  22. from models.workflow import WorkflowType
  23. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  24. from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
  25. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  26. def init_llm_node(config: dict) -> LLMNode:
  27. graph_config = {
  28. "edges": [
  29. {
  30. "id": "start-source-next-target",
  31. "source": "start",
  32. "target": "llm",
  33. },
  34. ],
  35. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  36. }
  37. graph = Graph.init(graph_config=graph_config)
  38. # Use proper UUIDs for database compatibility
  39. tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  40. app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
  41. workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
  42. user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
  43. init_params = GraphInitParams(
  44. tenant_id=tenant_id,
  45. app_id=app_id,
  46. workflow_type=WorkflowType.WORKFLOW,
  47. workflow_id=workflow_id,
  48. graph_config=graph_config,
  49. user_id=user_id,
  50. user_from=UserFrom.ACCOUNT,
  51. invoke_from=InvokeFrom.DEBUGGER,
  52. call_depth=0,
  53. )
  54. # construct variable pool
  55. variable_pool = VariablePool(
  56. system_variables={
  57. SystemVariableKey.QUERY: "what's the weather today?",
  58. SystemVariableKey.FILES: [],
  59. SystemVariableKey.CONVERSATION_ID: "abababa",
  60. SystemVariableKey.USER_ID: "aaa",
  61. },
  62. user_inputs={},
  63. environment_variables=[],
  64. conversation_variables=[],
  65. )
  66. variable_pool.add(["abc", "output"], "sunny")
  67. node = LLMNode(
  68. id=str(uuid.uuid4()),
  69. graph_init_params=init_params,
  70. graph=graph,
  71. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  72. config=config,
  73. )
  74. return node
  75. def test_execute_llm(flask_req_ctx):
  76. node = init_llm_node(
  77. config={
  78. "id": "llm",
  79. "data": {
  80. "title": "123",
  81. "type": "llm",
  82. "model": {
  83. "provider": "langgenius/openai/openai",
  84. "name": "gpt-3.5-turbo",
  85. "mode": "chat",
  86. "completion_params": {},
  87. },
  88. "prompt_template": [
  89. {
  90. "role": "system",
  91. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
  92. },
  93. {"role": "user", "text": "{{#sys.query#}}"},
  94. ],
  95. "memory": None,
  96. "context": {"enabled": False},
  97. "vision": {"enabled": False},
  98. },
  99. },
  100. )
  101. credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  102. # Create a proper LLM result with real entities
  103. mock_usage = LLMUsage(
  104. prompt_tokens=30,
  105. prompt_unit_price=Decimal("0.001"),
  106. prompt_price_unit=Decimal("1000"),
  107. prompt_price=Decimal("0.00003"),
  108. completion_tokens=20,
  109. completion_unit_price=Decimal("0.002"),
  110. completion_price_unit=Decimal("1000"),
  111. completion_price=Decimal("0.00004"),
  112. total_tokens=50,
  113. total_price=Decimal("0.00007"),
  114. currency="USD",
  115. latency=0.5,
  116. )
  117. mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
  118. mock_llm_result = LLMResult(
  119. model="gpt-3.5-turbo",
  120. prompt_messages=[],
  121. message=mock_message,
  122. usage=mock_usage,
  123. )
  124. # Create a simple mock model instance that doesn't call real providers
  125. mock_model_instance = MagicMock()
  126. mock_model_instance.invoke_llm.return_value = mock_llm_result
  127. # Create a simple mock model config with required attributes
  128. mock_model_config = MagicMock()
  129. mock_model_config.mode = "chat"
  130. mock_model_config.provider = "langgenius/openai/openai"
  131. mock_model_config.model = "gpt-3.5-turbo"
  132. mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  133. # Mock the _fetch_model_config method
  134. def mock_fetch_model_config_func(_node_data_model):
  135. return mock_model_instance, mock_model_config
  136. # Also mock ModelManager.get_model_instance to avoid database calls
  137. def mock_get_model_instance(_self, **kwargs):
  138. return mock_model_instance
  139. with (
  140. patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
  141. patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
  142. ):
  143. # execute node
  144. result = node._run()
  145. assert isinstance(result, Generator)
  146. for item in result:
  147. if isinstance(item, RunCompletedEvent):
  148. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  149. assert item.run_result.process_data is not None
  150. assert item.run_result.outputs is not None
  151. assert item.run_result.outputs.get("text") is not None
  152. assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
  153. @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
  154. def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
  155. """
  156. Test execute LLM node with jinja2
  157. """
  158. node = init_llm_node(
  159. config={
  160. "id": "llm",
  161. "data": {
  162. "title": "123",
  163. "type": "llm",
  164. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  165. "prompt_config": {
  166. "jinja2_variables": [
  167. {"variable": "sys_query", "value_selector": ["sys", "query"]},
  168. {"variable": "output", "value_selector": ["abc", "output"]},
  169. ]
  170. },
  171. "prompt_template": [
  172. {
  173. "role": "system",
  174. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
  175. "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
  176. "edition_type": "jinja2",
  177. },
  178. {
  179. "role": "user",
  180. "text": "{{#sys.query#}}",
  181. "jinja2_text": "{{sys_query}}",
  182. "edition_type": "basic",
  183. },
  184. ],
  185. "memory": None,
  186. "context": {"enabled": False},
  187. "vision": {"enabled": False},
  188. },
  189. },
  190. )
  191. # Mock db.session.close()
  192. db.session.close = MagicMock()
  193. # Create a proper LLM result with real entities
  194. mock_usage = LLMUsage(
  195. prompt_tokens=30,
  196. prompt_unit_price=Decimal("0.001"),
  197. prompt_price_unit=Decimal("1000"),
  198. prompt_price=Decimal("0.00003"),
  199. completion_tokens=20,
  200. completion_unit_price=Decimal("0.002"),
  201. completion_price_unit=Decimal("1000"),
  202. completion_price=Decimal("0.00004"),
  203. total_tokens=50,
  204. total_price=Decimal("0.00007"),
  205. currency="USD",
  206. latency=0.5,
  207. )
  208. mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
  209. mock_llm_result = LLMResult(
  210. model="gpt-3.5-turbo",
  211. prompt_messages=[],
  212. message=mock_message,
  213. usage=mock_usage,
  214. )
  215. # Create a simple mock model instance that doesn't call real providers
  216. mock_model_instance = MagicMock()
  217. mock_model_instance.invoke_llm.return_value = mock_llm_result
  218. # Create a simple mock model config with required attributes
  219. mock_model_config = MagicMock()
  220. mock_model_config.mode = "chat"
  221. mock_model_config.provider = "openai"
  222. mock_model_config.model = "gpt-3.5-turbo"
  223. mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  224. # Mock the _fetch_model_config method
  225. def mock_fetch_model_config_func(_node_data_model):
  226. return mock_model_instance, mock_model_config
  227. # Also mock ModelManager.get_model_instance to avoid database calls
  228. def mock_get_model_instance(_self, **kwargs):
  229. return mock_model_instance
  230. with (
  231. patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
  232. patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
  233. ):
  234. # execute node
  235. result = node._run()
  236. for item in result:
  237. if isinstance(item, RunCompletedEvent):
  238. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  239. assert item.run_result.process_data is not None
  240. assert "sunny" in json.dumps(item.run_result.process_data)
  241. assert "what's the weather today?" in json.dumps(item.run_result.process_data)
  242. def test_extract_json():
  243. node = init_llm_node(
  244. config={
  245. "id": "llm",
  246. "data": {
  247. "title": "123",
  248. "type": "llm",
  249. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  250. "prompt_config": {
  251. "structured_output": {
  252. "enabled": True,
  253. "schema": {
  254. "type": "object",
  255. "properties": {"name": {"type": "string"}, "age": {"type": "number"}},
  256. },
  257. }
  258. },
  259. "prompt_template": [{"role": "user", "text": "{{#sys.query#}}"}],
  260. "memory": None,
  261. "context": {"enabled": False},
  262. "vision": {"enabled": False},
  263. },
  264. },
  265. )
  266. llm_texts = [
  267. '<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
  268. '{"name":"test","age":123}', # json schema model (gpt-4o)
  269. '{\n "name": "test",\n "age": 123\n}', # small model (llama-3.2-1b)
  270. '```json\n{"name": "test", "age": 123}\n```', # json markdown (deepseek-chat)
  271. '{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
  272. ]
  273. result = {"name": "test", "age": 123}
  274. assert all(node._parse_structured_output(item) == result for item in llm_texts)