- Extract methods used by `ParameterExtractorNode` from `LLMNode` into a separate file. - Convert `ParameterExtractorNode` into a subclass of `BaseNode`. - Refactor code referencing the extracted methods to ensure functionality and clarity. - Fixes the issue that `ParameterExtractorNode` returns error when executed. - Fix relevant test cases. Closes #20840.tags/1.4.2
| @@ -21,7 +21,7 @@ from core.plugin.entities.request import ( | |||
| ) | |||
| from core.tools.entities.tool_entities import ToolProviderType | |||
| from core.tools.utils.model_invocation_utils import ModelInvocationUtils | |||
| from core.workflow.nodes.llm.node import LLMNode | |||
| from core.workflow.nodes.llm import llm_utils | |||
| from models.account import Tenant | |||
| @@ -55,7 +55,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): | |||
| def handle() -> Generator[LLMResultChunk, None, None]: | |||
| for chunk in response: | |||
| if chunk.delta.usage: | |||
| LLMNode.deduct_llm_quota( | |||
| llm_utils.deduct_llm_quota( | |||
| tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage | |||
| ) | |||
| chunk.prompt_messages = [] | |||
| @@ -64,7 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): | |||
| return handle() | |||
| else: | |||
| if response.usage: | |||
| LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) | |||
| llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) | |||
| def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]: | |||
| yield LLMResultChunk( | |||
| @@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform | |||
| from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate | |||
| from core.rag.retrieval.output_parser.react_output import ReactAction | |||
| from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.workflow.nodes.llm import LLMNode | |||
| from core.workflow.nodes.llm import llm_utils | |||
| PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" | |||
| @@ -165,7 +165,7 @@ class ReactMultiDatasetRouter: | |||
| text, usage = self._handle_invoke_result(invoke_result=invoke_result) | |||
| # deduct quota | |||
| LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) | |||
| llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) | |||
| return text, usage | |||
| @@ -0,0 +1,156 @@ | |||
| from collections.abc import Sequence | |||
| from datetime import UTC, datetime | |||
| from typing import Optional, cast | |||
| from sqlalchemy import select, update | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.file.models import File | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.prompt.entities.advanced_prompt_entities import MemoryConfig | |||
| from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.enums import SystemVariableKey | |||
| from core.workflow.nodes.llm.entities import ModelConfig | |||
| from models import db | |||
| from models.model import Conversation | |||
| from models.provider import Provider, ProviderType | |||
| from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError | |||
| def fetch_model_config( | |||
| tenant_id: str, node_data_model: ModelConfig | |||
| ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| if not node_data_model.mode: | |||
| raise LLMModeRequiredError("LLM mode is required.") | |||
| model = ModelManager().get_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.LLM, | |||
| provider=node_data_model.provider, | |||
| model=node_data_model.name, | |||
| ) | |||
| model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) | |||
| # check model | |||
| provider_model = model.provider_model_bundle.configuration.get_provider_model( | |||
| model=node_data_model.name, model_type=ModelType.LLM | |||
| ) | |||
| if provider_model is None: | |||
| raise ModelNotExistError(f"Model {node_data_model.name} not exist.") | |||
| provider_model.raise_for_status() | |||
| # model config | |||
| stop: list[str] = [] | |||
| if "stop" in node_data_model.completion_params: | |||
| stop = node_data_model.completion_params.pop("stop") | |||
| model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) | |||
| if not model_schema: | |||
| raise ModelNotExistError(f"Model {node_data_model.name} not exist.") | |||
| return model, ModelConfigWithCredentialsEntity( | |||
| provider=node_data_model.provider, | |||
| model=node_data_model.name, | |||
| model_schema=model_schema, | |||
| mode=node_data_model.mode, | |||
| provider_model_bundle=model.provider_model_bundle, | |||
| credentials=model.credentials, | |||
| parameters=node_data_model.completion_params, | |||
| stop=stop, | |||
| ) | |||
| def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]: | |||
| variable = variable_pool.get(selector) | |||
| if variable is None: | |||
| return [] | |||
| elif isinstance(variable, FileSegment): | |||
| return [variable.value] | |||
| elif isinstance(variable, ArrayFileSegment): | |||
| return variable.value | |||
| elif isinstance(variable, NoneSegment | ArrayAnySegment): | |||
| return [] | |||
| raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") | |||
| def fetch_memory( | |||
| variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance | |||
| ) -> Optional[TokenBufferMemory]: | |||
| if not node_data_memory: | |||
| return None | |||
| # get conversation id | |||
| conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value]) | |||
| if not isinstance(conversation_id_variable, StringSegment): | |||
| return None | |||
| conversation_id = conversation_id_variable.value | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) | |||
| conversation = session.scalar(stmt) | |||
| if not conversation: | |||
| return None | |||
| memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) | |||
| return memory | |||
| def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: | |||
| provider_model_bundle = model_instance.provider_model_bundle | |||
| provider_configuration = provider_model_bundle.configuration | |||
| if provider_configuration.using_provider_type != ProviderType.SYSTEM: | |||
| return | |||
| system_configuration = provider_configuration.system_configuration | |||
| quota_unit = None | |||
| for quota_configuration in system_configuration.quota_configurations: | |||
| if quota_configuration.quota_type == system_configuration.current_quota_type: | |||
| quota_unit = quota_configuration.quota_unit | |||
| if quota_configuration.quota_limit == -1: | |||
| return | |||
| break | |||
| used_quota = None | |||
| if quota_unit: | |||
| if quota_unit == QuotaUnit.TOKENS: | |||
| used_quota = usage.total_tokens | |||
| elif quota_unit == QuotaUnit.CREDITS: | |||
| used_quota = dify_config.get_model_credits(model_instance.model) | |||
| else: | |||
| used_quota = 1 | |||
| if used_quota is not None and system_configuration.current_quota_type is not None: | |||
| with Session(db.engine) as session: | |||
| stmt = ( | |||
| update(Provider) | |||
| .where( | |||
| Provider.tenant_id == tenant_id, | |||
| # TODO: Use provider name with prefix after the data migration. | |||
| Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == system_configuration.current_quota_type.value, | |||
| Provider.quota_limit > Provider.quota_used, | |||
| ) | |||
| .values( | |||
| quota_used=Provider.quota_used + used_quota, | |||
| last_used=datetime.now(tz=UTC).replace(tzinfo=None), | |||
| ) | |||
| ) | |||
| session.execute(stmt) | |||
| session.commit() | |||
| @@ -3,16 +3,11 @@ import io | |||
| import json | |||
| import logging | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from datetime import UTC, datetime | |||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||
| import json_repair | |||
| from sqlalchemy import select, update | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.file import FileType, file_manager | |||
| from core.helper.code_executor import CodeExecutor, CodeLanguage | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| @@ -40,12 +35,10 @@ from core.model_runtime.entities.model_entities import ( | |||
| ) | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig | |||
| from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.rag.entities.citation_metadata import RetrievalSourceMetadata | |||
| from core.variables import ( | |||
| ArrayAnySegment, | |||
| ArrayFileSegment, | |||
| ArraySegment, | |||
| FileSegment, | |||
| @@ -75,10 +68,8 @@ from core.workflow.utils.structured_output.entities import ( | |||
| ) | |||
| from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation | |||
| from models.provider import Provider, ProviderType | |||
| from . import llm_utils | |||
| from .entities import ( | |||
| LLMNodeChatModelMessage, | |||
| LLMNodeCompletionModelPromptTemplate, | |||
| @@ -88,7 +79,6 @@ from .entities import ( | |||
| from .exc import ( | |||
| InvalidContextStructureError, | |||
| InvalidVariableTypeError, | |||
| LLMModeRequiredError, | |||
| LLMNodeError, | |||
| MemoryRolePrefixRequiredError, | |||
| ModelNotExistError, | |||
| @@ -160,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| result_text = "" | |||
| usage = LLMUsage.empty_usage() | |||
| finish_reason = None | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| try: | |||
| # init messages template | |||
| @@ -178,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| # fetch files | |||
| files = ( | |||
| self._fetch_files(selector=self.node_data.vision.configs.variable_selector) | |||
| llm_utils.fetch_files( | |||
| variable_pool=variable_pool, | |||
| selector=self.node_data.vision.configs.variable_selector, | |||
| ) | |||
| if self.node_data.vision.enabled | |||
| else [] | |||
| ) | |||
| @@ -200,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| model_instance, model_config = self._fetch_model_config(self.node_data.model) | |||
| # fetch memory | |||
| memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) | |||
| memory = llm_utils.fetch_memory( | |||
| variable_pool=variable_pool, | |||
| app_id=self.app_id, | |||
| node_data_memory=self.node_data.memory, | |||
| model_instance=model_instance, | |||
| ) | |||
| query = None | |||
| if self.node_data.memory: | |||
| query = self.node_data.memory.query_prompt_template | |||
| if not query and ( | |||
| query_variable := self.graph_runtime_state.variable_pool.get( | |||
| (SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY) | |||
| ) | |||
| query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) | |||
| ): | |||
| query = query_variable.text | |||
| @@ -222,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| memory_config=self.node_data.memory, | |||
| vision_enabled=self.node_data.vision.enabled, | |||
| vision_detail=self.node_data.vision.configs.detail, | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| variable_pool=variable_pool, | |||
| jinja2_variables=self.node_data.prompt_config.jinja2_variables, | |||
| ) | |||
| @@ -251,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| usage = event.usage | |||
| finish_reason = event.finish_reason | |||
| # deduct quota | |||
| self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) | |||
| llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) | |||
| break | |||
| outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} | |||
| structured_output = process_structured_output(result_text) | |||
| @@ -447,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| return inputs | |||
| def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]: | |||
| variable = self.graph_runtime_state.variable_pool.get(selector) | |||
| if variable is None: | |||
| return [] | |||
| elif isinstance(variable, FileSegment): | |||
| return [variable.value] | |||
| elif isinstance(variable, ArrayFileSegment): | |||
| return variable.value | |||
| elif isinstance(variable, NoneSegment | ArrayAnySegment): | |||
| return [] | |||
| raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") | |||
| def _fetch_context(self, node_data: LLMNodeData): | |||
| if not node_data.context.enabled: | |||
| return | |||
| @@ -524,31 +509,10 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| def _fetch_model_config( | |||
| self, node_data_model: ModelConfig | |||
| ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| if not node_data_model.mode: | |||
| raise LLMModeRequiredError("LLM mode is required.") | |||
| model = ModelManager().get_model_instance( | |||
| tenant_id=self.tenant_id, | |||
| model_type=ModelType.LLM, | |||
| provider=node_data_model.provider, | |||
| model=node_data_model.name, | |||
| model, model_config_with_cred = llm_utils.fetch_model_config( | |||
| tenant_id=self.tenant_id, node_data_model=node_data_model | |||
| ) | |||
| model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) | |||
| # check model | |||
| provider_model = model.provider_model_bundle.configuration.get_provider_model( | |||
| model=node_data_model.name, model_type=ModelType.LLM | |||
| ) | |||
| if provider_model is None: | |||
| raise ModelNotExistError(f"Model {node_data_model.name} not exist.") | |||
| provider_model.raise_for_status() | |||
| # model config | |||
| stop: list[str] = [] | |||
| if "stop" in node_data_model.completion_params: | |||
| stop = node_data_model.completion_params.pop("stop") | |||
| completion_params = model_config_with_cred.parameters | |||
| model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) | |||
| if not model_schema: | |||
| @@ -556,47 +520,12 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| if self.node_data.structured_output_enabled: | |||
| if model_schema.support_structure_output: | |||
| node_data_model.completion_params = self._handle_native_json_schema( | |||
| node_data_model.completion_params, model_schema.parameter_rules | |||
| ) | |||
| completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) | |||
| else: | |||
| # Set appropriate response format based on model capabilities | |||
| self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) | |||
| return model, ModelConfigWithCredentialsEntity( | |||
| provider=node_data_model.provider, | |||
| model=node_data_model.name, | |||
| model_schema=model_schema, | |||
| mode=node_data_model.mode, | |||
| provider_model_bundle=model.provider_model_bundle, | |||
| credentials=model.credentials, | |||
| parameters=node_data_model.completion_params, | |||
| stop=stop, | |||
| ) | |||
| def _fetch_memory( | |||
| self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance | |||
| ) -> Optional[TokenBufferMemory]: | |||
| if not node_data_memory: | |||
| return None | |||
| # get conversation id | |||
| conversation_id_variable = self.graph_runtime_state.variable_pool.get( | |||
| ["sys", SystemVariableKey.CONVERSATION_ID.value] | |||
| ) | |||
| if not isinstance(conversation_id_variable, StringSegment): | |||
| return None | |||
| conversation_id = conversation_id_variable.value | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) | |||
| conversation = session.scalar(stmt) | |||
| if not conversation: | |||
| return None | |||
| memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) | |||
| return memory | |||
| self._set_response_format(completion_params, model_schema.parameter_rules) | |||
| model_config_with_cred.parameters = completion_params | |||
| return model, model_config_with_cred | |||
| def _fetch_prompt_messages( | |||
| self, | |||
| @@ -810,55 +739,6 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| structured_output = parsed | |||
| return structured_output | |||
| @classmethod | |||
| def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: | |||
| provider_model_bundle = model_instance.provider_model_bundle | |||
| provider_configuration = provider_model_bundle.configuration | |||
| if provider_configuration.using_provider_type != ProviderType.SYSTEM: | |||
| return | |||
| system_configuration = provider_configuration.system_configuration | |||
| quota_unit = None | |||
| for quota_configuration in system_configuration.quota_configurations: | |||
| if quota_configuration.quota_type == system_configuration.current_quota_type: | |||
| quota_unit = quota_configuration.quota_unit | |||
| if quota_configuration.quota_limit == -1: | |||
| return | |||
| break | |||
| used_quota = None | |||
| if quota_unit: | |||
| if quota_unit == QuotaUnit.TOKENS: | |||
| used_quota = usage.total_tokens | |||
| elif quota_unit == QuotaUnit.CREDITS: | |||
| used_quota = dify_config.get_model_credits(model_instance.model) | |||
| else: | |||
| used_quota = 1 | |||
| if used_quota is not None and system_configuration.current_quota_type is not None: | |||
| with Session(db.engine) as session: | |||
| stmt = ( | |||
| update(Provider) | |||
| .where( | |||
| Provider.tenant_id == tenant_id, | |||
| # TODO: Use provider name with prefix after the data migration. | |||
| Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == system_configuration.current_quota_type.value, | |||
| Provider.quota_limit > Provider.quota_used, | |||
| ) | |||
| .values( | |||
| quota_used=Provider.quota_used + used_quota, | |||
| last_used=datetime.now(tz=UTC).replace(tzinfo=None), | |||
| ) | |||
| ) | |||
| session.execute(stmt) | |||
| session.commit() | |||
| @classmethod | |||
| def _extract_variable_selector_to_variable_mapping( | |||
| cls, | |||
| @@ -28,8 +28,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil | |||
| from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus | |||
| from core.workflow.nodes.base.node import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from core.workflow.nodes.llm import LLMNode, ModelConfig | |||
| from core.workflow.nodes.llm import ModelConfig, llm_utils | |||
| from core.workflow.utils import variable_template_parser | |||
| from .entities import ParameterExtractorNodeData | |||
| @@ -83,7 +84,7 @@ def extract_json(text): | |||
| return None | |||
| class ParameterExtractorNode(LLMNode): | |||
| class ParameterExtractorNode(BaseNode): | |||
| """ | |||
| Parameter Extractor Node. | |||
| """ | |||
| @@ -116,8 +117,11 @@ class ParameterExtractorNode(LLMNode): | |||
| variable = self.graph_runtime_state.variable_pool.get(node_data.query) | |||
| query = variable.text if variable else "" | |||
| variable_pool = self.graph_runtime_state.variable_pool | |||
| files = ( | |||
| self._fetch_files( | |||
| llm_utils.fetch_files( | |||
| variable_pool=variable_pool, | |||
| selector=node_data.vision.configs.variable_selector, | |||
| ) | |||
| if node_data.vision.enabled | |||
| @@ -137,7 +141,9 @@ class ParameterExtractorNode(LLMNode): | |||
| raise ModelSchemaNotFoundError("Model schema not found") | |||
| # fetch memory | |||
| memory = self._fetch_memory( | |||
| memory = llm_utils.fetch_memory( | |||
| variable_pool=variable_pool, | |||
| app_id=self.app_id, | |||
| node_data_memory=node_data.memory, | |||
| model_instance=model_instance, | |||
| ) | |||
| @@ -279,7 +285,7 @@ class ParameterExtractorNode(LLMNode): | |||
| tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None | |||
| # deduct quota | |||
| self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) | |||
| llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) | |||
| if text is None: | |||
| text = "" | |||
| @@ -794,7 +800,9 @@ class ParameterExtractorNode(LLMNode): | |||
| Fetch model config. | |||
| """ | |||
| if not self._model_instance or not self._model_config: | |||
| self._model_instance, self._model_config = super()._fetch_model_config(node_data_model) | |||
| self._model_instance, self._model_config = llm_utils.fetch_model_config( | |||
| tenant_id=self.tenant_id, node_data_model=node_data_model | |||
| ) | |||
| return self._model_instance, self._model_config | |||
| @@ -18,6 +18,7 @@ from core.workflow.nodes.llm import ( | |||
| LLMNode, | |||
| LLMNodeChatModelMessage, | |||
| LLMNodeCompletionModelPromptTemplate, | |||
| llm_utils, | |||
| ) | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| from libs.json_in_md_parser import parse_and_check_json_markdown | |||
| @@ -50,7 +51,9 @@ class QuestionClassifierNode(LLMNode): | |||
| # fetch model config | |||
| model_instance, model_config = self._fetch_model_config(node_data.model) | |||
| # fetch memory | |||
| memory = self._fetch_memory( | |||
| memory = llm_utils.fetch_memory( | |||
| variable_pool=variable_pool, | |||
| app_id=self.app_id, | |||
| node_data_memory=node_data.memory, | |||
| model_instance=model_instance, | |||
| ) | |||
| @@ -59,7 +62,8 @@ class QuestionClassifierNode(LLMNode): | |||
| node_data.instruction = variable_pool.convert_template(node_data.instruction).text | |||
| files = ( | |||
| self._fetch_files( | |||
| llm_utils.fetch_files( | |||
| variable_pool=variable_pool, | |||
| selector=node_data.vision.configs.variable_selector, | |||
| ) | |||
| if node_data.vision.enabled | |||
| @@ -353,7 +353,7 @@ def test_extract_json_from_tool_call(): | |||
| assert result["location"] == "kawaii" | |||
| def test_chat_parameter_extractor_with_memory(setup_model_mock): | |||
| def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): | |||
| """ | |||
| Test chat parameter extractor with memory. | |||
| """ | |||
| @@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock): | |||
| mode="chat", | |||
| credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, | |||
| ) | |||
| node._fetch_memory = get_mocked_fetch_memory("customized memory") | |||
| # Test the mock before running the actual test | |||
| monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory")) | |||
| db.session.close = MagicMock() | |||
| result = node._run() | |||
| @@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState | |||
| from core.workflow.nodes.answer import AnswerStreamGenerateRoute | |||
| from core.workflow.nodes.end import EndStreamParam | |||
| from core.workflow.nodes.llm import llm_utils | |||
| from core.workflow.nodes.llm.entities import ( | |||
| ContextConfig, | |||
| LLMNodeChatModelMessage, | |||
| @@ -170,7 +171,7 @@ def model_config(): | |||
| ) | |||
| def test_fetch_files_with_file_segment(llm_node): | |||
| def test_fetch_files_with_file_segment(): | |||
| file = File( | |||
| id="1", | |||
| tenant_id="test", | |||
| @@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node): | |||
| related_id="1", | |||
| storage_key="", | |||
| ) | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) | |||
| variable_pool = VariablePool() | |||
| variable_pool.add(["sys", "files"], file) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) | |||
| assert result == [file] | |||
| def test_fetch_files_with_array_file_segment(llm_node): | |||
| def test_fetch_files_with_array_file_segment(): | |||
| files = [ | |||
| File( | |||
| id="1", | |||
| @@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node): | |||
| storage_key="", | |||
| ), | |||
| ] | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) | |||
| variable_pool = VariablePool() | |||
| variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) | |||
| assert result == files | |||
| def test_fetch_files_with_none_segment(llm_node): | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) | |||
| def test_fetch_files_with_none_segment(): | |||
| variable_pool = VariablePool() | |||
| variable_pool.add(["sys", "files"], NoneSegment()) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) | |||
| assert result == [] | |||
| def test_fetch_files_with_array_any_segment(llm_node): | |||
| llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) | |||
| def test_fetch_files_with_array_any_segment(): | |||
| variable_pool = VariablePool() | |||
| variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) | |||
| assert result == [] | |||
| def test_fetch_files_with_non_existent_variable(llm_node): | |||
| result = llm_node._fetch_files(selector=["sys", "files"]) | |||
| def test_fetch_files_with_non_existent_variable(): | |||
| variable_pool = VariablePool() | |||
| result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) | |||
| assert result == [] | |||