Selaa lähdekoodia

[Chore/Refactor] Improve type annotations in models module (#25281)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
tags/2.0.0-beta.2^2
-LAN- 1 kuukausi sitten
vanhempi
commit
9b8a03b53b
No account linked to committer's email address

+ 1
- 1
api/controllers/console/apikey.py Näytä tiedosto

custom="max_keys_exceeded", custom="max_keys_exceeded",
) )


key = ApiToken.generate_api_key(self.token_prefix, 24)
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id api_token.tenant_id = current_user.current_tenant_id

+ 6
- 0
api/controllers/console/datasets/datasets_document.py Näytä tiedosto

data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict


if document.data_source_type == "upload_file": if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
extract_settings.append(extract_setting) extract_settings.append(extract_setting)


elif document.data_source_type == "notion_import": elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value, datasource_type=DatasourceType.NOTION.value,
notion_info={ notion_info={
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl": elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value, datasource_type=DatasourceType.WEBSITE.value,
website_info={ website_info={

+ 2
- 0
api/controllers/console/explore/parameter.py Näytä tiedosto

def get(self, installed_app: InstalledApp): def get(self, installed_app: InstalledApp):
"""Get app meta""" """Get app meta"""
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model) return AppService().get_app_meta(app_model)





+ 4
- 0
api/controllers/console/explore/workflow.py Näytä tiedosto

Run workflow Run workflow
""" """
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()
Stop workflow task Stop workflow task
""" """
app_model = installed_app.app app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW: if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError() raise NotWorkflowAppError()

+ 3
- 0
api/core/app/apps/completion/app_generator.py Näytä tiedosto

raise MessageNotExistsError() raise MessageNotExistsError()


current_app_model_config = app_model.app_model_config current_app_model_config = app_model.app_model_config
if not current_app_model_config:
raise MoreLikeThisDisabledError()

more_like_this = current_app_model_config.more_like_this_dict more_like_this = current_app_model_config.more_like_this_dict


if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:

+ 2
- 1
api/core/rag/extractor/notion_extractor.py Näytä tiedosto



last_edited_time = self.get_notion_last_edited_time() last_edited_time = self.get_notion_last_edited_time()
data_source_info = document_model.data_source_info_dict data_source_info = document_model.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
if data_source_info:
data_source_info["last_edited_time"] = last_edited_time


db.session.query(DocumentModel).filter_by(id=document_model.id).update( db.session.query(DocumentModel).filter_by(id=document_model.id).update(
{DocumentModel.data_source_info: json.dumps(data_source_info)} {DocumentModel.data_source_info: json.dumps(data_source_info)}

+ 2
- 2
api/core/tools/mcp_tool/provider.py Näytä tiedosto

import json import json
from typing import Any, Optional
from typing import Any, Optional, Self


from core.mcp.types import Tool as RemoteMCPTool from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
return ToolProviderType.MCP return ToolProviderType.MCP


@classmethod @classmethod
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
def from_db(cls, db_provider: MCPToolProvider) -> Self:
""" """
from db provider from db provider
""" """

+ 2
- 2
api/core/tools/tool_manager.py Näytä tiedosto

if provider is None: if provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")


controller = MCPToolProviderController._from_db(provider)
controller = MCPToolProviderController.from_db(provider)


return controller return controller


tenant_id: str, tenant_id: str,
provider_type: ToolProviderType, provider_type: ToolProviderType,
provider_id: str, provider_id: str,
) -> Union[str, dict]:
) -> Union[str, dict[str, Any]]:
""" """
get the tool icon get the tool icon



+ 4
- 4
api/models/account.py Näytä tiedosto

import enum import enum
import json import json
from datetime import datetime from datetime import datetime
from typing import Optional
from typing import Any, Optional


import sqlalchemy as sa import sqlalchemy as sa
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import DateTime, String, func, select from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor


) )


@property @property
def custom_config_dict(self):
def custom_config_dict(self) -> dict[str, Any]:
return json.loads(self.custom_config) if self.custom_config else {} return json.loads(self.custom_config) if self.custom_config else {}


@custom_config_dict.setter @custom_config_dict.setter
def custom_config_dict(self, value: dict):
def custom_config_dict(self, value: dict[str, Any]) -> None:
self.custom_config = json.dumps(value) self.custom_config = json.dumps(value)





+ 72
- 64
api/models/dataset.py Näytä tiedosto

"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
} }


def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"dataset_id": self.dataset_id, "dataset_id": self.dataset_id,
} }


@property @property
def rules_dict(self):
def rules_dict(self) -> dict[str, Any] | None:
try: try:
return json.loads(self.rules) if self.rules else None return json.loads(self.rules) if self.rules else None
except JSONDecodeError: except JSONDecodeError:
return status return status


@property @property
def data_source_info_dict(self):
def data_source_info_dict(self) -> dict[str, Any] | None:
if self.data_source_info: if self.data_source_info:
try: try:
data_source_info_dict = json.loads(self.data_source_info)
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
except JSONDecodeError: except JSONDecodeError:
data_source_info_dict = {} data_source_info_dict = {}


