Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.2
| @@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel): | |||
| status: ModelStatus | |||
| load_balancing_enabled: bool = False | |||
| def raise_for_status(self) -> None: | |||
| """ | |||
| Check model status and raise ValueError if not active. | |||
| :raises ValueError: When model status is not active, with a descriptive message | |||
| """ | |||
| if self.status == ModelStatus.ACTIVE: | |||
| return | |||
| error_messages = { | |||
| ModelStatus.NO_CONFIGURE: "Model is not configured", | |||
| ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", | |||
| ModelStatus.NO_PERMISSION: "No permission to use this model", | |||
| ModelStatus.DISABLED: "Model is disabled", | |||
| } | |||
| if self.status in error_messages: | |||
| raise ValueError(error_messages[self.status]) | |||
| class ModelWithProviderEntity(ProviderModelWithStatusEntity): | |||
| """ | |||
| @@ -41,45 +41,53 @@ class Extensible: | |||
| extensions = [] | |||
| position_map: dict[str, int] = {} | |||
| # get the path of the current class | |||
| current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") | |||
| current_dir_path = os.path.dirname(current_path) | |||
| # traverse subdirectories | |||
| for subdir_name in os.listdir(current_dir_path): | |||
| if subdir_name.startswith("__"): | |||
| continue | |||
| subdir_path = os.path.join(current_dir_path, subdir_name) | |||
| extension_name = subdir_name | |||
| if os.path.isdir(subdir_path): | |||
| # Get the package name from the module path | |||
| package_name = ".".join(cls.__module__.split(".")[:-1]) | |||
| try: | |||
| # Get package directory path | |||
| package_spec = importlib.util.find_spec(package_name) | |||
| if not package_spec or not package_spec.origin: | |||
| raise ImportError(f"Could not find package {package_name}") | |||
| package_dir = os.path.dirname(package_spec.origin) | |||
| # Traverse subdirectories | |||
| for subdir_name in os.listdir(package_dir): | |||
| if subdir_name.startswith("__"): | |||
| continue | |||
| subdir_path = os.path.join(package_dir, subdir_name) | |||
| if not os.path.isdir(subdir_path): | |||
| continue | |||
| extension_name = subdir_name | |||
| file_names = os.listdir(subdir_path) | |||
| # is builtin extension, builtin extension | |||
| # in the front-end page and business logic, there are special treatments. | |||
| # Check for extension module file | |||
| if (extension_name + ".py") not in file_names: | |||
| logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") | |||
| continue | |||
| # Check for builtin flag and position | |||
| builtin = False | |||
| # default position is 0 can not be None for sort_to_dict_by_position_map | |||
| position = 0 | |||
| if "__builtin__" in file_names: | |||
| builtin = True | |||
| builtin_file_path = os.path.join(subdir_path, "__builtin__") | |||
| if os.path.exists(builtin_file_path): | |||
| position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) | |||
| position_map[extension_name] = position | |||
| if (extension_name + ".py") not in file_names: | |||
| logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") | |||
| continue | |||
| # Dynamic loading {subdir_name}.py file and find the subclass of Extensible | |||
| py_path = os.path.join(subdir_path, extension_name + ".py") | |||
| spec = importlib.util.spec_from_file_location(extension_name, py_path) | |||
| # Import the extension module | |||
| module_name = f"{package_name}.{extension_name}.{extension_name}" | |||
| spec = importlib.util.find_spec(module_name) | |||
| if not spec or not spec.loader: | |||
| raise Exception(f"Failed to load module {extension_name} from {py_path}") | |||
| raise ImportError(f"Failed to load module {module_name}") | |||
| mod = importlib.util.module_from_spec(spec) | |||
| spec.loader.exec_module(mod) | |||
| # Find extension class | |||
| extension_class = None | |||
| for name, obj in vars(mod).items(): | |||
| if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: | |||
| @@ -87,21 +95,21 @@ class Extensible: | |||
| break | |||
| if not extension_class: | |||
| logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") | |||
| logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.") | |||
| continue | |||
| # Load schema if not builtin | |||
| json_data: dict[str, Any] = {} | |||
| if not builtin: | |||
| if "schema.json" not in file_names: | |||
| json_path = os.path.join(subdir_path, "schema.json") | |||
| if not os.path.exists(json_path): | |||
| logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") | |||
| continue | |||
| json_path = os.path.join(subdir_path, "schema.json") | |||
| json_data = {} | |||
| if os.path.exists(json_path): | |||
| with open(json_path, encoding="utf-8") as f: | |||
| json_data = json.load(f) | |||
| with open(json_path, encoding="utf-8") as f: | |||
| json_data = json.load(f) | |||
| # Create extension | |||
| extensions.append( | |||
| ModuleExtension( | |||
| extension_class=extension_class, | |||
| @@ -113,6 +121,11 @@ class Extensible: | |||
| ) | |||
| ) | |||
| except Exception as e: | |||
| logging.exception("Error scanning extensions") | |||
| raise | |||
| # Sort extensions by position | |||
| sorted_extensions = sort_to_dict_by_position_map( | |||
| position_map=position_map, data=extensions, name_func=lambda x: x.name | |||
| ) | |||
| @@ -160,6 +160,10 @@ class ProviderModel(BaseModel): | |||
| deprecated: bool = False | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| @property | |||
| def support_structure_output(self) -> bool: | |||
| return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features | |||
| class ParameterRule(BaseModel): | |||
| """ | |||
| @@ -3,7 +3,9 @@ from collections import defaultdict | |||
| from json import JSONDecodeError | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.exc import IntegrityError | |||
| from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | |||
| @@ -393,19 +395,13 @@ class ProviderManager: | |||
| @staticmethod | |||
| def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: | |||
| """ | |||
| Get all provider records of the workspace. | |||
| :param tenant_id: workspace id | |||
| :return: | |||
| """ | |||
| providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() | |||
| provider_name_to_provider_records_dict = defaultdict(list) | |||
| for provider in providers: | |||
| # TODO: Use provider name with prefix after the data migration | |||
| provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) | |||
| providers = session.scalars(stmt) | |||
| for provider in providers: | |||
| # Use provider name with prefix after the data migration | |||
| provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) | |||
| return provider_name_to_provider_records_dict | |||
| @staticmethod | |||
| @@ -416,17 +412,12 @@ class ProviderManager: | |||
| :param tenant_id: workspace id | |||
| :return: | |||
| """ | |||
| # Get all provider model records of the workspace | |||
| provider_models = ( | |||
| db.session.query(ProviderModel) | |||
| .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) | |||
| .all() | |||
| ) | |||
| provider_name_to_provider_model_records_dict = defaultdict(list) | |||
| for provider_model in provider_models: | |||
| provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) | |||
| provider_models = session.scalars(stmt) | |||
| for provider_model in provider_models: | |||
| provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) | |||
| return provider_name_to_provider_model_records_dict | |||
| @staticmethod | |||
| @@ -437,17 +428,14 @@ class ProviderManager: | |||
| :param tenant_id: workspace id | |||
| :return: | |||
| """ | |||
| preferred_provider_types = ( | |||
| db.session.query(TenantPreferredModelProvider) | |||
| .filter(TenantPreferredModelProvider.tenant_id == tenant_id) | |||
| .all() | |||
| ) | |||
| provider_name_to_preferred_provider_type_records_dict = { | |||
| preferred_provider_type.provider_name: preferred_provider_type | |||
| for preferred_provider_type in preferred_provider_types | |||
| } | |||
| provider_name_to_preferred_provider_type_records_dict = {} | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) | |||
| preferred_provider_types = session.scalars(stmt) | |||
| provider_name_to_preferred_provider_type_records_dict = { | |||
| preferred_provider_type.provider_name: preferred_provider_type | |||
| for preferred_provider_type in preferred_provider_types | |||
| } | |||
| return provider_name_to_preferred_provider_type_records_dict | |||
| @staticmethod | |||
| @@ -458,18 +446,14 @@ class ProviderManager: | |||
| :param tenant_id: workspace id | |||
| :return: | |||
| """ | |||
| provider_model_settings = ( | |||
| db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() | |||
| ) | |||
| provider_name_to_provider_model_settings_dict = defaultdict(list) | |||
| for provider_model_setting in provider_model_settings: | |||
| ( | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) | |||
| provider_model_settings = session.scalars(stmt) | |||
| for provider_model_setting in provider_model_settings: | |||
| provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( | |||
| provider_model_setting | |||
| ) | |||
| ) | |||
| return provider_name_to_provider_model_settings_dict | |||
| @staticmethod | |||
| @@ -492,15 +476,14 @@ class ProviderManager: | |||
| if not model_load_balancing_enabled: | |||
| return {} | |||
| provider_load_balancing_configs = ( | |||
| db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() | |||
| ) | |||
| provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) | |||
| for provider_load_balancing_config in provider_load_balancing_configs: | |||
| provider_name_to_provider_load_balancing_model_configs_dict[ | |||
| provider_load_balancing_config.provider_name | |||
| ].append(provider_load_balancing_config) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) | |||
| provider_load_balancing_configs = session.scalars(stmt) | |||
| for provider_load_balancing_config in provider_load_balancing_configs: | |||
| provider_name_to_provider_load_balancing_model_configs_dict[ | |||
| provider_load_balancing_config.provider_name | |||
| ].append(provider_load_balancing_config) | |||
| return provider_name_to_provider_load_balancing_model_configs_dict | |||
| @@ -626,10 +609,9 @@ class ProviderManager: | |||
| if not cached_provider_credentials: | |||
| try: | |||
| # fix origin data | |||
| if ( | |||
| custom_provider_record.encrypted_config | |||
| and not custom_provider_record.encrypted_config.startswith("{") | |||
| ): | |||
| if custom_provider_record.encrypted_config is None: | |||
| raise ValueError("No credentials found") | |||
| if not custom_provider_record.encrypted_config.startswith("{"): | |||
| provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | |||
| else: | |||
| provider_credentials = json.loads(custom_provider_record.encrypted_config) | |||
| @@ -733,7 +715,7 @@ class ProviderManager: | |||
| return SystemConfiguration(enabled=False) | |||
| # Convert provider_records to dict | |||
| quota_type_to_provider_records_dict = {} | |||
| quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} | |||
| for provider_record in provider_records: | |||
| if provider_record.provider_type != ProviderType.SYSTEM.value: | |||
| continue | |||
| @@ -758,6 +740,11 @@ class ProviderManager: | |||
| else: | |||
| provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] | |||
| if provider_record.quota_used is None: | |||
| raise ValueError("quota_used is None") | |||
| if provider_record.quota_limit is None: | |||
| raise ValueError("quota_limit is None") | |||
| quota_configuration = QuotaConfiguration( | |||
| quota_type=provider_quota.quota_type, | |||
| quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, | |||
| @@ -791,10 +778,9 @@ class ProviderManager: | |||
| cached_provider_credentials = provider_credentials_cache.get() | |||
| if not cached_provider_credentials: | |||
| try: | |||
| provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) | |||
| except JSONDecodeError: | |||
| provider_credentials = {} | |||
| provider_credentials: dict[str, Any] = {} | |||
| if provider_records and provider_records[0].encrypted_config: | |||
| provider_credentials = json.loads(provider_records[0].encrypted_config) | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self._extract_secret_variables( | |||
| @@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData): | |||
| context: ContextConfig | |||
| vision: VisionConfig = Field(default_factory=VisionConfig) | |||
| structured_output: dict | None = None | |||
| structured_output_enabled: bool = False | |||
| # We used 'structured_output_enabled' in the past, but it's not a good name. | |||
| structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") | |||
| @field_validator("prompt_config", mode="before") | |||
| @classmethod | |||
| @@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData): | |||
| if v is None: | |||
| return PromptConfig() | |||
| return v | |||
| @property | |||
| def structured_output_enabled(self) -> bool: | |||
| return self.structured_output_switch_on and self.structured_output is not None | |||
| @@ -12,9 +12,7 @@ from sqlalchemy.orm import Session | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.file import FileType, file_manager | |||
| from core.helper.code_executor import CodeExecutor, CodeLanguage | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| @@ -74,7 +72,6 @@ from core.workflow.nodes.event import ( | |||
| from core.workflow.utils.structured_output.entities import ( | |||
| ResponseFormat, | |||
| SpecialModelType, | |||
| SupportStructuredOutputStatus, | |||
| ) | |||
| from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT | |||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | |||
| @@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| llm_usage=usage, | |||
| ) | |||
| ) | |||
| except LLMNodeError as e: | |||
| except ValueError as e: | |||
| yield RunCompletedEvent( | |||
| run_result=NodeRunResult( | |||
| status=WorkflowNodeExecutionStatus.FAILED, | |||
| @@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| def _fetch_model_config( | |||
| self, node_data_model: ModelConfig | |||
| ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | |||
| model_name = node_data_model.name | |||
| provider_name = node_data_model.provider | |||
| if not node_data_model.mode: | |||
| raise LLMModeRequiredError("LLM mode is required.") | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_model_instance( | |||
| tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name | |||
| model = ModelManager().get_model_instance( | |||
| tenant_id=self.tenant_id, | |||
| model_type=ModelType.LLM, | |||
| provider=node_data_model.provider, | |||
| model=node_data_model.name, | |||
| ) | |||
| provider_model_bundle = model_instance.provider_model_bundle | |||
| model_type_instance = model_instance.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| model_credentials = model_instance.credentials | |||
| model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) | |||
| # check model | |||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||
| model=model_name, model_type=ModelType.LLM | |||
| provider_model = model.provider_model_bundle.configuration.get_provider_model( | |||
| model=node_data_model.name, model_type=ModelType.LLM | |||
| ) | |||
| if provider_model is None: | |||
| raise ModelNotExistError(f"Model {model_name} not exist.") | |||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||
| raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||
| raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") | |||
| raise ModelNotExistError(f"Model {node_data_model.name} not exist.") | |||
| provider_model.raise_for_status() | |||
| # model config | |||
| completion_params = node_data_model.completion_params | |||
| stop = [] | |||
| if "stop" in completion_params: | |||
| stop = completion_params["stop"] | |||
| del completion_params["stop"] | |||
| # get model mode | |||
| model_mode = node_data_model.mode | |||
| if not model_mode: | |||
| raise LLMModeRequiredError("LLM mode is required.") | |||
| model_schema = model_type_instance.get_model_schema(model_name, model_credentials) | |||
| stop: list[str] = [] | |||
| if "stop" in node_data_model.completion_params: | |||
| stop = node_data_model.completion_params.pop("stop") | |||
| model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) | |||
| if not model_schema: | |||
| raise ModelNotExistError(f"Model {model_name} not exist.") | |||
| support_structured_output = self._check_model_structured_output_support() | |||
| if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: | |||
| completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) | |||
| elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: | |||
| # Set appropriate response format based on model capabilities | |||
| self._set_response_format(completion_params, model_schema.parameter_rules) | |||
| return model_instance, ModelConfigWithCredentialsEntity( | |||
| provider=provider_name, | |||
| model=model_name, | |||
| raise ModelNotExistError(f"Model {node_data_model.name} not exist.") | |||
| if self.node_data.structured_output_enabled: | |||
| if model_schema.support_structure_output: | |||
| node_data_model.completion_params = self._handle_native_json_schema( | |||
| node_data_model.completion_params, model_schema.parameter_rules | |||
| ) | |||
| else: | |||
| # Set appropriate response format based on model capabilities | |||
| self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) | |||
| return model, ModelConfigWithCredentialsEntity( | |||
| provider=node_data_model.provider, | |||
| model=node_data_model.name, | |||
| model_schema=model_schema, | |||
| mode=model_mode, | |||
| provider_model_bundle=provider_model_bundle, | |||
| credentials=model_credentials, | |||
| parameters=completion_params, | |||
| mode=node_data_model.mode, | |||
| provider_model_bundle=model.provider_model_bundle, | |||
| credentials=model.credentials, | |||
| parameters=node_data_model.completion_params, | |||
| stop=stop, | |||
| ) | |||
| @@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| "No prompt found in the LLM configuration. " | |||
| "Please ensure a prompt is properly configured before proceeding." | |||
| ) | |||
| support_structured_output = self._check_model_structured_output_support() | |||
| if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: | |||
| filtered_prompt_messages = self._handle_prompt_based_schema( | |||
| prompt_messages=filtered_prompt_messages, | |||
| ) | |||
| stop = model_config.stop | |||
| return filtered_prompt_messages, stop | |||
| model = ModelManager().get_model_instance( | |||
| tenant_id=self.tenant_id, | |||
| model_type=ModelType.LLM, | |||
| provider=self.node_data.model.provider, | |||
| model=self.node_data.model.name, | |||
| ) | |||
| model_schema = model.model_type_instance.get_model_schema( | |||
| model=self.node_data.model.name, | |||
| credentials=model.credentials, | |||
| ) | |||
| if not model_schema: | |||
| raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.") | |||
| if self.node_data.structured_output_enabled: | |||
| if not model_schema.support_structure_output: | |||
| filtered_prompt_messages = self._handle_prompt_based_schema( | |||
| prompt_messages=filtered_prompt_messages, | |||
| ) | |||
| return filtered_prompt_messages, model_config.stop | |||
| def _parse_structured_output(self, result_text: str) -> dict[str, Any]: | |||
| structured_output: dict[str, Any] = {} | |||
| @@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| except json.JSONDecodeError: | |||
| raise LLMNodeError("structured_output_schema is not valid JSON format") | |||
| def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: | |||
| """ | |||
| Check if the current model supports structured output. | |||
| Returns: | |||
| SupportStructuredOutput: The support status of structured output | |||
| """ | |||
| # Early return if structured output is disabled | |||
| if ( | |||
| not isinstance(self.node_data, LLMNodeData) | |||
| or not self.node_data.structured_output_enabled | |||
| or not self.node_data.structured_output | |||
| ): | |||
| return SupportStructuredOutputStatus.DISABLED | |||
| # Get model schema and check if it exists | |||
| model_schema = self._fetch_model_schema(self.node_data.model.provider) | |||
| if not model_schema: | |||
| return SupportStructuredOutputStatus.DISABLED | |||
| # Check if model supports structured output feature | |||
| return ( | |||
| SupportStructuredOutputStatus.SUPPORTED | |||
| if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) | |||
| else SupportStructuredOutputStatus.UNSUPPORTED | |||
| ) | |||
| def _save_multimodal_output_and_convert_result_to_markdown( | |||
| self, | |||
| contents: str | list[PromptMessageContentUnionTypes] | None, | |||
| @@ -14,11 +14,3 @@ class SpecialModelType(StrEnum): | |||
| GEMINI = "gemini" | |||
| OLLAMA = "ollama" | |||
| class SupportStructuredOutputStatus(StrEnum): | |||
| """Constants for structured output support status""" | |||
| SUPPORTED = "supported" | |||
| UNSUPPORTED = "unsupported" | |||
| DISABLED = "disabled" | |||
| @@ -1,6 +1,9 @@ | |||
| from datetime import datetime | |||
| from enum import Enum | |||
| from typing import Optional | |||
| from sqlalchemy import func | |||
| from sqlalchemy import func, text | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| from .base import Base | |||
| from .engine import db | |||
| @@ -51,20 +54,24 @@ class Provider(Base): | |||
| ), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) | |||
| encrypted_config = db.Column(db.Text, nullable=True) | |||
| is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| last_used = db.Column(db.DateTime, nullable=True) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| provider_type: Mapped[str] = mapped_column( | |||
| db.String(40), nullable=False, server_default=text("'custom'::character varying") | |||
| ) | |||
| encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) | |||
| is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) | |||
| last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) | |||
| quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) | |||
| quota_limit = db.Column(db.BigInteger, nullable=True) | |||
| quota_used = db.Column(db.BigInteger, default=0) | |||
| quota_type: Mapped[Optional[str]] = mapped_column( | |||
| db.String(40), nullable=True, server_default=text("''::character varying") | |||
| ) | |||
| quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) | |||
| quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| def __repr__(self): | |||
| return ( | |||
| @@ -104,15 +111,15 @@ class ProviderModel(Base): | |||
| ), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| model_type = db.Column(db.String(40), nullable=False) | |||
| encrypted_config = db.Column(db.Text, nullable=True) | |||
| is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||
| encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) | |||
| is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class TenantDefaultModel(Base): | |||
| @@ -122,13 +129,13 @@ class TenantDefaultModel(Base): | |||
| db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| model_type = db.Column(db.String(40), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class TenantPreferredModelProvider(Base): | |||
| @@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base): | |||
| db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| preferred_provider_type = db.Column(db.String(40), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class ProviderOrder(Base): | |||
| @@ -153,22 +160,24 @@ class ProviderOrder(Base): | |||
| db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| account_id = db.Column(StringUUID, nullable=False) | |||
| payment_product_id = db.Column(db.String(191), nullable=False) | |||
| payment_id = db.Column(db.String(191)) | |||
| transaction_id = db.Column(db.String(191)) | |||
| quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) | |||
| currency = db.Column(db.String(40)) | |||
| total_amount = db.Column(db.Integer) | |||
| payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) | |||
| paid_at = db.Column(db.DateTime) | |||
| pay_failed_at = db.Column(db.DateTime) | |||
| refunded_at = db.Column(db.DateTime) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) | |||
| payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) | |||
| transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) | |||
| quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) | |||
| currency: Mapped[Optional[str]] = mapped_column(db.String(40)) | |||
| total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) | |||
| payment_status: Mapped[str] = mapped_column( | |||
| db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") | |||
| ) | |||
| paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) | |||
| pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) | |||
| refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class ProviderModelSetting(Base): | |||
| @@ -182,15 +191,15 @@ class ProviderModelSetting(Base): | |||
| db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| model_type = db.Column(db.String(40), nullable=False) | |||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||
| load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||
| enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) | |||
| load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class LoadBalancingModelConfig(Base): | |||
| @@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base): | |||
| db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| model_type = db.Column(db.String(40), nullable=False) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| encrypted_config = db.Column(db.Text, nullable=True) | |||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||
| name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||
| encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) | |||
| enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) | |||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @@ -3,11 +3,16 @@ import os | |||
| import time | |||
| import uuid | |||
| from collections.abc import Generator | |||
| from unittest.mock import MagicMock | |||
| from decimal import Decimal | |||
| from unittest.mock import MagicMock, patch | |||
| import pytest | |||
| from app_factory import create_app | |||
| from configs import dify_config | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||
| from core.model_runtime.entities.message_entities import AssistantPromptMessage | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | |||
| from core.workflow.enums import SystemVariableKey | |||
| @@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode | |||
| from extensions.ext_database import db | |||
| from models.enums import UserFrom | |||
| from models.workflow import WorkflowType | |||
| from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config | |||
| """FOR MOCK FIXTURES, DO NOT REMOVE""" | |||
| from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock | |||
| from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | |||
| @pytest.fixture(scope="session") | |||
| def app(): | |||
| # Set up storage configuration | |||
| os.environ["STORAGE_TYPE"] = "opendal" | |||
| os.environ["OPENDAL_SCHEME"] = "fs" | |||
| os.environ["OPENDAL_FS_ROOT"] = "storage" | |||
| # Ensure storage directory exists | |||
| os.makedirs("storage", exist_ok=True) | |||
| app = create_app() | |||
| dify_config.LOGIN_DISABLED = True | |||
| return app | |||
| def init_llm_node(config: dict) -> LLMNode: | |||
| graph_config = { | |||
| "edges": [ | |||
| @@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode: | |||
| graph = Graph.init(graph_config=graph_config) | |||
| # Use proper UUIDs for database compatibility | |||
| tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" | |||
| app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c" | |||
| workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" | |||
| user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" | |||
| init_params = GraphInitParams( | |||
| tenant_id="1", | |||
| app_id="1", | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| workflow_type=WorkflowType.WORKFLOW, | |||
| workflow_id="1", | |||
| workflow_id=workflow_id, | |||
| graph_config=graph_config, | |||
| user_id="1", | |||
| user_id=user_id, | |||
| user_from=UserFrom.ACCOUNT, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| call_depth=0, | |||
| @@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode: | |||
| return node | |||
| def test_execute_llm(setup_model_mock): | |||
| node = init_llm_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| "title": "123", | |||
| "type": "llm", | |||
| "model": { | |||
| "provider": "langgenius/openai/openai", | |||
| "name": "gpt-3.5-turbo", | |||
| "mode": "chat", | |||
| "completion_params": {}, | |||
| def test_execute_llm(app): | |||
| with app.app_context(): | |||
| node = init_llm_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| "title": "123", | |||
| "type": "llm", | |||
| "model": { | |||
| "provider": "langgenius/openai/openai", | |||
| "name": "gpt-3.5-turbo", | |||
| "mode": "chat", | |||
| "completion_params": {}, | |||
| }, | |||
| "prompt_template": [ | |||
| { | |||
| "role": "system", | |||
| "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", | |||
| }, | |||
| {"role": "user", "text": "{{#sys.query#}}"}, | |||
| ], | |||
| "memory": None, | |||
| "context": {"enabled": False}, | |||
| "vision": {"enabled": False}, | |||
| }, | |||
| "prompt_template": [ | |||
| {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, | |||
| {"role": "user", "text": "{{#sys.query#}}"}, | |||
| ], | |||
| "memory": None, | |||
| "context": {"enabled": False}, | |||
| "vision": {"enabled": False}, | |||
| }, | |||
| }, | |||
| ) | |||
| ) | |||
| credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} | |||
| credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} | |||
| # Mock db.session.close() | |||
| db.session.close = MagicMock() | |||
| # Create a proper LLM result with real entities | |||
| mock_usage = LLMUsage( | |||
| prompt_tokens=30, | |||
| prompt_unit_price=Decimal("0.001"), | |||
| prompt_price_unit=Decimal("1000"), | |||
| prompt_price=Decimal("0.00003"), | |||
| completion_tokens=20, | |||
| completion_unit_price=Decimal("0.002"), | |||
| completion_price_unit=Decimal("1000"), | |||
| completion_price=Decimal("0.00004"), | |||
| total_tokens=50, | |||
| total_price=Decimal("0.00007"), | |||
| currency="USD", | |||
| latency=0.5, | |||
| ) | |||
| node._fetch_model_config = get_mocked_fetch_model_config( | |||
| provider="langgenius/openai/openai", | |||
| model="gpt-3.5-turbo", | |||
| mode="chat", | |||
| credentials=credentials, | |||
| ) | |||
| mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") | |||
| mock_llm_result = LLMResult( | |||
| model="gpt-3.5-turbo", | |||
| prompt_messages=[], | |||
| message=mock_message, | |||
| usage=mock_usage, | |||
| ) | |||
| # Create a simple mock model instance that doesn't call real providers | |||
| mock_model_instance = MagicMock() | |||
| mock_model_instance.invoke_llm.return_value = mock_llm_result | |||
| # Create a simple mock model config with required attributes | |||
| mock_model_config = MagicMock() | |||
| mock_model_config.mode = "chat" | |||
| mock_model_config.provider = "langgenius/openai/openai" | |||
| mock_model_config.model = "gpt-3.5-turbo" | |||
| mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" | |||
| # Mock the _fetch_model_config method | |||
| def mock_fetch_model_config_func(_node_data_model): | |||
| return mock_model_instance, mock_model_config | |||
| # Also mock ModelManager.get_model_instance to avoid database calls | |||
| def mock_get_model_instance(_self, **kwargs): | |||
| return mock_model_instance | |||
| # execute node | |||
| result = node._run() | |||
| assert isinstance(result, Generator) | |||
| with ( | |||
| patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), | |||
| patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), | |||
| ): | |||
| # execute node | |||
| result = node._run() | |||
| assert isinstance(result, Generator) | |||
| for item in result: | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.process_data is not None | |||
| assert item.run_result.outputs is not None | |||
| assert item.run_result.outputs.get("text") is not None | |||
| assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 | |||
| for item in result: | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.process_data is not None | |||
| assert item.run_result.outputs is not None | |||
| assert item.run_result.outputs.get("text") is not None | |||
| assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 | |||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | |||
| def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock): | |||
| def test_execute_llm_with_jinja2(app, setup_code_executor_mock): | |||
| """ | |||
| Test execute LLM node with jinja2 | |||
| """ | |||
| node = init_llm_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| "title": "123", | |||
| "type": "llm", | |||
| "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, | |||
| "prompt_config": { | |||
| "jinja2_variables": [ | |||
| {"variable": "sys_query", "value_selector": ["sys", "query"]}, | |||
| {"variable": "output", "value_selector": ["abc", "output"]}, | |||
| ] | |||
| }, | |||
| "prompt_template": [ | |||
| { | |||
| "role": "system", | |||
| "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", | |||
| "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", | |||
| "edition_type": "jinja2", | |||
| with app.app_context(): | |||
| node = init_llm_node( | |||
| config={ | |||
| "id": "llm", | |||
| "data": { | |||
| "title": "123", | |||
| "type": "llm", | |||
| "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, | |||
| "prompt_config": { | |||
| "jinja2_variables": [ | |||
| {"variable": "sys_query", "value_selector": ["sys", "query"]}, | |||
| {"variable": "output", "value_selector": ["abc", "output"]}, | |||
| ] | |||
| }, | |||
| { | |||
| "role": "user", | |||
| "text": "{{#sys.query#}}", | |||
| "jinja2_text": "{{sys_query}}", | |||
| "edition_type": "basic", | |||
| }, | |||
| ], | |||
| "memory": None, | |||
| "context": {"enabled": False}, | |||
| "vision": {"enabled": False}, | |||
| "prompt_template": [ | |||
| { | |||
| "role": "system", | |||
| "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", | |||
| "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", | |||
| "edition_type": "jinja2", | |||
| }, | |||
| { | |||
| "role": "user", | |||
| "text": "{{#sys.query#}}", | |||
| "jinja2_text": "{{sys_query}}", | |||
| "edition_type": "basic", | |||
| }, | |||
| ], | |||
| "memory": None, | |||
| "context": {"enabled": False}, | |||
| "vision": {"enabled": False}, | |||
| }, | |||
| }, | |||
| }, | |||
| ) | |||
| ) | |||
| credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} | |||
| # Mock db.session.close() | |||
| db.session.close = MagicMock() | |||
| # Mock db.session.close() | |||
| db.session.close = MagicMock() | |||
| # Create a proper LLM result with real entities | |||
| mock_usage = LLMUsage( | |||
| prompt_tokens=30, | |||
| prompt_unit_price=Decimal("0.001"), | |||
| prompt_price_unit=Decimal("1000"), | |||
| prompt_price=Decimal("0.00003"), | |||
| completion_tokens=20, | |||
| completion_unit_price=Decimal("0.002"), | |||
| completion_price_unit=Decimal("1000"), | |||
| completion_price=Decimal("0.00004"), | |||
| total_tokens=50, | |||
| total_price=Decimal("0.00007"), | |||
| currency="USD", | |||
| latency=0.5, | |||
| ) | |||
| node._fetch_model_config = get_mocked_fetch_model_config( | |||
| provider="langgenius/openai/openai", | |||
| model="gpt-3.5-turbo", | |||
| mode="chat", | |||
| credentials=credentials, | |||
| ) | |||
| mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") | |||
| mock_llm_result = LLMResult( | |||
| model="gpt-3.5-turbo", | |||
| prompt_messages=[], | |||
| message=mock_message, | |||
| usage=mock_usage, | |||
| ) | |||
| # Create a simple mock model instance that doesn't call real providers | |||
| mock_model_instance = MagicMock() | |||
| mock_model_instance.invoke_llm.return_value = mock_llm_result | |||
| # Create a simple mock model config with required attributes | |||
| mock_model_config = MagicMock() | |||
| mock_model_config.mode = "chat" | |||
| mock_model_config.provider = "openai" | |||
| mock_model_config.model = "gpt-3.5-turbo" | |||
| mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" | |||
| # Mock the _fetch_model_config method | |||
| def mock_fetch_model_config_func(_node_data_model): | |||
| return mock_model_instance, mock_model_config | |||
| # Also mock ModelManager.get_model_instance to avoid database calls | |||
| def mock_get_model_instance(_self, **kwargs): | |||
| return mock_model_instance | |||
| # execute node | |||
| result = node._run() | |||
| with ( | |||
| patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), | |||
| patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), | |||
| ): | |||
| # execute node | |||
| result = node._run() | |||
| for item in result: | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.process_data is not None | |||
| assert "sunny" in json.dumps(item.run_result.process_data) | |||
| assert "what's the weather today?" in json.dumps(item.run_result.process_data) | |||
| for item in result: | |||
| if isinstance(item, RunCompletedEvent): | |||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||
| assert item.run_result.process_data is not None | |||
| assert "sunny" in json.dumps(item.run_result.process_data) | |||
| assert "what's the weather today?" in json.dumps(item.run_result.process_data) | |||
| def test_extract_json(): | |||