浏览代码

refactor: Improve model status handling and structured output (#20586)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.4.2
-LAN- 5 个月前
父节点
当前提交
5ccfb1f4ba
没有帐户链接到提交者的电子邮件

+ 19
- 0
api/core/entities/model_entities.py 查看文件

status: ModelStatus status: ModelStatus
load_balancing_enabled: bool = False load_balancing_enabled: bool = False


def raise_for_status(self) -> None:
"""
Check model status and raise ValueError if not active.

:raises ValueError: When model status is not active, with a descriptive message
"""
if self.status == ModelStatus.ACTIVE:
return

error_messages = {
ModelStatus.NO_CONFIGURE: "Model is not configured",
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
ModelStatus.NO_PERMISSION: "No permission to use this model",
ModelStatus.DISABLED: "Model is disabled",
}

if self.status in error_messages:
raise ValueError(error_messages[self.status])



class ModelWithProviderEntity(ProviderModelWithStatusEntity): class ModelWithProviderEntity(ProviderModelWithStatusEntity):
""" """

+ 44
- 31
api/core/extension/extensible.py 查看文件

extensions = [] extensions = []
position_map: dict[str, int] = {} position_map: dict[str, int] = {}


# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
current_dir_path = os.path.dirname(current_path)

# traverse subdirectories
for subdir_name in os.listdir(current_dir_path):
if subdir_name.startswith("__"):
continue

subdir_path = os.path.join(current_dir_path, subdir_name)
extension_name = subdir_name
if os.path.isdir(subdir_path):
# Get the package name from the module path
package_name = ".".join(cls.__module__.split(".")[:-1])

try:
# Get package directory path
package_spec = importlib.util.find_spec(package_name)
if not package_spec or not package_spec.origin:
raise ImportError(f"Could not find package {package_name}")

package_dir = os.path.dirname(package_spec.origin)

# Traverse subdirectories
for subdir_name in os.listdir(package_dir):
if subdir_name.startswith("__"):
continue

subdir_path = os.path.join(package_dir, subdir_name)
if not os.path.isdir(subdir_path):
continue

extension_name = subdir_name
file_names = os.listdir(subdir_path) file_names = os.listdir(subdir_path)


# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
# Check for extension module file
if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue

# Check for builtin flag and position
builtin = False builtin = False
# default position is 0 can not be None for sort_to_dict_by_position_map
position = 0 position = 0
if "__builtin__" in file_names: if "__builtin__" in file_names:
builtin = True builtin = True

builtin_file_path = os.path.join(subdir_path, "__builtin__") builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path): if os.path.exists(builtin_file_path):
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position_map[extension_name] = position position_map[extension_name] = position


if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue

# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + ".py")
spec = importlib.util.spec_from_file_location(extension_name, py_path)
# Import the extension module
module_name = f"{package_name}.{extension_name}.{extension_name}"
spec = importlib.util.find_spec(module_name)
if not spec or not spec.loader: if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}")
raise ImportError(f"Failed to load module {module_name}")
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod)


# Find extension class
extension_class = None extension_class = None
for name, obj in vars(mod).items(): for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
break break


if not extension_class: if not extension_class:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
continue continue


# Load schema if not builtin
json_data: dict[str, Any] = {} json_data: dict[str, Any] = {}
if not builtin: if not builtin:
if "schema.json" not in file_names:
json_path = os.path.join(subdir_path, "schema.json")
if not os.path.exists(json_path):
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue continue


json_path = os.path.join(subdir_path, "schema.json")
json_data = {}
if os.path.exists(json_path):
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)


# Create extension
extensions.append( extensions.append(
ModuleExtension( ModuleExtension(
extension_class=extension_class, extension_class=extension_class,
) )
) )


except Exception as e:
logging.exception("Error scanning extensions")
raise

# Sort extensions by position
sorted_extensions = sort_to_dict_by_position_map( sorted_extensions = sort_to_dict_by_position_map(
position_map=position_map, data=extensions, name_func=lambda x: x.name position_map=position_map, data=extensions, name_func=lambda x: x.name
) )