return None return None


@property @property
def data_source_detail_dict(self):
def data_source_detail_dict(self) -> dict[str, Any]:
if self.data_source_info: if self.data_source_info:
if self.data_source_type == "upload_file": if self.data_source_type == "upload_file":
data_source_info_dict = json.loads(self.data_source_info)
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"]) .where(UploadFile.id == data_source_info_dict["upload_file_id"])
} }
} }
elif self.data_source_type in {"notion_import", "website_crawl"}: elif self.data_source_type in {"notion_import", "website_crawl"}:
return json.loads(self.data_source_info)
result: dict[str, Any] = json.loads(self.data_source_info)
return result
return {} return {}


@property @property
return self.updated_at return self.updated_at


@property @property
def doc_metadata_details(self):
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
if self.doc_metadata: if self.doc_metadata:
document_metadatas = ( document_metadatas = (
db.session.query(DatasetMetadata) db.session.query(DatasetMetadata)
) )
.all() .all()
) )
metadata_list = []
metadata_list: list[dict[str, Any]] = []
for metadata in document_metadatas: for metadata in document_metadatas:
metadata_dict = {
metadata_dict: dict[str, Any] = {
"id": metadata.id, "id": metadata.id,
"name": metadata.name, "name": metadata.name,
"type": metadata.type, "type": metadata.type,
return None return None


@property @property
def process_rule_dict(self):
if self.dataset_process_rule_id:
def process_rule_dict(self) -> dict[str, Any] | None:
if self.dataset_process_rule_id and self.dataset_process_rule:
return self.dataset_process_rule.to_dict() return self.dataset_process_rule.to_dict()
return None return None


def get_built_in_fields(self):
built_in_fields = []
def get_built_in_fields(self) -> list[dict[str, Any]]:
built_in_fields: list[dict[str, Any]] = []
built_in_fields.append( built_in_fields.append(
{ {
"id": "built-in", "id": "built-in",
) )
return built_in_fields return built_in_fields


def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"tenant_id": self.tenant_id, "tenant_id": self.tenant_id,
"data_source_info_dict": self.data_source_info_dict, "data_source_info_dict": self.data_source_info_dict,
"average_segment_length": self.average_segment_length, "average_segment_length": self.average_segment_length,
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
"dataset": self.dataset.to_dict() if self.dataset else None,
"dataset": None, # Dataset class doesn't have a to_dict method
"segment_count": self.segment_count, "segment_count": self.segment_count,
"hit_count": self.hit_count, "hit_count": self.hit_count,
} }


@classmethod @classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: dict[str, Any]):
return cls( return cls(
id=data.get("id"), id=data.get("id"),
tenant_id=data.get("tenant_id"), tenant_id=data.get("tenant_id"),
) )


@property @property
def child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
def child_chunks(self) -> list[Any]:
if not self.document:
return [] return []

def get_child_chunks(self):
process_rule = self.document.dataset_process_rule process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []

def get_child_chunks(self) -> list[Any]:
if not self.document:
return [] return []
process_rule = self.document.dataset_process_rule
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []


@property @property
def sign_content(self):
def sign_content(self) -> str:
return self.get_sign_content() return self.get_sign_content()


def get_sign_content(self):
signed_urls = []
def get_sign_content(self) -> str:
signed_urls: list[tuple[int, int, str]] = []
text = self.content text = self.content


# For data before v0.10.0 # For data before v0.10.0
) )


@property @property
def keyword_table_dict(self):
def keyword_table_dict(self) -> dict[str, set[Any]] | None:
class SetDecoder(json.JSONDecoder): class SetDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
super().__init__(object_hook=self.object_hook, *args, **kwargs)

def object_hook(self, dct):
if isinstance(dct, dict):
for keyword, node_idxs in dct.items():
if isinstance(node_idxs, list):
dct[keyword] = set(node_idxs)
return dct
def __init__(self, *args: Any, **kwargs: Any) -> None:
def object_hook(dct: Any) -> Any:
if isinstance(dct, dict):
result: dict[str, Any] = {}
items = cast(dict[str, Any], dct).items()
for keyword, node_idxs in items:
if isinstance(node_idxs, list):
result[keyword] = set(cast(list[Any], node_idxs))
else:
result[keyword] = node_idxs
return result
return dct

super().__init__(object_hook=object_hook, *args, **kwargs)


# get dataset # get dataset
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
updated_by = mapped_column(StringUUID, nullable=True) updated_by = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())


def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"tenant_id": self.tenant_id, "tenant_id": self.tenant_id,
} }


@property @property
def settings_dict(self):
def settings_dict(self) -> dict[str, Any] | None:
try: try:
return json.loads(self.settings) if self.settings else None return json.loads(self.settings) if self.settings else None
except JSONDecodeError: except JSONDecodeError:
return None return None


@property @property
def dataset_bindings(self):
def dataset_bindings(self) -> list[dict[str, Any]]:
external_knowledge_bindings = ( external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings) db.session.query(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
) )
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = []
dataset_bindings: list[dict[str, Any]] = []
for dataset in datasets: for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name}) dataset_bindings.append({"id": dataset.id, "name": dataset.name})



