Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.2
| status: ModelStatus | status: ModelStatus | ||||
| load_balancing_enabled: bool = False | load_balancing_enabled: bool = False | ||||
| def raise_for_status(self) -> None: | |||||
| """ | |||||
| Check model status and raise ValueError if not active. | |||||
| :raises ValueError: When model status is not active, with a descriptive message | |||||
| """ | |||||
| if self.status == ModelStatus.ACTIVE: | |||||
| return | |||||
| error_messages = { | |||||
| ModelStatus.NO_CONFIGURE: "Model is not configured", | |||||
| ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", | |||||
| ModelStatus.NO_PERMISSION: "No permission to use this model", | |||||
| ModelStatus.DISABLED: "Model is disabled", | |||||
| } | |||||
| if self.status in error_messages: | |||||
| raise ValueError(error_messages[self.status]) | |||||
| class ModelWithProviderEntity(ProviderModelWithStatusEntity): | class ModelWithProviderEntity(ProviderModelWithStatusEntity): | ||||
| """ | """ |
| extensions = [] | extensions = [] | ||||
| position_map: dict[str, int] = {} | position_map: dict[str, int] = {} | ||||
| # get the path of the current class | |||||
| current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") | |||||
| current_dir_path = os.path.dirname(current_path) | |||||
| # traverse subdirectories | |||||
| for subdir_name in os.listdir(current_dir_path): | |||||
| if subdir_name.startswith("__"): | |||||
| continue | |||||
| subdir_path = os.path.join(current_dir_path, subdir_name) | |||||
| extension_name = subdir_name | |||||
| if os.path.isdir(subdir_path): | |||||
| # Get the package name from the module path | |||||
| package_name = ".".join(cls.__module__.split(".")[:-1]) | |||||
| try: | |||||
| # Get package directory path | |||||
| package_spec = importlib.util.find_spec(package_name) | |||||
| if not package_spec or not package_spec.origin: | |||||
| raise ImportError(f"Could not find package {package_name}") | |||||
| package_dir = os.path.dirname(package_spec.origin) | |||||
| # Traverse subdirectories | |||||
| for subdir_name in os.listdir(package_dir): | |||||
| if subdir_name.startswith("__"): | |||||
| continue | |||||
| subdir_path = os.path.join(package_dir, subdir_name) | |||||
| if not os.path.isdir(subdir_path): | |||||
| continue | |||||
| extension_name = subdir_name | |||||
| file_names = os.listdir(subdir_path) | file_names = os.listdir(subdir_path) | ||||
| # is builtin extension, builtin extension | |||||
| # in the front-end page and business logic, there are special treatments. | |||||
| # Check for extension module file | |||||
| if (extension_name + ".py") not in file_names: | |||||
| logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") | |||||
| continue | |||||
| # Check for builtin flag and position | |||||
| builtin = False | builtin = False | ||||
| # default position is 0 can not be None for sort_to_dict_by_position_map | |||||
| position = 0 | position = 0 | ||||
| if "__builtin__" in file_names: | if "__builtin__" in file_names: | ||||
| builtin = True | builtin = True | ||||
| builtin_file_path = os.path.join(subdir_path, "__builtin__") | builtin_file_path = os.path.join(subdir_path, "__builtin__") | ||||
| if os.path.exists(builtin_file_path): | if os.path.exists(builtin_file_path): | ||||
| position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) | position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) | ||||
| position_map[extension_name] = position | position_map[extension_name] = position | ||||
| if (extension_name + ".py") not in file_names: | |||||
| logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") | |||||
| continue | |||||
| # Dynamic loading {subdir_name}.py file and find the subclass of Extensible | |||||
| py_path = os.path.join(subdir_path, extension_name + ".py") | |||||
| spec = importlib.util.spec_from_file_location(extension_name, py_path) | |||||
| # Import the extension module | |||||
| module_name = f"{package_name}.{extension_name}.{extension_name}" | |||||
| spec = importlib.util.find_spec(module_name) | |||||
| if not spec or not spec.loader: | if not spec or not spec.loader: | ||||
| raise Exception(f"Failed to load module {extension_name} from {py_path}") | |||||
| raise ImportError(f"Failed to load module {module_name}") | |||||
| mod = importlib.util.module_from_spec(spec) | mod = importlib.util.module_from_spec(spec) | ||||
| spec.loader.exec_module(mod) | spec.loader.exec_module(mod) | ||||
| # Find extension class | |||||
| extension_class = None | extension_class = None | ||||
| for name, obj in vars(mod).items(): | for name, obj in vars(mod).items(): | ||||
| if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: | if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: | ||||
| break | break | ||||
| if not extension_class: | if not extension_class: | ||||
| logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") | |||||
| logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.") | |||||
| continue | continue | ||||
| # Load schema if not builtin | |||||
| json_data: dict[str, Any] = {} | json_data: dict[str, Any] = {} | ||||
| if not builtin: | if not builtin: | ||||
| if "schema.json" not in file_names: | |||||
| json_path = os.path.join(subdir_path, "schema.json") | |||||
| if not os.path.exists(json_path): | |||||
| logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") | logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") | ||||
| continue | continue | ||||
| json_path = os.path.join(subdir_path, "schema.json") | |||||
| json_data = {} | |||||
| if os.path.exists(json_path): | |||||
| with open(json_path, encoding="utf-8") as f: | |||||
| json_data = json.load(f) | |||||
| with open(json_path, encoding="utf-8") as f: | |||||
| json_data = json.load(f) | |||||
| # Create extension | |||||
| extensions.append( | extensions.append( | ||||
| ModuleExtension( | ModuleExtension( | ||||
| extension_class=extension_class, | extension_class=extension_class, | ||||
| ) | ) | ||||
| ) | ) | ||||
| except Exception as e: | |||||
| logging.exception("Error scanning extensions") | |||||
| raise | |||||
| # Sort extensions by position | |||||
| sorted_extensions = sort_to_dict_by_position_map( | sorted_extensions = sort_to_dict_by_position_map( | ||||
| position_map=position_map, data=extensions, name_func=lambda x: x.name | position_map=position_map, data=extensions, name_func=lambda x: x.name | ||||
| ) | ) |
| deprecated: bool = False | deprecated: bool = False | ||||
| model_config = ConfigDict(protected_namespaces=()) | model_config = ConfigDict(protected_namespaces=()) | ||||
| @property | |||||
| def support_structure_output(self) -> bool: | |||||
| return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features | |||||
| class ParameterRule(BaseModel): | class ParameterRule(BaseModel): | ||||
| """ | """ |
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from typing import Any, Optional, cast | from typing import Any, Optional, cast | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.exc import IntegrityError | from sqlalchemy.exc import IntegrityError | ||||
| from sqlalchemy.orm import Session | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | ||||
| @staticmethod | @staticmethod | ||||
| def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: | def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: | ||||
| """ | |||||
| Get all provider records of the workspace. | |||||
| :param tenant_id: workspace id | |||||
| :return: | |||||
| """ | |||||
| providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() | |||||
| provider_name_to_provider_records_dict = defaultdict(list) | provider_name_to_provider_records_dict = defaultdict(list) | ||||
| for provider in providers: | |||||
| # TODO: Use provider name with prefix after the data migration | |||||
| provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) | |||||
| providers = session.scalars(stmt) | |||||
| for provider in providers: | |||||
| # Use provider name with prefix after the data migration | |||||
| provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) | |||||
| return provider_name_to_provider_records_dict | return provider_name_to_provider_records_dict | ||||
| @staticmethod | @staticmethod | ||||
| :param tenant_id: workspace id | :param tenant_id: workspace id | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get all provider model records of the workspace | |||||
| provider_models = ( | |||||
| db.session.query(ProviderModel) | |||||
| .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) | |||||
| .all() | |||||
| ) | |||||
| provider_name_to_provider_model_records_dict = defaultdict(list) | provider_name_to_provider_model_records_dict = defaultdict(list) | ||||
| for provider_model in provider_models: | |||||
| provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) | |||||
| provider_models = session.scalars(stmt) | |||||
| for provider_model in provider_models: | |||||
| provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) | |||||
| return provider_name_to_provider_model_records_dict | return provider_name_to_provider_model_records_dict | ||||
| @staticmethod | @staticmethod | ||||
| :param tenant_id: workspace id | :param tenant_id: workspace id | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| preferred_provider_types = ( | |||||
| db.session.query(TenantPreferredModelProvider) | |||||
| .filter(TenantPreferredModelProvider.tenant_id == tenant_id) | |||||
| .all() | |||||
| ) | |||||
| provider_name_to_preferred_provider_type_records_dict = { | |||||
| preferred_provider_type.provider_name: preferred_provider_type | |||||
| for preferred_provider_type in preferred_provider_types | |||||
| } | |||||
| provider_name_to_preferred_provider_type_records_dict = {} | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) | |||||
| preferred_provider_types = session.scalars(stmt) | |||||
| provider_name_to_preferred_provider_type_records_dict = { | |||||
| preferred_provider_type.provider_name: preferred_provider_type | |||||
| for preferred_provider_type in preferred_provider_types | |||||
| } | |||||
| return provider_name_to_preferred_provider_type_records_dict | return provider_name_to_preferred_provider_type_records_dict | ||||
| @staticmethod | @staticmethod | ||||
| :param tenant_id: workspace id | :param tenant_id: workspace id | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| provider_model_settings = ( | |||||
| db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() | |||||
| ) | |||||
| provider_name_to_provider_model_settings_dict = defaultdict(list) | provider_name_to_provider_model_settings_dict = defaultdict(list) | ||||
| for provider_model_setting in provider_model_settings: | |||||
| ( | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) | |||||
| provider_model_settings = session.scalars(stmt) | |||||
| for provider_model_setting in provider_model_settings: | |||||
| provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( | provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( | ||||
| provider_model_setting | provider_model_setting | ||||
| ) | ) | ||||
| ) | |||||
| return provider_name_to_provider_model_settings_dict | return provider_name_to_provider_model_settings_dict | ||||
| @staticmethod | @staticmethod | ||||
| if not model_load_balancing_enabled: | if not model_load_balancing_enabled: | ||||
| return {} | return {} | ||||
| provider_load_balancing_configs = ( | |||||
| db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() | |||||
| ) | |||||
| provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) | provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) | ||||
| for provider_load_balancing_config in provider_load_balancing_configs: | |||||
| provider_name_to_provider_load_balancing_model_configs_dict[ | |||||
| provider_load_balancing_config.provider_name | |||||
| ].append(provider_load_balancing_config) | |||||
| with Session(db.engine, expire_on_commit=False) as session: | |||||
| stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) | |||||
| provider_load_balancing_configs = session.scalars(stmt) | |||||
| for provider_load_balancing_config in provider_load_balancing_configs: | |||||
| provider_name_to_provider_load_balancing_model_configs_dict[ | |||||
| provider_load_balancing_config.provider_name | |||||
| ].append(provider_load_balancing_config) | |||||
| return provider_name_to_provider_load_balancing_model_configs_dict | return provider_name_to_provider_load_balancing_model_configs_dict | ||||
| if not cached_provider_credentials: | if not cached_provider_credentials: | ||||
| try: | try: | ||||
| # fix origin data | # fix origin data | ||||
| if ( | |||||
| custom_provider_record.encrypted_config | |||||
| and not custom_provider_record.encrypted_config.startswith("{") | |||||
| ): | |||||
| if custom_provider_record.encrypted_config is None: | |||||
| raise ValueError("No credentials found") | |||||
| if not custom_provider_record.encrypted_config.startswith("{"): | |||||
| provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | ||||
| else: | else: | ||||
| provider_credentials = json.loads(custom_provider_record.encrypted_config) | provider_credentials = json.loads(custom_provider_record.encrypted_config) | ||||
| return SystemConfiguration(enabled=False) | return SystemConfiguration(enabled=False) | ||||
| # Convert provider_records to dict | # Convert provider_records to dict | ||||
| quota_type_to_provider_records_dict = {} | |||||
| quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} | |||||
| for provider_record in provider_records: | for provider_record in provider_records: | ||||
| if provider_record.provider_type != ProviderType.SYSTEM.value: | if provider_record.provider_type != ProviderType.SYSTEM.value: | ||||
| continue | continue | ||||
| else: | else: | ||||
| provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] | provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] | ||||
| if provider_record.quota_used is None: | |||||
| raise ValueError("quota_used is None") | |||||
| if provider_record.quota_limit is None: | |||||
| raise ValueError("quota_limit is None") | |||||
| quota_configuration = QuotaConfiguration( | quota_configuration = QuotaConfiguration( | ||||
| quota_type=provider_quota.quota_type, | quota_type=provider_quota.quota_type, | ||||
| quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, | quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, | ||||
| cached_provider_credentials = provider_credentials_cache.get() | cached_provider_credentials = provider_credentials_cache.get() | ||||
| if not cached_provider_credentials: | if not cached_provider_credentials: | ||||
| try: | |||||
| provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) | |||||
| except JSONDecodeError: | |||||
| provider_credentials = {} | |||||
| provider_credentials: dict[str, Any] = {} | |||||
| if provider_records and provider_records[0].encrypted_config: | |||||
| provider_credentials = json.loads(provider_records[0].encrypted_config) | |||||
| # Get provider credential secret variables | # Get provider credential secret variables | ||||
| provider_credential_secret_variables = self._extract_secret_variables( | provider_credential_secret_variables = self._extract_secret_variables( |
| context: ContextConfig | context: ContextConfig | ||||
| vision: VisionConfig = Field(default_factory=VisionConfig) | vision: VisionConfig = Field(default_factory=VisionConfig) | ||||
| structured_output: dict | None = None | structured_output: dict | None = None | ||||
| structured_output_enabled: bool = False | |||||
| # We used 'structured_output_enabled' in the past, but it's not a good name. | |||||
| structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") | |||||
| @field_validator("prompt_config", mode="before") | @field_validator("prompt_config", mode="before") | ||||
| @classmethod | @classmethod | ||||
| if v is None: | if v is None: | ||||
| return PromptConfig() | return PromptConfig() | ||||
| return v | return v | ||||
| @property | |||||
| def structured_output_enabled(self) -> bool: | |||||
| return self.structured_output_switch_on and self.structured_output is not None |
| from configs import dify_config | from configs import dify_config | ||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | ||||
| from core.entities.model_entities import ModelStatus | |||||
| from core.entities.provider_entities import QuotaUnit | from core.entities.provider_entities import QuotaUnit | ||||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||||
| from core.file import FileType, file_manager | from core.file import FileType, file_manager | ||||
| from core.helper.code_executor import CodeExecutor, CodeLanguage | from core.helper.code_executor import CodeExecutor, CodeLanguage | ||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| from core.workflow.utils.structured_output.entities import ( | from core.workflow.utils.structured_output.entities import ( | ||||
| ResponseFormat, | ResponseFormat, | ||||
| SpecialModelType, | SpecialModelType, | ||||
| SupportStructuredOutputStatus, | |||||
| ) | ) | ||||
| from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT | from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT | ||||
| from core.workflow.utils.variable_template_parser import VariableTemplateParser | from core.workflow.utils.variable_template_parser import VariableTemplateParser | ||||
| llm_usage=usage, | llm_usage=usage, | ||||
| ) | ) | ||||
| ) | ) | ||||
| except LLMNodeError as e: | |||||
| except ValueError as e: | |||||
| yield RunCompletedEvent( | yield RunCompletedEvent( | ||||
| run_result=NodeRunResult( | run_result=NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| def _fetch_model_config( | def _fetch_model_config( | ||||
| self, node_data_model: ModelConfig | self, node_data_model: ModelConfig | ||||
| ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: | ||||
| model_name = node_data_model.name | |||||
| provider_name = node_data_model.provider | |||||
| if not node_data_model.mode: | |||||
| raise LLMModeRequiredError("LLM mode is required.") | |||||
| model_manager = ModelManager() | |||||
| model_instance = model_manager.get_model_instance( | |||||
| tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name | |||||
| model = ModelManager().get_model_instance( | |||||
| tenant_id=self.tenant_id, | |||||
| model_type=ModelType.LLM, | |||||
| provider=node_data_model.provider, | |||||
| model=node_data_model.name, | |||||
| ) | ) | ||||
| provider_model_bundle = model_instance.provider_model_bundle | |||||
| model_type_instance = model_instance.model_type_instance | |||||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||||
| model_credentials = model_instance.credentials | |||||
| model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) | |||||
| # check model | # check model | ||||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||||
| model=model_name, model_type=ModelType.LLM | |||||
| provider_model = model.provider_model_bundle.configuration.get_provider_model( | |||||
| model=node_data_model.name, model_type=ModelType.LLM | |||||
| ) | ) | ||||
| if provider_model is None: | if provider_model is None: | ||||
| raise ModelNotExistError(f"Model {model_name} not exist.") | |||||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||||
| raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||||
| raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") | |||||
| raise ModelNotExistError(f"Model {node_data_model.name} not exist.") | |||||
| provider_model.raise_for_status() | |||||
| # model config | # model config | ||||
| completion_params = node_data_model.completion_params | |||||
| stop = [] | |||||
| if "stop" in completion_params: | |||||
| stop = completion_params["stop"] | |||||
| del completion_params["stop"] | |||||
| # get model mode | |||||
| model_mode = node_data_model.mode | |||||
| if not model_mode: | |||||
| raise LLMModeRequiredError("LLM mode is required.") | |||||
| model_schema = model_type_instance.get_model_schema(model_name, model_credentials) | |||||
| stop: list[str] = [] | |||||
| if "stop" in node_data_model.completion_params: | |||||
| stop = node_data_model.completion_params.pop("stop") | |||||
| model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) | |||||
| if not model_schema: | if not model_schema: | ||||
| raise ModelNotExistError(f"Model {model_name} not exist.") | |||||
| support_structured_output = self._check_model_structured_output_support() | |||||
| if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: | |||||
| completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) | |||||
| elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: | |||||
| # Set appropriate response format based on model capabilities | |||||
| self._set_response_format(completion_params, model_schema.parameter_rules) | |||||
| return model_instance, ModelConfigWithCredentialsEntity( | |||||
| provider=provider_name, | |||||
| model=model_name, | |||||
| raise ModelNotExistError(f"Model {node_data_model.name} not exist.") | |||||
| if self.node_data.structured_output_enabled: | |||||
| if model_schema.support_structure_output: | |||||
| node_data_model.completion_params = self._handle_native_json_schema( | |||||
| node_data_model.completion_params, model_schema.parameter_rules | |||||
| ) | |||||
| else: | |||||
| # Set appropriate response format based on model capabilities | |||||
| self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) | |||||
| return model, ModelConfigWithCredentialsEntity( | |||||
| provider=node_data_model.provider, | |||||
| model=node_data_model.name, | |||||
| model_schema=model_schema, | model_schema=model_schema, | ||||
| mode=model_mode, | |||||
| provider_model_bundle=provider_model_bundle, | |||||
| credentials=model_credentials, | |||||
| parameters=completion_params, | |||||
| mode=node_data_model.mode, | |||||
| provider_model_bundle=model.provider_model_bundle, | |||||
| credentials=model.credentials, | |||||
| parameters=node_data_model.completion_params, | |||||
| stop=stop, | stop=stop, | ||||
| ) | ) | ||||
| "No prompt found in the LLM configuration. " | "No prompt found in the LLM configuration. " | ||||
| "Please ensure a prompt is properly configured before proceeding." | "Please ensure a prompt is properly configured before proceeding." | ||||
| ) | ) | ||||
| support_structured_output = self._check_model_structured_output_support() | |||||
| if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: | |||||
| filtered_prompt_messages = self._handle_prompt_based_schema( | |||||
| prompt_messages=filtered_prompt_messages, | |||||
| ) | |||||
| stop = model_config.stop | |||||
| return filtered_prompt_messages, stop | |||||
| model = ModelManager().get_model_instance( | |||||
| tenant_id=self.tenant_id, | |||||
| model_type=ModelType.LLM, | |||||
| provider=self.node_data.model.provider, | |||||
| model=self.node_data.model.name, | |||||
| ) | |||||
| model_schema = model.model_type_instance.get_model_schema( | |||||
| model=self.node_data.model.name, | |||||
| credentials=model.credentials, | |||||
| ) | |||||
| if not model_schema: | |||||
| raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.") | |||||
| if self.node_data.structured_output_enabled: | |||||
| if not model_schema.support_structure_output: | |||||
| filtered_prompt_messages = self._handle_prompt_based_schema( | |||||
| prompt_messages=filtered_prompt_messages, | |||||
| ) | |||||
| return filtered_prompt_messages, model_config.stop | |||||
| def _parse_structured_output(self, result_text: str) -> dict[str, Any]: | def _parse_structured_output(self, result_text: str) -> dict[str, Any]: | ||||
| structured_output: dict[str, Any] = {} | structured_output: dict[str, Any] = {} | ||||
| except json.JSONDecodeError: | except json.JSONDecodeError: | ||||
| raise LLMNodeError("structured_output_schema is not valid JSON format") | raise LLMNodeError("structured_output_schema is not valid JSON format") | ||||
| def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: | |||||
| """ | |||||
| Check if the current model supports structured output. | |||||
| Returns: | |||||
| SupportStructuredOutput: The support status of structured output | |||||
| """ | |||||
| # Early return if structured output is disabled | |||||
| if ( | |||||
| not isinstance(self.node_data, LLMNodeData) | |||||
| or not self.node_data.structured_output_enabled | |||||
| or not self.node_data.structured_output | |||||
| ): | |||||
| return SupportStructuredOutputStatus.DISABLED | |||||
| # Get model schema and check if it exists | |||||
| model_schema = self._fetch_model_schema(self.node_data.model.provider) | |||||
| if not model_schema: | |||||
| return SupportStructuredOutputStatus.DISABLED | |||||
| # Check if model supports structured output feature | |||||
| return ( | |||||
| SupportStructuredOutputStatus.SUPPORTED | |||||
| if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) | |||||
| else SupportStructuredOutputStatus.UNSUPPORTED | |||||
| ) | |||||
| def _save_multimodal_output_and_convert_result_to_markdown( | def _save_multimodal_output_and_convert_result_to_markdown( | ||||
| self, | self, | ||||
| contents: str | list[PromptMessageContentUnionTypes] | None, | contents: str | list[PromptMessageContentUnionTypes] | None, |
| GEMINI = "gemini" | GEMINI = "gemini" | ||||
| OLLAMA = "ollama" | OLLAMA = "ollama" | ||||
| class SupportStructuredOutputStatus(StrEnum): | |||||
| """Constants for structured output support status""" | |||||
| SUPPORTED = "supported" | |||||
| UNSUPPORTED = "unsupported" | |||||
| DISABLED = "disabled" |
| from datetime import datetime | |||||
| from enum import Enum | from enum import Enum | ||||
| from typing import Optional | |||||
| from sqlalchemy import func | |||||
| from sqlalchemy import func, text | |||||
| from sqlalchemy.orm import Mapped, mapped_column | |||||
| from .base import Base | from .base import Base | ||||
| from .engine import db | from .engine import db | ||||
| ), | ), | ||||
| ) | ) | ||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| provider_name = db.Column(db.String(255), nullable=False) | |||||
| provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) | |||||
| encrypted_config = db.Column(db.Text, nullable=True) | |||||
| is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||||
| last_used = db.Column(db.DateTime, nullable=True) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| provider_type: Mapped[str] = mapped_column( | |||||
| db.String(40), nullable=False, server_default=text("'custom'::character varying") | |||||
| ) | |||||
| encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) | |||||
| is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) | |||||
| last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) | |||||
| quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) | |||||
| quota_limit = db.Column(db.BigInteger, nullable=True) | |||||
| quota_used = db.Column(db.BigInteger, default=0) | |||||
| quota_type: Mapped[Optional[str]] = mapped_column( | |||||
| db.String(40), nullable=True, server_default=text("''::character varying") | |||||
| ) | |||||
| quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) | |||||
| quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| def __repr__(self): | def __repr__(self): | ||||
| return ( | return ( | ||||
| ), | ), | ||||
| ) | ) | ||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| provider_name = db.Column(db.String(255), nullable=False) | |||||
| model_name = db.Column(db.String(255), nullable=False) | |||||
| model_type = db.Column(db.String(40), nullable=False) | |||||
| encrypted_config = db.Column(db.Text, nullable=True) | |||||
| is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||||
| encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) | |||||
| is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) | |||||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| class TenantDefaultModel(Base): | class TenantDefaultModel(Base): | ||||
| db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), | db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), | ||||
| ) | ) | ||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| provider_name = db.Column(db.String(255), nullable=False) | |||||
| model_name = db.Column(db.String(255), nullable=False) | |||||
| model_type = db.Column(db.String(40), nullable=False) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| class TenantPreferredModelProvider(Base): | class TenantPreferredModelProvider(Base): | ||||
| db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), | db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), | ||||
| ) | ) | ||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| provider_name = db.Column(db.String(255), nullable=False) | |||||
| preferred_provider_type = db.Column(db.String(40), nullable=False) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| class ProviderOrder(Base): | class ProviderOrder(Base): | ||||
| db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), | db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), | ||||
| ) | ) | ||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| provider_name = db.Column(db.String(255), nullable=False) | |||||
| account_id = db.Column(StringUUID, nullable=False) | |||||
| payment_product_id = db.Column(db.String(191), nullable=False) | |||||
| payment_id = db.Column(db.String(191)) | |||||
| transaction_id = db.Column(db.String(191)) | |||||
| quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) | |||||
| currency = db.Column(db.String(40)) | |||||
| total_amount = db.Column(db.Integer) | |||||
| payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) | |||||
| paid_at = db.Column(db.DateTime) | |||||
| pay_failed_at = db.Column(db.DateTime) | |||||
| refunded_at = db.Column(db.DateTime) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) | |||||
| payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) | |||||
| transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) | |||||
| quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) | |||||
| currency: Mapped[Optional[str]] = mapped_column(db.String(40)) | |||||
| total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) | |||||
| payment_status: Mapped[str] = mapped_column( | |||||
| db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") | |||||
| ) | |||||
| paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) | |||||
| pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) | |||||
| refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) | |||||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| class ProviderModelSetting(Base): | class ProviderModelSetting(Base): | ||||
| db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), | db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), | ||||
| ) | ) | ||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| provider_name = db.Column(db.String(255), nullable=False) | |||||
| model_name = db.Column(db.String(255), nullable=False) | |||||
| model_type = db.Column(db.String(40), nullable=False) | |||||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||||
| load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||||
| enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) | |||||
| load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) | |||||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| class LoadBalancingModelConfig(Base): | class LoadBalancingModelConfig(Base): | ||||
| db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), | db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), | ||||
| ) | ) | ||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| provider_name = db.Column(db.String(255), nullable=False) | |||||
| model_name = db.Column(db.String(255), nullable=False) | |||||
| model_type = db.Column(db.String(40), nullable=False) | |||||
| name = db.Column(db.String(255), nullable=False) | |||||
| encrypted_config = db.Column(db.Text, nullable=True) | |||||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) | |||||
| tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||||
| provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) | |||||
| name: Mapped[str] = mapped_column(db.String(255), nullable=False) | |||||
| encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) | |||||
| enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) | |||||
| created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||||
| updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) |
| import time | import time | ||||
| import uuid | import uuid | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from unittest.mock import MagicMock | |||||
| from decimal import Decimal | |||||
| from unittest.mock import MagicMock, patch | |||||
| import pytest | import pytest | ||||
| from app_factory import create_app | |||||
| from configs import dify_config | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage | |||||
| from core.model_runtime.entities.message_entities import AssistantPromptMessage | |||||
| from core.workflow.entities.variable_pool import VariablePool | from core.workflow.entities.variable_pool import VariablePool | ||||
| from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus | ||||
| from core.workflow.enums import SystemVariableKey | from core.workflow.enums import SystemVariableKey | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.enums import UserFrom | from models.enums import UserFrom | ||||
| from models.workflow import WorkflowType | from models.workflow import WorkflowType | ||||
| from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config | |||||
| """FOR MOCK FIXTURES, DO NOT REMOVE""" | """FOR MOCK FIXTURES, DO NOT REMOVE""" | ||||
| from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock | from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock | ||||
| from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock | ||||
| @pytest.fixture(scope="session") | |||||
| def app(): | |||||
| # Set up storage configuration | |||||
| os.environ["STORAGE_TYPE"] = "opendal" | |||||
| os.environ["OPENDAL_SCHEME"] = "fs" | |||||
| os.environ["OPENDAL_FS_ROOT"] = "storage" | |||||
| # Ensure storage directory exists | |||||
| os.makedirs("storage", exist_ok=True) | |||||
| app = create_app() | |||||
| dify_config.LOGIN_DISABLED = True | |||||
| return app | |||||
| def init_llm_node(config: dict) -> LLMNode: | def init_llm_node(config: dict) -> LLMNode: | ||||
| graph_config = { | graph_config = { | ||||
| "edges": [ | "edges": [ | ||||
| graph = Graph.init(graph_config=graph_config) | graph = Graph.init(graph_config=graph_config) | ||||
| # Use proper UUIDs for database compatibility | |||||
| tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" | |||||
| app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c" | |||||
| workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" | |||||
| user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" | |||||
| init_params = GraphInitParams( | init_params = GraphInitParams( | ||||
| tenant_id="1", | |||||
| app_id="1", | |||||
| tenant_id=tenant_id, | |||||
| app_id=app_id, | |||||
| workflow_type=WorkflowType.WORKFLOW, | workflow_type=WorkflowType.WORKFLOW, | ||||
| workflow_id="1", | |||||
| workflow_id=workflow_id, | |||||
| graph_config=graph_config, | graph_config=graph_config, | ||||
| user_id="1", | |||||
| user_id=user_id, | |||||
| user_from=UserFrom.ACCOUNT, | user_from=UserFrom.ACCOUNT, | ||||
| invoke_from=InvokeFrom.DEBUGGER, | invoke_from=InvokeFrom.DEBUGGER, | ||||
| call_depth=0, | call_depth=0, | ||||
| return node | return node | ||||
| def test_execute_llm(setup_model_mock): | |||||
| node = init_llm_node( | |||||
| config={ | |||||
| "id": "llm", | |||||
| "data": { | |||||
| "title": "123", | |||||
| "type": "llm", | |||||
| "model": { | |||||
| "provider": "langgenius/openai/openai", | |||||
| "name": "gpt-3.5-turbo", | |||||
| "mode": "chat", | |||||
| "completion_params": {}, | |||||
| def test_execute_llm(app): | |||||
| with app.app_context(): | |||||
| node = init_llm_node( | |||||
| config={ | |||||
| "id": "llm", | |||||
| "data": { | |||||
| "title": "123", | |||||
| "type": "llm", | |||||
| "model": { | |||||
| "provider": "langgenius/openai/openai", | |||||
| "name": "gpt-3.5-turbo", | |||||
| "mode": "chat", | |||||
| "completion_params": {}, | |||||
| }, | |||||
| "prompt_template": [ | |||||
| { | |||||
| "role": "system", | |||||
| "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", | |||||
| }, | |||||
| {"role": "user", "text": "{{#sys.query#}}"}, | |||||
| ], | |||||
| "memory": None, | |||||
| "context": {"enabled": False}, | |||||
| "vision": {"enabled": False}, | |||||
| }, | }, | ||||
| "prompt_template": [ | |||||
| {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, | |||||
| {"role": "user", "text": "{{#sys.query#}}"}, | |||||
| ], | |||||
| "memory": None, | |||||
| "context": {"enabled": False}, | |||||
| "vision": {"enabled": False}, | |||||
| }, | }, | ||||
| }, | |||||
| ) | |||||
| ) | |||||
| credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} | |||||
| credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} | |||||
| # Mock db.session.close() | |||||
| db.session.close = MagicMock() | |||||
| # Create a proper LLM result with real entities | |||||
| mock_usage = LLMUsage( | |||||
| prompt_tokens=30, | |||||
| prompt_unit_price=Decimal("0.001"), | |||||
| prompt_price_unit=Decimal("1000"), | |||||
| prompt_price=Decimal("0.00003"), | |||||
| completion_tokens=20, | |||||
| completion_unit_price=Decimal("0.002"), | |||||
| completion_price_unit=Decimal("1000"), | |||||
| completion_price=Decimal("0.00004"), | |||||
| total_tokens=50, | |||||
| total_price=Decimal("0.00007"), | |||||
| currency="USD", | |||||
| latency=0.5, | |||||
| ) | |||||
| node._fetch_model_config = get_mocked_fetch_model_config( | |||||
| provider="langgenius/openai/openai", | |||||
| model="gpt-3.5-turbo", | |||||
| mode="chat", | |||||
| credentials=credentials, | |||||
| ) | |||||
| mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") | |||||
| mock_llm_result = LLMResult( | |||||
| model="gpt-3.5-turbo", | |||||
| prompt_messages=[], | |||||
| message=mock_message, | |||||
| usage=mock_usage, | |||||
| ) | |||||
| # Create a simple mock model instance that doesn't call real providers | |||||
| mock_model_instance = MagicMock() | |||||
| mock_model_instance.invoke_llm.return_value = mock_llm_result | |||||
| # Create a simple mock model config with required attributes | |||||
| mock_model_config = MagicMock() | |||||
| mock_model_config.mode = "chat" | |||||
| mock_model_config.provider = "langgenius/openai/openai" | |||||
| mock_model_config.model = "gpt-3.5-turbo" | |||||
| mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" | |||||
| # Mock the _fetch_model_config method | |||||
| def mock_fetch_model_config_func(_node_data_model): | |||||
| return mock_model_instance, mock_model_config | |||||
| # Also mock ModelManager.get_model_instance to avoid database calls | |||||
| def mock_get_model_instance(_self, **kwargs): | |||||
| return mock_model_instance | |||||
| # execute node | |||||
| result = node._run() | |||||
| assert isinstance(result, Generator) | |||||
| with ( | |||||
| patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), | |||||
| patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), | |||||
| ): | |||||
| # execute node | |||||
| result = node._run() | |||||
| assert isinstance(result, Generator) | |||||
| for item in result: | |||||
| if isinstance(item, RunCompletedEvent): | |||||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||||
| assert item.run_result.process_data is not None | |||||
| assert item.run_result.outputs is not None | |||||
| assert item.run_result.outputs.get("text") is not None | |||||
| assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 | |||||
| for item in result: | |||||
| if isinstance(item, RunCompletedEvent): | |||||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||||
| assert item.run_result.process_data is not None | |||||
| assert item.run_result.outputs is not None | |||||
| assert item.run_result.outputs.get("text") is not None | |||||
| assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 | |||||
| @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) | ||||
| def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock): | |||||
| def test_execute_llm_with_jinja2(app, setup_code_executor_mock): | |||||
| """ | """ | ||||
| Test execute LLM node with jinja2 | Test execute LLM node with jinja2 | ||||
| """ | """ | ||||
| node = init_llm_node( | |||||
| config={ | |||||
| "id": "llm", | |||||
| "data": { | |||||
| "title": "123", | |||||
| "type": "llm", | |||||
| "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, | |||||
| "prompt_config": { | |||||
| "jinja2_variables": [ | |||||
| {"variable": "sys_query", "value_selector": ["sys", "query"]}, | |||||
| {"variable": "output", "value_selector": ["abc", "output"]}, | |||||
| ] | |||||
| }, | |||||
| "prompt_template": [ | |||||
| { | |||||
| "role": "system", | |||||
| "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", | |||||
| "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", | |||||
| "edition_type": "jinja2", | |||||
| with app.app_context(): | |||||
| node = init_llm_node( | |||||
| config={ | |||||
| "id": "llm", | |||||
| "data": { | |||||
| "title": "123", | |||||
| "type": "llm", | |||||
| "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, | |||||
| "prompt_config": { | |||||
| "jinja2_variables": [ | |||||
| {"variable": "sys_query", "value_selector": ["sys", "query"]}, | |||||
| {"variable": "output", "value_selector": ["abc", "output"]}, | |||||
| ] | |||||
| }, | }, | ||||
| { | |||||
| "role": "user", | |||||
| "text": "{{#sys.query#}}", | |||||
| "jinja2_text": "{{sys_query}}", | |||||
| "edition_type": "basic", | |||||
| }, | |||||
| ], | |||||
| "memory": None, | |||||
| "context": {"enabled": False}, | |||||
| "vision": {"enabled": False}, | |||||
| "prompt_template": [ | |||||
| { | |||||
| "role": "system", | |||||
| "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", | |||||
| "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", | |||||
| "edition_type": "jinja2", | |||||
| }, | |||||
| { | |||||
| "role": "user", | |||||
| "text": "{{#sys.query#}}", | |||||
| "jinja2_text": "{{sys_query}}", | |||||
| "edition_type": "basic", | |||||
| }, | |||||
| ], | |||||
| "memory": None, | |||||
| "context": {"enabled": False}, | |||||
| "vision": {"enabled": False}, | |||||
| }, | |||||
| }, | }, | ||||
| }, | |||||
| ) | |||||
| ) | |||||
| credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} | |||||
| # Mock db.session.close() | |||||
| db.session.close = MagicMock() | |||||
| # Mock db.session.close() | |||||
| db.session.close = MagicMock() | |||||
| # Create a proper LLM result with real entities | |||||
| mock_usage = LLMUsage( | |||||
| prompt_tokens=30, | |||||
| prompt_unit_price=Decimal("0.001"), | |||||
| prompt_price_unit=Decimal("1000"), | |||||
| prompt_price=Decimal("0.00003"), | |||||
| completion_tokens=20, | |||||
| completion_unit_price=Decimal("0.002"), | |||||
| completion_price_unit=Decimal("1000"), | |||||
| completion_price=Decimal("0.00004"), | |||||
| total_tokens=50, | |||||
| total_price=Decimal("0.00007"), | |||||
| currency="USD", | |||||
| latency=0.5, | |||||
| ) | |||||
| node._fetch_model_config = get_mocked_fetch_model_config( | |||||
| provider="langgenius/openai/openai", | |||||
| model="gpt-3.5-turbo", | |||||
| mode="chat", | |||||
| credentials=credentials, | |||||
| ) | |||||
| mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") | |||||
| mock_llm_result = LLMResult( | |||||
| model="gpt-3.5-turbo", | |||||
| prompt_messages=[], | |||||
| message=mock_message, | |||||
| usage=mock_usage, | |||||
| ) | |||||
| # Create a simple mock model instance that doesn't call real providers | |||||
| mock_model_instance = MagicMock() | |||||
| mock_model_instance.invoke_llm.return_value = mock_llm_result | |||||
| # Create a simple mock model config with required attributes | |||||
| mock_model_config = MagicMock() | |||||
| mock_model_config.mode = "chat" | |||||
| mock_model_config.provider = "openai" | |||||
| mock_model_config.model = "gpt-3.5-turbo" | |||||
| mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" | |||||
| # Mock the _fetch_model_config method | |||||
| def mock_fetch_model_config_func(_node_data_model): | |||||
| return mock_model_instance, mock_model_config | |||||
| # Also mock ModelManager.get_model_instance to avoid database calls | |||||
| def mock_get_model_instance(_self, **kwargs): | |||||
| return mock_model_instance | |||||
| # execute node | |||||
| result = node._run() | |||||
| with ( | |||||
| patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), | |||||
| patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), | |||||
| ): | |||||
| # execute node | |||||
| result = node._run() | |||||
| for item in result: | |||||
| if isinstance(item, RunCompletedEvent): | |||||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||||
| assert item.run_result.process_data is not None | |||||
| assert "sunny" in json.dumps(item.run_result.process_data) | |||||
| assert "what's the weather today?" in json.dumps(item.run_result.process_data) | |||||
| for item in result: | |||||
| if isinstance(item, RunCompletedEvent): | |||||
| assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED | |||||
| assert item.run_result.process_data is not None | |||||
| assert "sunny" in json.dumps(item.run_result.process_data) | |||||
| assert "what's the weather today?" in json.dumps(item.run_result.process_data) | |||||
| def test_extract_json(): | def test_extract_json(): |