+ 4
- 0
api/core/model_runtime/entities/model_entities.py 查看文件

deprecated: bool = False deprecated: bool = False
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())


@property
def support_structure_output(self) -> bool:
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features



class ParameterRule(BaseModel): class ParameterRule(BaseModel):
""" """

+ 44
- 58
api/core/provider_manager.py 查看文件

from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Optional, cast from typing import Any, Optional, cast


from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session


from configs import dify_config from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity


@staticmethod @staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
"""
Get all provider records of the workspace.

:param tenant_id: workspace id
:return:
"""
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()

provider_name_to_provider_records_dict = defaultdict(list) provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers:
# TODO: Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)

with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
providers = session.scalars(stmt)
for provider in providers:
# Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
return provider_name_to_provider_records_dict return provider_name_to_provider_records_dict


@staticmethod @staticmethod
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
# Get all provider model records of the workspace
provider_models = (
db.session.query(ProviderModel)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
.all()
)

provider_name_to_provider_model_records_dict = defaultdict(list) provider_name_to_provider_model_records_dict = defaultdict(list)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)

with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
return provider_name_to_provider_model_records_dict return provider_name_to_provider_model_records_dict


@staticmethod @staticmethod
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
preferred_provider_types = (
db.session.query(TenantPreferredModelProvider)
.filter(TenantPreferredModelProvider.tenant_id == tenant_id)
.all()
)

provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
for preferred_provider_type in preferred_provider_types
}

provider_name_to_preferred_provider_type_records_dict = {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
preferred_provider_types = session.scalars(stmt)
provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
for preferred_provider_type in preferred_provider_types
}
return provider_name_to_preferred_provider_type_records_dict return provider_name_to_preferred_provider_type_records_dict


@staticmethod @staticmethod
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
provider_model_settings = (
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
)

provider_name_to_provider_model_settings_dict = defaultdict(list) provider_name_to_provider_model_settings_dict = defaultdict(list)
for provider_model_setting in provider_model_settings:
(
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
provider_model_setting provider_model_setting
) )
)

return provider_name_to_provider_model_settings_dict return provider_name_to_provider_model_settings_dict


@staticmethod @staticmethod
if not model_load_balancing_enabled: if not model_load_balancing_enabled:
return {} return {}


provider_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
)

provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
for provider_load_balancing_config in provider_load_balancing_configs:
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_configs = session.scalars(stmt)
for provider_load_balancing_config in provider_load_balancing_configs:
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)


return provider_name_to_provider_load_balancing_model_configs_dict return provider_name_to_provider_load_balancing_model_configs_dict


if not cached_provider_credentials: if not cached_provider_credentials:
try: try:
# fix origin data # fix origin data
if (
custom_provider_record.encrypted_config
and not custom_provider_record.encrypted_config.startswith("{")
):
if custom_provider_record.encrypted_config is None:
raise ValueError("No credentials found")
if not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else: else:
provider_credentials = json.loads(custom_provider_record.encrypted_config) provider_credentials = json.loads(custom_provider_record.encrypted_config)
return SystemConfiguration(enabled=False) return SystemConfiguration(enabled=False)


# Convert provider_records to dict # Convert provider_records to dict
quota_type_to_provider_records_dict = {}
quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records: for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value: if provider_record.provider_type != ProviderType.SYSTEM.value:
continue continue
else: else:
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]


if provider_record.quota_used is None:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")

quota_configuration = QuotaConfiguration( quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type, quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
cached_provider_credentials = provider_credentials_cache.get() cached_provider_credentials = provider_credentials_cache.get()


if not cached_provider_credentials: if not cached_provider_credentials:
try:
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
provider_credentials: dict[str, Any] = {}
if provider_records and provider_records[0].encrypted_config:
provider_credentials = json.loads(provider_records[0].encrypted_config)


# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(

+ 6
- 1
api/core/workflow/nodes/llm/entities.py 查看文件

context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None structured_output: dict | None = None
structured_output_enabled: bool = False
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")


@field_validator("prompt_config", mode="before") @field_validator("prompt_config", mode="before")
@classmethod @classmethod
if v is None: if v is None:
return PromptConfig() return PromptConfig()
return v return v

@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

+ 54
- 83
api/core/workflow/nodes/llm/node.py 查看文件



from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.workflow.utils.structured_output.entities import ( from core.workflow.utils.structured_output.entities import (
ResponseFormat, ResponseFormat,
SpecialModelType, SpecialModelType,
SupportStructuredOutputStatus,
) )
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
llm_usage=usage, llm_usage=usage,
) )
) )
except LLMNodeError as e:
except ValueError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
def _fetch_model_config( def _fetch_model_config(
self, node_data_model: ModelConfig self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = node_data_model.name
provider_name = node_data_model.provider
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")


model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
) )


provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

model_credentials = model_instance.credentials
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)


# check model # check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
) )


if provider_model is None: if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")

if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()


# model config # model config
completion_params = node_data_model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]

# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise LLMModeRequiredError("LLM mode is required.")

model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")


model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema: if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
# Set appropriate response format based on model capabilities
self._set_response_format(completion_params, model_schema.parameter_rules)
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")

if self.node_data.structured_output_enabled:
if model_schema.support_structure_output:
node_data_model.completion_params = self._handle_native_json_schema(
node_data_model.completion_params, model_schema.parameter_rules
)
else:
# Set appropriate response format based on model capabilities
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)

return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema, model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop, stop=stop,
) )


"No prompt found in the LLM configuration. " "No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding." "Please ensure a prompt is properly configured before proceeding."
) )
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
stop = model_config.stop
return filtered_prompt_messages, stop

model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=self.node_data.model.provider,
model=self.node_data.model.name,
)
model_schema = model.model_type_instance.get_model_schema(
model=self.node_data.model.name,
credentials=model.credentials,
)
if not model_schema:
raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
if self.node_data.structured_output_enabled:
if not model_schema.support_structure_output:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
return filtered_prompt_messages, model_config.stop


def _parse_structured_output(self, result_text: str) -> dict[str, Any]: def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
structured_output: dict[str, Any] = {} structured_output: dict[str, Any] = {}
except json.JSONDecodeError: except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format") raise LLMNodeError("structured_output_schema is not valid JSON format")


def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
"""
Check if the current model supports structured output.

Returns:
SupportStructuredOutput: The support status of structured output
"""
# Early return if structured output is disabled
if (
not isinstance(self.node_data, LLMNodeData)
or not self.node_data.structured_output_enabled
or not self.node_data.structured_output
):
return SupportStructuredOutputStatus.DISABLED
# Get model schema and check if it exists
model_schema = self._fetch_model_schema(self.node_data.model.provider)
if not model_schema:
return SupportStructuredOutputStatus.DISABLED

# Check if model supports structured output feature
return (
SupportStructuredOutputStatus.SUPPORTED
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
else SupportStructuredOutputStatus.UNSUPPORTED
)

def _save_multimodal_output_and_convert_result_to_markdown( def _save_multimodal_output_and_convert_result_to_markdown(
self, self,
contents: str | list[PromptMessageContentUnionTypes] | None, contents: str | list[PromptMessageContentUnionTypes] | None,

+ 0
- 8
api/core/workflow/utils/structured_output/entities.py 查看文件



GEMINI = "gemini" GEMINI = "gemini"
OLLAMA = "ollama" OLLAMA = "ollama"


class SupportStructuredOutputStatus(StrEnum):
"""Constants for structured output support status"""

SUPPORTED = "supported"
UNSUPPORTED = "unsupported"
DISABLED = "disabled"

+ 79
- 70
api/models/provider.py 查看文件

from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional


from sqlalchemy import func
from sqlalchemy import func, text
from sqlalchemy.orm import Mapped, mapped_column


from .base import Base from .base import Base
from .engine import db from .engine import db
), ),
) )


id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
last_used = db.Column(db.DateTime, nullable=True)
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider_type: Mapped[str] = mapped_column(
db.String(40), nullable=False, server_default=text("'custom'::character varying")
)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)


quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
quota_limit = db.Column(db.BigInteger, nullable=True)
quota_used = db.Column(db.BigInteger, default=0)
quota_type: Mapped[Optional[str]] = mapped_column(
db.String(40), nullable=True, server_default=text("''::character varying")
)
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)


created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())


def __repr__(self): def __repr__(self):
return ( return (
), ),
) )


id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())




class TenantDefaultModel(Base): class TenantDefaultModel(Base):
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
) )


id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())




class TenantPreferredModelProvider(Base): class TenantPreferredModelProvider(Base):
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
) )


id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())




class ProviderOrder(Base): class ProviderOrder(Base):
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
) )


id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
account_id = db.Column(StringUUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191))
transaction_id = db.Column(db.String(191))
quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1"))
currency = db.Column(db.String(40))
total_amount = db.Column(db.Integer)
payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying"))
paid_at = db.Column(db.DateTime)
pay_failed_at = db.Column(db.DateTime)
refunded_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
currency: Mapped[Optional[str]] = mapped_column(db.String(40))
total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
payment_status: Mapped[str] = mapped_column(
db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
)
paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())




class ProviderModelSetting(Base): class ProviderModelSetting(Base):
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
) )


id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())




class LoadBalancingModelConfig(Base): class LoadBalancingModelConfig(Base):
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
) )


id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
name = db.Column(db.String(255), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
name: Mapped[str] = mapped_column(db.String(255), nullable=False)
encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())

+ 202
- 95
api/tests/integration_tests/workflow/nodes/test_llm.py 查看文件

import time import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from unittest.mock import MagicMock
from decimal import Decimal
from unittest.mock import MagicMock, patch


import pytest import pytest


from app_factory import create_app
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config


"""FOR MOCK FIXTURES, DO NOT REMOVE""" """FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock




@pytest.fixture(scope="session")
def app():
# Set up storage configuration
os.environ["STORAGE_TYPE"] = "opendal"
os.environ["OPENDAL_SCHEME"] = "fs"
os.environ["OPENDAL_FS_ROOT"] = "storage"

# Ensure storage directory exists
os.makedirs("storage", exist_ok=True)

app = create_app()
dify_config.LOGIN_DISABLED = True
return app


def init_llm_node(config: dict) -> LLMNode: def init_llm_node(config: dict) -> LLMNode:
graph_config = { graph_config = {
"edges": [ "edges": [


graph = Graph.init(graph_config=graph_config) graph = Graph.init(graph_config=graph_config)


# Use proper UUIDs for database compatibility
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"

init_params = GraphInitParams( init_params = GraphInitParams(
tenant_id="1",
app_id="1",
tenant_id=tenant_id,
app_id=app_id,
workflow_type=WorkflowType.WORKFLOW, workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
workflow_id=workflow_id,
graph_config=graph_config, graph_config=graph_config,
user_id="1",
user_id=user_id,
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
call_depth=0, call_depth=0,
return node return node




def test_execute_llm(setup_model_mock):
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
def test_execute_llm(app):
with app.app_context():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
}, },
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
}, },
},
)
)


credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}


# Mock db.session.close()
db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)


node._fetch_model_config = get_mocked_fetch_model_config(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials=credentials,
)
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")

mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)

# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result

# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"

# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config

# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance


# execute node
result = node._run()
assert isinstance(result, Generator)
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
assert isinstance(result, Generator)


for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0




@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock):
def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
""" """
Test execute LLM node with jinja2 Test execute LLM node with jinja2
""" """
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
with app.app_context():
node = init_llm_node(
config={
"id": "llm",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_config": {
"jinja2_variables": [
{"variable": "sys_query", "value_selector": ["sys", "query"]},
{"variable": "output", "value_selector": ["abc", "output"]},
]
}, },
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
},
{
"role": "user",
"text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
}, },
},
)
)


credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
# Mock db.session.close()
db.session.close = MagicMock()


# Mock db.session.close()
db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)


node._fetch_model_config = get_mocked_fetch_model_config(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials=credentials,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")

mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)

# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result

# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"

# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config

# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance


# execute node
result = node._run()
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()


for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)




def test_extract_json(): def test_extract_json():

正在加载...
取消
保存