+ 150
- 101
api/models/model.py Näytä tiedosto



import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore[import-untyped]
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column


from constants import DEFAULT_FILE_NUMBER_LIMITS from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from libs.helper import generate_string
from libs.helper import generate_string # type: ignore[import-not-found]


from .account import Account, Tenant from .account import Account, Tenant
from .base import Base from .base import Base
use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))


@property @property
def desc_or_prompt(self):
def desc_or_prompt(self) -> str:
if self.description: if self.description:
return self.description return self.description
else: else:
return "" return ""


@property @property
def site(self):
def site(self) -> Optional["Site"]:
site = db.session.query(Site).where(Site.app_id == self.id).first() site = db.session.query(Site).where(Site.app_id == self.id).first()
return site return site


@property @property
def app_model_config(self):
def app_model_config(self) -> Optional["AppModelConfig"]:
if self.app_model_config_id: if self.app_model_config_id:
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()


return None return None


@property @property
def api_base_url(self):
def api_base_url(self) -> str:
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"


@property @property
def tenant(self):
def tenant(self) -> Optional[Tenant]:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant return tenant


return str(self.mode) return str(self.mode)


@property @property
def deleted_tools(self):
def deleted_tools(self) -> list[dict[str, str]]:
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from services.plugin.plugin_service import PluginService from services.plugin.plugin_service import PluginService


provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
} }


deleted_tools = []
deleted_tools: list[dict[str, str]] = []


for tool in tools: for tool in tools:
keys = list(tool.keys()) keys = list(tool.keys())
return deleted_tools return deleted_tools


@property @property
def tags(self):
def tags(self) -> list["Tag"]:
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id) .join(TagBinding, Tag.id == TagBinding.tag_id)
return tags or [] return tags or []


@property @property
def author_name(self):
def author_name(self) -> Optional[str]:
if self.created_by: if self.created_by:
account = db.session.query(Account).where(Account.id == self.created_by).first() account = db.session.query(Account).where(Account.id == self.created_by).first()
if account: if account:
file_upload = mapped_column(sa.Text) file_upload = mapped_column(sa.Text)


@property @property
def app(self):
def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app


@property @property
def model_dict(self):
def model_dict(self) -> dict[str, Any]:
return json.loads(self.model) if self.model else {} return json.loads(self.model) if self.model else {}


@property @property
def suggested_questions_list(self):
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else [] return json.loads(self.suggested_questions) if self.suggested_questions else []


@property @property
def suggested_questions_after_answer_dict(self):
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
return ( return (
json.loads(self.suggested_questions_after_answer) json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer if self.suggested_questions_after_answer
) )


@property @property
def speech_to_text_dict(self):
def speech_to_text_dict(self) -> dict[str, Any]:
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}


@property @property
def text_to_speech_dict(self):
def text_to_speech_dict(self) -> dict[str, Any]:
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}


@property @property
def retriever_resource_dict(self):
def retriever_resource_dict(self) -> dict[str, Any]:
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}


@property @property
def annotation_reply_dict(self):
def annotation_reply_dict(self) -> dict[str, Any]:
annotation_setting = ( annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
) )
return {"enabled": False} return {"enabled": False}


@property @property
def more_like_this_dict(self):
def more_like_this_dict(self) -> dict[str, Any]:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}


@property @property
def sensitive_word_avoidance_dict(self):
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
return ( return (
json.loads(self.sensitive_word_avoidance) json.loads(self.sensitive_word_avoidance)
if self.sensitive_word_avoidance if self.sensitive_word_avoidance
) )


@property @property
def external_data_tools_list(self) -> list[dict]:
def external_data_tools_list(self) -> list[dict[str, Any]]:
return json.loads(self.external_data_tools) if self.external_data_tools else [] return json.loads(self.external_data_tools) if self.external_data_tools else []


@property @property
def user_input_form_list(self):
def user_input_form_list(self) -> list[dict[str, Any]]:
return json.loads(self.user_input_form) if self.user_input_form else [] return json.loads(self.user_input_form) if self.user_input_form else []


@property @property
def agent_mode_dict(self):
def agent_mode_dict(self) -> dict[str, Any]:
return ( return (
json.loads(self.agent_mode) json.loads(self.agent_mode)
if self.agent_mode if self.agent_mode
) )


@property @property
def chat_prompt_config_dict(self):
def chat_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}


@property @property
def completion_prompt_config_dict(self):
def completion_prompt_config_dict(self) -> dict[str, Any]:
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}


@property @property
def dataset_configs_dict(self):
def dataset_configs_dict(self) -> dict[str, Any]:
if self.dataset_configs: if self.dataset_configs:
dataset_configs: dict = json.loads(self.dataset_configs)
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
if "retrieval_model" not in dataset_configs: if "retrieval_model" not in dataset_configs:
return {"retrieval_model": "single"} return {"retrieval_model": "single"}
else: else:
} }


