Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

test_llm.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import json
  2. import os
  3. import time
  4. import uuid
  5. from collections.abc import Generator
  6. from decimal import Decimal
  7. from unittest.mock import MagicMock, patch
  8. import pytest
  9. from core.app.entities.app_invoke_entities import InvokeFrom
  10. from core.llm_generator.output_parser.structured_output import _parse_structured_output
  11. from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
  12. from core.model_runtime.entities.message_entities import AssistantPromptMessage
  13. from core.workflow.entities.variable_pool import VariablePool
  14. from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
  15. from core.workflow.enums import SystemVariableKey
  16. from core.workflow.graph_engine.entities.graph import Graph
  17. from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
  18. from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
  19. from core.workflow.nodes.event import RunCompletedEvent
  20. from core.workflow.nodes.llm.node import LLMNode
  21. from extensions.ext_database import db
  22. from models.enums import UserFrom
  23. from models.workflow import WorkflowType
  24. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  25. from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
  26. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  27. def init_llm_node(config: dict) -> LLMNode:
  28. graph_config = {
  29. "edges": [
  30. {
  31. "id": "start-source-next-target",
  32. "source": "start",
  33. "target": "llm",
  34. },
  35. ],
  36. "nodes": [{"data": {"type": "start"}, "id": "start"}, config],
  37. }
  38. graph = Graph.init(graph_config=graph_config)
  39. # Use proper UUIDs for database compatibility
  40. tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  41. app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
  42. workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
  43. user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
  44. init_params = GraphInitParams(
  45. tenant_id=tenant_id,
  46. app_id=app_id,
  47. workflow_type=WorkflowType.WORKFLOW,
  48. workflow_id=workflow_id,
  49. graph_config=graph_config,
  50. user_id=user_id,
  51. user_from=UserFrom.ACCOUNT,
  52. invoke_from=InvokeFrom.DEBUGGER,
  53. call_depth=0,
  54. )
  55. # construct variable pool
  56. variable_pool = VariablePool(
  57. system_variables={
  58. SystemVariableKey.QUERY: "what's the weather today?",
  59. SystemVariableKey.FILES: [],
  60. SystemVariableKey.CONVERSATION_ID: "abababa",
  61. SystemVariableKey.USER_ID: "aaa",
  62. },
  63. user_inputs={},
  64. environment_variables=[],
  65. conversation_variables=[],
  66. )
  67. variable_pool.add(["abc", "output"], "sunny")
  68. node = LLMNode(
  69. id=str(uuid.uuid4()),
  70. graph_init_params=init_params,
  71. graph=graph,
  72. graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
  73. config=config,
  74. )
  75. return node
  76. def test_execute_llm(flask_req_ctx):
  77. node = init_llm_node(
  78. config={
  79. "id": "llm",
  80. "data": {
  81. "title": "123",
  82. "type": "llm",
  83. "model": {
  84. "provider": "langgenius/openai/openai",
  85. "name": "gpt-3.5-turbo",
  86. "mode": "chat",
  87. "completion_params": {},
  88. },
  89. "prompt_template": [
  90. {
  91. "role": "system",
  92. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
  93. },
  94. {"role": "user", "text": "{{#sys.query#}}"},
  95. ],
  96. "memory": None,
  97. "context": {"enabled": False},
  98. "vision": {"enabled": False},
  99. },
  100. },
  101. )
  102. credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  103. # Create a proper LLM result with real entities
  104. mock_usage = LLMUsage(
  105. prompt_tokens=30,
  106. prompt_unit_price=Decimal("0.001"),
  107. prompt_price_unit=Decimal("1000"),
  108. prompt_price=Decimal("0.00003"),
  109. completion_tokens=20,
  110. completion_unit_price=Decimal("0.002"),
  111. completion_price_unit=Decimal("1000"),
  112. completion_price=Decimal("0.00004"),
  113. total_tokens=50,
  114. total_price=Decimal("0.00007"),
  115. currency="USD",
  116. latency=0.5,
  117. )
  118. mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
  119. mock_llm_result = LLMResult(
  120. model="gpt-3.5-turbo",
  121. prompt_messages=[],
  122. message=mock_message,
  123. usage=mock_usage,
  124. )
  125. # Create a simple mock model instance that doesn't call real providers
  126. mock_model_instance = MagicMock()
  127. mock_model_instance.invoke_llm.return_value = mock_llm_result
  128. # Create a simple mock model config with required attributes
  129. mock_model_config = MagicMock()
  130. mock_model_config.mode = "chat"
  131. mock_model_config.provider = "langgenius/openai/openai"
  132. mock_model_config.model = "gpt-3.5-turbo"
  133. mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  134. # Mock the _fetch_model_config method
  135. def mock_fetch_model_config_func(_node_data_model):
  136. return mock_model_instance, mock_model_config
  137. # Also mock ModelManager.get_model_instance to avoid database calls
  138. def mock_get_model_instance(_self, **kwargs):
  139. return mock_model_instance
  140. with (
  141. patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
  142. patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
  143. ):
  144. # execute node
  145. result = node._run()
  146. assert isinstance(result, Generator)
  147. for item in result:
  148. if isinstance(item, RunCompletedEvent):
  149. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  150. assert item.run_result.process_data is not None
  151. assert item.run_result.outputs is not None
  152. assert item.run_result.outputs.get("text") is not None
  153. assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
  154. @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
  155. def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
  156. """
  157. Test execute LLM node with jinja2
  158. """
  159. node = init_llm_node(
  160. config={
  161. "id": "llm",
  162. "data": {
  163. "title": "123",
  164. "type": "llm",
  165. "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
  166. "prompt_config": {
  167. "jinja2_variables": [
  168. {"variable": "sys_query", "value_selector": ["sys", "query"]},
  169. {"variable": "output", "value_selector": ["abc", "output"]},
  170. ]
  171. },
  172. "prompt_template": [
  173. {
  174. "role": "system",
  175. "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
  176. "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
  177. "edition_type": "jinja2",
  178. },
  179. {
  180. "role": "user",
  181. "text": "{{#sys.query#}}",
  182. "jinja2_text": "{{sys_query}}",
  183. "edition_type": "basic",
  184. },
  185. ],
  186. "memory": None,
  187. "context": {"enabled": False},
  188. "vision": {"enabled": False},
  189. },
  190. },
  191. )
  192. # Mock db.session.close()
  193. db.session.close = MagicMock()
  194. # Create a proper LLM result with real entities
  195. mock_usage = LLMUsage(
  196. prompt_tokens=30,
  197. prompt_unit_price=Decimal("0.001"),
  198. prompt_price_unit=Decimal("1000"),
  199. prompt_price=Decimal("0.00003"),
  200. completion_tokens=20,
  201. completion_unit_price=Decimal("0.002"),
  202. completion_price_unit=Decimal("1000"),
  203. completion_price=Decimal("0.00004"),
  204. total_tokens=50,
  205. total_price=Decimal("0.00007"),
  206. currency="USD",
  207. latency=0.5,
  208. )
  209. mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
  210. mock_llm_result = LLMResult(
  211. model="gpt-3.5-turbo",
  212. prompt_messages=[],
  213. message=mock_message,
  214. usage=mock_usage,
  215. )
  216. # Create a simple mock model instance that doesn't call real providers
  217. mock_model_instance = MagicMock()
  218. mock_model_instance.invoke_llm.return_value = mock_llm_result
  219. # Create a simple mock model config with required attributes
  220. mock_model_config = MagicMock()
  221. mock_model_config.mode = "chat"
  222. mock_model_config.provider = "openai"
  223. mock_model_config.model = "gpt-3.5-turbo"
  224. mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
  225. # Mock the _fetch_model_config method
  226. def mock_fetch_model_config_func(_node_data_model):
  227. return mock_model_instance, mock_model_config
  228. # Also mock ModelManager.get_model_instance to avoid database calls
  229. def mock_get_model_instance(_self, **kwargs):
  230. return mock_model_instance
  231. with (
  232. patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
  233. patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
  234. ):
  235. # execute node
  236. result = node._run()
  237. for item in result:
  238. if isinstance(item, RunCompletedEvent):
  239. assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  240. assert item.run_result.process_data is not None
  241. assert "sunny" in json.dumps(item.run_result.process_data)
  242. assert "what's the weather today?" in json.dumps(item.run_result.process_data)
  243. def test_extract_json():
  244. llm_texts = [
  245. '<think>\n\n</think>{"name": "test", "age": 123', # resoning model (deepseek-r1)
  246. '{"name":"test","age":123}', # json schema model (gpt-4o)
  247. '{\n "name": "test",\n "age": 123\n}', # small model (llama-3.2-1b)
  248. '```json\n{"name": "test", "age": 123}\n```', # json markdown (deepseek-chat)
  249. '{"name":"test",age:123}', # without quotes (qwen-2.5-0.5b)
  250. ]
  251. result = {"name": "test", "age": 123}
  252. assert all(_parse_structured_output(item) == result for item in llm_texts)