@property @property
def file_upload_dict(self):
def file_upload_dict(self) -> dict[str, Any]:
return ( return (
json.loads(self.file_upload) json.loads(self.file_upload)
if self.file_upload if self.file_upload
} }
) )


def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return { return {
"opening_statement": self.opening_statement, "opening_statement": self.opening_statement,
"suggested_questions": self.suggested_questions_list, "suggested_questions": self.suggested_questions_list,
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())


@property @property
def app(self):
def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app


created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())


@property @property
def app(self):
def app(self) -> Optional[App]:
app = db.session.query(App).where(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app


@property @property
def tenant(self):
def tenant(self) -> Optional[Tenant]:
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant return tenant


mode: Mapped[str] = mapped_column(String(255)) mode: Mapped[str] = mapped_column(String(255))
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
summary = mapped_column(sa.Text) summary = mapped_column(sa.Text)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
introduction = mapped_column(sa.Text) introduction = mapped_column(sa.Text)
system_instruction = mapped_column(sa.Text) system_instruction = mapped_column(sa.Text)
system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))


@property @property
def inputs(self):
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy() inputs = self._inputs.copy()


# Convert file mapping to File object # Convert file mapping to File object
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory from factories import file_factory


if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
): ):
inputs[key] = []
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
value_dict = cast(dict[str, Any], value)
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
value_dict["tool_file_id"] = value_dict["related_id"]
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value_dict["upload_file_id"] = value_dict["related_id"]
tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
inputs[key] = file_list


return inputs return inputs


for k, v in inputs.items(): for k, v in inputs.items():
if isinstance(v, File): if isinstance(v, File):
inputs[k] = v.model_dump() inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
inputs[k] = [item.model_dump() for item in v]
elif isinstance(v, list):
v_list = cast(list[Any], v)
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs self._inputs = inputs


@property @property
) )


@property @property
def app(self):
def app(self) -> Optional[App]:
return db.session.query(App).where(App.id == self.app_id).first() return db.session.query(App).where(App.id == self.app_id).first()


@property @property
return None return None


@property @property
def from_account_name(self):
def from_account_name(self) -> Optional[str]:
if self.from_account_id: if self.from_account_id:
account = db.session.query(Account).where(Account.id == self.from_account_id).first() account = db.session.query(Account).where(Account.id == self.from_account_id).first()
if account: if account:
return None return None


@property @property
def in_debug_mode(self):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None return self.override_model_configs is not None


def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"app_id": self.app_id, "app_id": self.app_id,
model_id = mapped_column(String(255), nullable=True) model_id = mapped_column(String(255), nullable=True)
override_model_configs = mapped_column(sa.Text) override_model_configs = mapped_column(sa.Text)
conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", sa.JSON)
_inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON)
query: Mapped[str] = mapped_column(sa.Text, nullable=False) query: Mapped[str] = mapped_column(sa.Text, nullable=False)
message = mapped_column(sa.JSON, nullable=False) message = mapped_column(sa.JSON, nullable=False)
message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)


@property @property
def inputs(self):
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy() inputs = self._inputs.copy()
for key, value in inputs.items(): for key, value in inputs.items():
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory from factories import file_factory


if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
if (
isinstance(value, dict)
and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY
): ):
inputs[key] = []
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
value_dict = cast(dict[str, Any], value)
if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
value_dict["tool_file_id"] = value_dict["related_id"]
elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value_dict["upload_file_id"] = value_dict["related_id"]
tenant_id = cast(str, value_dict.get("tenant_id", ""))
inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id)
elif isinstance(value, list):
value_list = cast(list[Any], value)
if all(
isinstance(item, dict)
and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY
for item in value_list
):
file_list: list[File] = []
for item in value_list:
if not isinstance(item, dict):
continue
item_dict = cast(dict[str, Any], item)
if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE:
item_dict["tool_file_id"] = item_dict["related_id"]
elif item_dict["transfer_method"] in [
FileTransferMethod.LOCAL_FILE,
FileTransferMethod.REMOTE_URL,
]:
item_dict["upload_file_id"] = item_dict["related_id"]
tenant_id = cast(str, item_dict.get("tenant_id", ""))
file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id))
inputs[key] = file_list
return inputs return inputs


@inputs.setter @inputs.setter
for k, v in inputs.items(): for k, v in inputs.items():
if isinstance(v, File): if isinstance(v, File):
inputs[k] = v.model_dump() inputs[k] = v.model_dump()
elif isinstance(v, list) and all(isinstance(item, File) for item in v):
inputs[k] = [item.model_dump() for item in v]
elif isinstance(v, list):
v_list = cast(list[Any], v)
if all(isinstance(item, File) for item in v_list):
inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)]
self._inputs = inputs self._inputs = inputs


@property @property
return None return None


@property @property
def in_debug_mode(self):
def in_debug_mode(self) -> bool:
return self.override_model_configs is not None return self.override_model_configs is not None


@property @property
def message_metadata_dict(self):
def message_metadata_dict(self) -> dict[str, Any]:
return json.loads(self.message_metadata) if self.message_metadata else {} return json.loads(self.message_metadata) if self.message_metadata else {}


@property @property
def agent_thoughts(self):
def agent_thoughts(self) -> list["MessageAgentThought"]:
return ( return (
db.session.query(MessageAgentThought) db.session.query(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id) .where(MessageAgentThought.message_id == self.id)
) )


@property @property
def retriever_resources(self):
def retriever_resources(self) -> Any | list[Any]:
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []


@property @property
def message_files(self):
def message_files(self) -> list[dict[str, Any]]:
from factories import file_factory from factories import file_factory


message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
if not current_app: if not current_app:
raise ValueError(f"App {self.app_id} not found") raise ValueError(f"App {self.app_id} not found")


files = []
files: list[File] = []
for message_file in message_files: for message_file in message_files:
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value:
if message_file.upload_file_id is None: if message_file.upload_file_id is None:
) )
files.append(file) files.append(file)


result = [
result: list[dict[str, Any]] = [
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
for (file, message_file) in zip(files, message_files) for (file, message_file) in zip(files, message_files)
] ]


return None return None


def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"app_id": self.app_id, "app_id": self.app_id,
} }


@classmethod @classmethod
def from_dict(cls, data: dict):
def from_dict(cls, data: dict[str, Any]) -> "Message":
return cls( return cls(
id=data["id"], id=data["id"],
app_id=data["app_id"], app_id=data["app_id"],
account = db.session.query(Account).where(Account.id == self.from_account_id).first() account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account return account


def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return { return {
"id": str(self.id), "id": str(self.id),
"app_id": str(self.app_id), "app_id": str(self.app_id),
type: Mapped[str] = mapped_column(String(255), nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False)
external_user_id = mapped_column(String(255), nullable=True) external_user_id = mapped_column(String(255), nullable=True)
name = mapped_column(String(255)) name = mapped_column(String(255))
is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
_is_anonymous: Mapped[bool] = mapped_column(
"is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true")
)

@property
def is_anonymous(self) -> Literal[False]:
return False

@is_anonymous.setter
def is_anonymous(self, value: bool) -> None:
self._is_anonymous = value

session_id: Mapped[str] = mapped_column() session_id: Mapped[str] = mapped_column()
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())


@staticmethod @staticmethod
def generate_server_code(n):
def generate_server_code(n: int) -> str:
while True: while True:
result = generate_string(n) result = generate_string(n)
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
self._custom_disclaimer = value self._custom_disclaimer = value


@staticmethod @staticmethod
def generate_code(n):
def generate_code(n: int) -> str:
while True: while True:
result = generate_string(n) result = generate_string(n)
while db.session.query(Site).where(Site.code == result).count() > 0: while db.session.query(Site).where(Site.code == result).count() > 0:
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())


@staticmethod @staticmethod
def generate_api_key(prefix, n):
def generate_api_key(prefix: str, n: int) -> str:
while True: while True:
result = prefix + generate_string(n) result = prefix + generate_string(n)
if db.session.scalar(select(exists().where(ApiToken.token == result))): if db.session.scalar(select(exists().where(ApiToken.token == result))):
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())


@property @property
def files(self):
def files(self) -> list[Any]:
if self.message_files: if self.message_files:
return cast(list[Any], json.loads(self.message_files)) return cast(list[Any], json.loads(self.message_files))
else: else:
return self.tool.split(";") if self.tool else [] return self.tool.split(";") if self.tool else []


@property @property
def tool_labels(self):
def tool_labels(self) -> dict[str, Any]:
try: try:
if self.tool_labels_str: if self.tool_labels_str:
return cast(dict, json.loads(self.tool_labels_str))
return cast(dict[str, Any], json.loads(self.tool_labels_str))
else: else:
return {} return {}
except Exception: except Exception:
return {} return {}


@property @property
def tool_meta(self):
def tool_meta(self) -> dict[str, Any]:
try: try:
if self.tool_meta_str: if self.tool_meta_str:
return cast(dict, json.loads(self.tool_meta_str))
return cast(dict[str, Any], json.loads(self.tool_meta_str))
else: else:
return {} return {}
except Exception: except Exception:
return {} return {}


@property @property
def tool_inputs_dict(self):
def tool_inputs_dict(self) -> dict[str, Any]:
tools = self.tools tools = self.tools
try: try:
if self.tool_input: if self.tool_input:
data = json.loads(self.tool_input) data = json.loads(self.tool_input)
result = {}
result: dict[str, Any] = {}
for tool in tools: for tool in tools:
if tool in data: if tool in data:
result[tool] = data[tool] result[tool] = data[tool]
return {} return {}


@property @property
def tool_outputs_dict(self):
def tool_outputs_dict(self) -> dict[str, Any]:
tools = self.tools tools = self.tools
try: try:
if self.observation: if self.observation:
data = json.loads(self.observation) data = json.loads(self.observation)
result = {}
result: dict[str, Any] = {}
for tool in tools: for tool in tools:
if tool in data: if tool in data:
result[tool] = data[tool] result[tool] = data[tool]
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))


@property @property
def tracing_config_dict(self):
def tracing_config_dict(self) -> dict[str, Any]:
return self.tracing_config or {} return self.tracing_config or {}


@property @property
def tracing_config_str(self):
def tracing_config_str(self) -> str:
return json.dumps(self.tracing_config_dict) return json.dumps(self.tracing_config_dict)


def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return { return {
"id": self.id, "id": self.id,
"app_id": self.app_id, "app_id": self.app_id,

+ 2
- 2
api/models/provider.py Näytä tiedosto

SYSTEM = "system" SYSTEM = "system"


@staticmethod @staticmethod
def value_of(value):
def value_of(value: str) -> "ProviderType":
for member in ProviderType: for member in ProviderType:
if member.value == value: if member.value == value:
return member return member
"""hosted trial quota""" """hosted trial quota"""


@staticmethod @staticmethod
def value_of(value):
def value_of(value: str) -> "ProviderQuotaType":
for member in ProviderQuotaType: for member in ProviderQuotaType:
if member.value == value: if member.value == value:
return member return member

+ 12
- 12
api/models/tools.py Näytä tiedosto

import json import json
from datetime import datetime from datetime import datetime
from typing import Optional, cast
from typing import Any, Optional, cast
from urllib.parse import urlparse from urllib.parse import urlparse


import sqlalchemy as sa import sqlalchemy as sa
encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False)


@property @property
def oauth_params(self):
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
def oauth_params(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}"))




class BuiltinToolProvider(Base): class BuiltinToolProvider(Base):
expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"))


@property @property
def credentials(self):
return cast(dict, json.loads(self.encrypted_credentials))
def credentials(self) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(self.encrypted_credentials))




class ApiToolProvider(Base): class ApiToolProvider(Base):
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]


@property @property
def credentials(self):
return dict(json.loads(self.credentials_str))
def credentials(self) -> dict[str, Any]:
return dict[str, Any](json.loads(self.credentials_str))


@property @property
def user(self) -> Account | None: def user(self) -> Account | None:
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()


@property @property
def credentials(self):
def credentials(self) -> dict[str, Any]:
try: try:
return cast(dict, json.loads(self.encrypted_credentials)) or {}
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
except Exception: except Exception:
return {} return {}


return mask_url(self.decrypted_server_url) return mask_url(self.decrypted_server_url)


@property @property
def decrypted_credentials(self):
def decrypted_credentials(self) -> dict[str, Any]:
from core.helper.provider_cache import NoOpProviderCredentialCache from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.encryption import create_provider_encrypter from core.tools.utils.encryption import create_provider_encrypter


provider_controller = MCPToolProviderController._from_db(self)
provider_controller = MCPToolProviderController.from_db(self)


encrypter, _ = create_provider_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
) )


return encrypter.decrypt(self.credentials) # type: ignore
return encrypter.decrypt(self.credentials)




class ToolModelInvoke(Base): class ToolModelInvoke(Base):

+ 20
- 18
api/models/types.py Näytä tiedosto

import enum import enum
from typing import Generic, TypeVar
import uuid
from typing import Any, Generic, TypeVar


from sqlalchemy import CHAR, VARCHAR, TypeDecorator from sqlalchemy import CHAR, VARCHAR, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.type_api import TypeEngine




class StringUUID(TypeDecorator):
class StringUUID(TypeDecorator[uuid.UUID | str | None]):
impl = CHAR impl = CHAR
cache_ok = True cache_ok = True


def process_bind_param(self, value, dialect):
def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None: if value is None:
return value return value
elif dialect.name == "postgresql": elif dialect.name == "postgresql":
return str(value) return str(value)
else: else:
return value.hex
if isinstance(value, uuid.UUID):
return value.hex
return value


def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
if dialect.name == "postgresql": if dialect.name == "postgresql":
return dialect.type_descriptor(UUID()) return dialect.type_descriptor(UUID())
else: else:
return dialect.type_descriptor(CHAR(36)) return dialect.type_descriptor(CHAR(36))


def process_result_value(self, value, dialect):
def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None:
if value is None: if value is None:
return value return value
return str(value) return str(value)
_E = TypeVar("_E", bound=enum.StrEnum) _E = TypeVar("_E", bound=enum.StrEnum)




class EnumText(TypeDecorator, Generic[_E]):
class EnumText(TypeDecorator[_E | None], Generic[_E]):
impl = VARCHAR impl = VARCHAR
cache_ok = True cache_ok = True


# leave some rooms for future longer enum values. # leave some rooms for future longer enum values.
self._length = max(max_enum_value_len, 20) self._length = max(max_enum_value_len, 20)


def process_bind_param(self, value: _E | str | None, dialect):
def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None:
if value is None: if value is None:
return value return value
if isinstance(value, self._enum_class): if isinstance(value, self._enum_class):
return value.value return value.value
elif isinstance(value, str):
self._enum_class(value)
return value
else:
raise TypeError(f"expected str or {self._enum_class}, got {type(value)}")
# Since _E is bound to StrEnum which inherits from str, at this point value must be str
self._enum_class(value)
return value


def load_dialect_impl(self, dialect):
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
return dialect.type_descriptor(VARCHAR(self._length)) return dialect.type_descriptor(VARCHAR(self._length))


def process_result_value(self, value, dialect) -> _E | None:
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
if value is None: if value is None:
return value return value
if not isinstance(value, str):
raise TypeError(f"expected str, got {type(value)}")
# Type annotation guarantees value is str at this point
return self._enum_class(value) return self._enum_class(value)


def compare_values(self, x, y):
def compare_values(self, x: _E | None, y: _E | None) -> bool:
if x is None or y is None: if x is None or y is None:
return x is y return x is y
return x == y return x == y

+ 31
- 31
api/models/workflow.py Näytä tiedosto

from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4


import sqlalchemy as sa import sqlalchemy as sa
raise WorkflowDataError("nodes not found in workflow graph") raise WorkflowDataError("nodes not found in workflow graph")


try: try:
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes))
except StopIteration: except StopIteration:
raise NodeNotFoundError(node_id) raise NodeNotFoundError(node_id)
assert isinstance(node_config, dict) assert isinstance(node_config, dict)
def features_dict(self) -> dict[str, Any]: def features_dict(self) -> dict[str, Any]:
return json.loads(self.features) if self.features else {} return json.loads(self.features) if self.features else {}


def user_input_form(self, to_old_structure: bool = False):
def user_input_form(self, to_old_structure: bool = False) -> list[Any]:
# get start node from graph # get start node from graph
if not self.graph: if not self.graph:
return [] return []
variables: list[Any] = start_node.get("data", {}).get("variables", []) variables: list[Any] = start_node.get("data", {}).get("variables", [])


if to_old_structure: if to_old_structure:
old_structure_variables = []
old_structure_variables: list[dict[str, Any]] = []
for variable in variables: for variable in variables:
old_structure_variables.append({variable["type"]: variable}) old_structure_variables.append({variable["type"]: variable})




@property @property
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# _environment_variables is guaranteed to be non-None due to server_default="{}"


# Use workflow.tenant_id to avoid relying on request user in background threads # Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id tenant_id = self.tenant_id
] ]


# decrypt secret variables value # decrypt secret variables value
def decrypt_func(var):
def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var return var
else: else:
raise AssertionError("this statement should be unreachable.")
# Other variable types are not supported for environment variables
raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}")


decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
map(decrypt_func, results)
)
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [
decrypt_func(var) for var in results
]
return decrypted_results return decrypted_results


@environment_variables.setter @environment_variables.setter
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})


# encrypt secret variables value # encrypt secret variables value
def encrypt_func(var):
def encrypt_func(var: Variable) -> Variable:
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else: else:


@property @property
def conversation_variables(self) -> Sequence[Variable]: def conversation_variables(self) -> Sequence[Variable]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
# _conversation_variables is guaranteed to be non-None due to server_default="{}"


variables_dict: dict[str, Any] = json.loads(self._conversation_variables) variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
} }


@classmethod @classmethod
def from_dict(cls, data: dict) -> "WorkflowRun":
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls( return cls(
id=data.get("id"), id=data.get("id"),
tenant_id=data.get("tenant_id"), tenant_id=data.get("tenant_id"),
__tablename__ = "workflow_node_executions" __tablename__ = "workflow_node_executions"


@declared_attr @declared_attr
def __table_args__(cls): # noqa
@classmethod
def __table_args__(cls) -> Any:
return ( return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index( Index(
# MyPy may flag the following line because it doesn't recognize that # MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first # the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes. # argument to this method, allowing us to reference class attributes.
cls.created_at.desc(), # type: ignore
cls.created_at.desc(),
), ),
) )


return json.loads(self.execution_metadata) if self.execution_metadata else {} return json.loads(self.execution_metadata) if self.execution_metadata else {}


@property @property
def extras(self):
def extras(self) -> dict[str, Any]:
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager


extras = {}
extras: dict[str, Any] = {}
if self.execution_metadata_dict: if self.execution_metadata_dict:
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType


if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict:
tool_info = self.execution_metadata_dict["tool_info"]
tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"]
extras["icon"] = ToolManager.get_tool_icon( extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
provider_type=tool_info["provider_type"], provider_type=tool_info["provider_type"],
# making this attribute harder to access from outside the class. # making this attribute harder to access from outside the class.
__value: Segment | None __value: Segment | None


def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
""" """
The constructor of `WorkflowDraftVariable` is not intended for The constructor of `WorkflowDraftVariable` is not intended for
direct use outside this file. Its solo purpose is setup private state direct use outside this file. Its solo purpose is setup private state
self.__value = None self.__value = None


def get_selector(self) -> list[str]: def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
selector: Any = json.loads(self.selector)
if not isinstance(selector, list): if not isinstance(selector, list):
logger.error( logger.error(
"invalid selector loaded from database, type=%s, value=%s", "invalid selector loaded from database, type=%s, value=%s",
type(selector),
type(selector).__name__,
self.selector, self.selector,
) )
raise ValueError("invalid selector.") raise ValueError("invalid selector.")
return selector
return cast(list[str], selector)


def _set_selector(self, value: list[str]): def _set_selector(self, value: list[str]):
self.selector = json.dumps(value) self.selector = json.dumps(value)
# `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
if isinstance(value, dict): if isinstance(value, dict):
if not maybe_file_object(value): if not maybe_file_object(value):
return value
return cast(Any, value)
return File.model_validate(value) return File.model_validate(value)
elif isinstance(value, list) and value: elif isinstance(value, list) and value:
first = value[0]
value_list = cast(list[Any], value)
first: Any = value_list[0]
if not maybe_file_object(first): if not maybe_file_object(first):
return value
return [File.model_validate(i) for i in value]
return cast(Any, value)
file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list]
return cast(Any, file_list)
else: else:
return value
return cast(Any, value)


@classmethod @classmethod
def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:

+ 0
- 1
api/pyrightconfig.json Näytä tiedosto

"tests/", "tests/",
"migrations/", "migrations/",
".venv/", ".venv/",
"models/",
"core/", "core/",
"controllers/", "controllers/",
"tasks/", "tasks/",

+ 2
- 2
api/services/agent_service.py Näytä tiedosto

import threading import threading
from typing import Optional
from typing import Any, Optional


import pytz import pytz
from flask_login import current_user from flask_login import current_user
if not app_model_config: if not app_model_config:
raise ValueError("App model config not found") raise ValueError("App model config not found")


result = {
result: dict[str, Any] = {
"meta": { "meta": {
"status": "success", "status": "success",
"executor": executor, "executor": executor,

+ 4
- 1
api/services/app_service.py Näytä tiedosto

# get original app model config # get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
model_config = app.app_model_config model_config = app.app_model_config
if not model_config:
return app
agent_mode = model_config.agent_mode_dict agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input # decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get("tools") or []: for tool in agent_mode.get("tools") or []:
pass pass


# override agent mode # override agent mode
model_config.agent_mode = json.dumps(agent_mode)
if model_config:
model_config.agent_mode = json.dumps(agent_mode)


class ModifiedApp(App): class ModifiedApp(App):
""" """

+ 4
- 2
api/services/audio_service.py Näytä tiedosto

from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import MessageStatus from models.enums import MessageStatus
from models.model import App, AppMode, AppModelConfig, Message
from models.model import App, AppMode, Message
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
NoAudioUploadedServiceError, NoAudioUploadedServiceError,
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
raise ValueError("Speech to text is not enabled") raise ValueError("Speech to text is not enabled")
else: else:
app_model_config: AppModelConfig = app_model.app_model_config
app_model_config = app_model.app_model_config
if not app_model_config:
raise ValueError("Speech to text is not enabled")


if not app_model_config.speech_to_text_dict["enabled"]: if not app_model_config.speech_to_text_dict["enabled"]:
raise ValueError("Speech to text is not enabled") raise ValueError("Speech to text is not enabled")

+ 4
- 3
api/services/dataset_service.py Näytä tiedosto

file_ids = [ file_ids = [
document.data_source_info_dict["upload_file_id"] document.data_source_info_dict["upload_file_id"]
for document in documents for document in documents
if document.data_source_type == "upload_file"
if document.data_source_type == "upload_file" and document.data_source_info_dict
] ]
batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids)


# sync document indexing # sync document indexing
document.indexing_status = "waiting" document.indexing_status = "waiting"
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
data_source_info["mode"] = "scrape"
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
if data_source_info:
data_source_info["mode"] = "scrape"
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()



+ 3
- 2
api/services/external_knowledge_service.py Näytä tiedosto

) )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key")
settings = args.get("settings")
if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict:
settings["api_key"] = external_knowledge_api.settings_dict.get("api_key")


external_knowledge_api.name = args.get("name") external_knowledge_api.name = args.get("name")
external_knowledge_api.description = args.get("description", "") external_knowledge_api.description = args.get("description", "")

+ 1
- 1
api/services/tools/mcp_tools_manage_service.py Näytä tiedosto

def update_mcp_provider_credentials( def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
): ):
provider_controller = MCPToolProviderController._from_db(mcp_provider)
provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id, tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]

+ 2
- 2
api/tests/unit_tests/models/test_types_enum_text.py Näytä tiedosto

TestCase( TestCase(
name="session insert with invalid type", name="session insert with invalid type",
action=lambda s: _session_insert_with_value(s, 1), action=lambda s: _session_insert_with_value(s, 1),
exc_type=TypeError,
exc_type=ValueError,
), ),
TestCase( TestCase(
name="insert with invalid value", name="insert with invalid value",
TestCase( TestCase(
name="insert with invalid type", name="insert with invalid type",
action=lambda s: _insert_with_user(s, 1), action=lambda s: _insert_with_user(s, 1),
exc_type=TypeError,
exc_type=ValueError,
), ),
] ]
for idx, c in enumerate(cases, 1): for idx, c in enumerate(cases, 1):

Loading…
Peruuta
Tallenna