Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>tags/2.0.0-beta.2^2
| @@ -87,7 +87,7 @@ class BaseApiKeyListResource(Resource): | |||
| custom="max_keys_exceeded", | |||
| ) | |||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||
| key = ApiToken.generate_api_key(self.token_prefix or "", 24) | |||
| api_token = ApiToken() | |||
| setattr(api_token, self.resource_id_field, resource_id) | |||
| api_token.tenant_id = current_user.current_tenant_id | |||
| @@ -475,6 +475,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| data_source_info = document.data_source_info_dict | |||
| if document.data_source_type == "upload_file": | |||
| if not data_source_info: | |||
| continue | |||
| file_id = data_source_info["upload_file_id"] | |||
| file_detail = ( | |||
| db.session.query(UploadFile) | |||
| @@ -491,6 +493,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == "notion_import": | |||
| if not data_source_info: | |||
| continue | |||
| extract_setting = ExtractSetting( | |||
| datasource_type=DatasourceType.NOTION.value, | |||
| notion_info={ | |||
| @@ -503,6 +507,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == "website_crawl": | |||
| if not data_source_info: | |||
| continue | |||
| extract_setting = ExtractSetting( | |||
| datasource_type=DatasourceType.WEBSITE.value, | |||
| website_info={ | |||
| @@ -43,6 +43,8 @@ class ExploreAppMetaApi(InstalledAppResource): | |||
| def get(self, installed_app: InstalledApp): | |||
| """Get app meta""" | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise ValueError("App not found") | |||
| return AppService().get_app_meta(app_model) | |||
| @@ -35,6 +35,8 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): | |||
| Run workflow | |||
| """ | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise NotWorkflowAppError() | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| @@ -73,6 +75,8 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): | |||
| Stop workflow task | |||
| """ | |||
| app_model = installed_app.app | |||
| if not app_model: | |||
| raise NotWorkflowAppError() | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| @@ -262,6 +262,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| raise MessageNotExistsError() | |||
| current_app_model_config = app_model.app_model_config | |||
| if not current_app_model_config: | |||
| raise MoreLikeThisDisabledError() | |||
| more_like_this = current_app_model_config.more_like_this_dict | |||
| if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: | |||
| @@ -334,7 +334,8 @@ class NotionExtractor(BaseExtractor): | |||
| last_edited_time = self.get_notion_last_edited_time() | |||
| data_source_info = document_model.data_source_info_dict | |||
| data_source_info["last_edited_time"] = last_edited_time | |||
| if data_source_info: | |||
| data_source_info["last_edited_time"] = last_edited_time | |||
| db.session.query(DocumentModel).filter_by(id=document_model.id).update( | |||
| {DocumentModel.data_source_info: json.dumps(data_source_info)} | |||
| @@ -1,5 +1,5 @@ | |||
| import json | |||
| from typing import Any, Optional | |||
| from typing import Any, Optional, Self | |||
| from core.mcp.types import Tool as RemoteMCPTool | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| @@ -48,7 +48,7 @@ class MCPToolProviderController(ToolProviderController): | |||
| return ToolProviderType.MCP | |||
| @classmethod | |||
| def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": | |||
| def from_db(cls, db_provider: MCPToolProvider) -> Self: | |||
| """ | |||
| from db provider | |||
| """ | |||
| @@ -773,7 +773,7 @@ class ToolManager: | |||
| if provider is None: | |||
| raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") | |||
| controller = MCPToolProviderController._from_db(provider) | |||
| controller = MCPToolProviderController.from_db(provider) | |||
| return controller | |||
| @@ -928,7 +928,7 @@ class ToolManager: | |||
| tenant_id: str, | |||
| provider_type: ToolProviderType, | |||
| provider_id: str, | |||
| ) -> Union[str, dict]: | |||
| ) -> Union[str, dict[str, Any]]: | |||
| """ | |||
| get the tool icon | |||
| @@ -1,10 +1,10 @@ | |||
| import enum | |||
| import json | |||
| from datetime import datetime | |||
| from typing import Optional | |||
| from typing import Any, Optional | |||
| import sqlalchemy as sa | |||
| from flask_login import UserMixin | |||
| from flask_login import UserMixin # type: ignore[import-untyped] | |||
| from sqlalchemy import DateTime, String, func, select | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor | |||
| @@ -225,11 +225,11 @@ class Tenant(Base): | |||
| ) | |||
| @property | |||
| def custom_config_dict(self): | |||
| def custom_config_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.custom_config) if self.custom_config else {} | |||
| @custom_config_dict.setter | |||
| def custom_config_dict(self, value: dict): | |||
| def custom_config_dict(self, value: dict[str, Any]) -> None: | |||
| self.custom_config = json.dumps(value) | |||
| @@ -286,7 +286,7 @@ class DatasetProcessRule(Base): | |||
| "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | |||
| } | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "dataset_id": self.dataset_id, | |||
| @@ -295,7 +295,7 @@ class DatasetProcessRule(Base): | |||
| } | |||
| @property | |||
| def rules_dict(self): | |||
| def rules_dict(self) -> dict[str, Any] | None: | |||
| try: | |||
| return json.loads(self.rules) if self.rules else None | |||
| except JSONDecodeError: | |||
| @@ -392,10 +392,10 @@ class Document(Base): | |||
| return status | |||
| @property | |||
| def data_source_info_dict(self): | |||
| def data_source_info_dict(self) -> dict[str, Any] | None: | |||
| if self.data_source_info: | |||
| try: | |||
| data_source_info_dict = json.loads(self.data_source_info) | |||
| data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) | |||
| except JSONDecodeError: | |||
| data_source_info_dict = {} | |||
| @@ -403,10 +403,10 @@ class Document(Base): | |||
| return None | |||
| @property | |||
| def data_source_detail_dict(self): | |||
| def data_source_detail_dict(self) -> dict[str, Any]: | |||
| if self.data_source_info: | |||
| if self.data_source_type == "upload_file": | |||
| data_source_info_dict = json.loads(self.data_source_info) | |||
| data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) | |||
| file_detail = ( | |||
| db.session.query(UploadFile) | |||
| .where(UploadFile.id == data_source_info_dict["upload_file_id"]) | |||
| @@ -425,7 +425,8 @@ class Document(Base): | |||
| } | |||
| } | |||
| elif self.data_source_type in {"notion_import", "website_crawl"}: | |||
| return json.loads(self.data_source_info) | |||
| result: dict[str, Any] = json.loads(self.data_source_info) | |||
| return result | |||
| return {} | |||
| @property | |||
| @@ -471,7 +472,7 @@ class Document(Base): | |||
| return self.updated_at | |||
| @property | |||
| def doc_metadata_details(self): | |||
| def doc_metadata_details(self) -> list[dict[str, Any]] | None: | |||
| if self.doc_metadata: | |||
| document_metadatas = ( | |||
| db.session.query(DatasetMetadata) | |||
| @@ -481,9 +482,9 @@ class Document(Base): | |||
| ) | |||
| .all() | |||
| ) | |||
| metadata_list = [] | |||
| metadata_list: list[dict[str, Any]] = [] | |||
| for metadata in document_metadatas: | |||
| metadata_dict = { | |||
| metadata_dict: dict[str, Any] = { | |||
| "id": metadata.id, | |||
| "name": metadata.name, | |||
| "type": metadata.type, | |||
| @@ -497,13 +498,13 @@ class Document(Base): | |||
| return None | |||
| @property | |||
| def process_rule_dict(self): | |||
| if self.dataset_process_rule_id: | |||
| def process_rule_dict(self) -> dict[str, Any] | None: | |||
| if self.dataset_process_rule_id and self.dataset_process_rule: | |||
| return self.dataset_process_rule.to_dict() | |||
| return None | |||
| def get_built_in_fields(self): | |||
| built_in_fields = [] | |||
| def get_built_in_fields(self) -> list[dict[str, Any]]: | |||
| built_in_fields: list[dict[str, Any]] = [] | |||
| built_in_fields.append( | |||
| { | |||
| "id": "built-in", | |||
| @@ -546,7 +547,7 @@ class Document(Base): | |||
| ) | |||
| return built_in_fields | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "tenant_id": self.tenant_id, | |||
| @@ -592,13 +593,13 @@ class Document(Base): | |||
| "data_source_info_dict": self.data_source_info_dict, | |||
| "average_segment_length": self.average_segment_length, | |||
| "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | |||
| "dataset": self.dataset.to_dict() if self.dataset else None, | |||
| "dataset": None, # Dataset class doesn't have a to_dict method | |||
| "segment_count": self.segment_count, | |||
| "hit_count": self.hit_count, | |||
| } | |||
| @classmethod | |||
| def from_dict(cls, data: dict): | |||
| def from_dict(cls, data: dict[str, Any]): | |||
| return cls( | |||
| id=data.get("id"), | |||
| tenant_id=data.get("tenant_id"), | |||
| @@ -711,46 +712,48 @@ class DocumentSegment(Base): | |||
| ) | |||
| @property | |||
| def child_chunks(self): | |||
| process_rule = self.document.dataset_process_rule | |||
| if process_rule.mode == "hierarchical": | |||
| rules = Rule(**process_rule.rules_dict) | |||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| else: | |||
| return [] | |||
| else: | |||
| def child_chunks(self) -> list[Any]: | |||
| if not self.document: | |||
| return [] | |||
| def get_child_chunks(self): | |||
| process_rule = self.document.dataset_process_rule | |||
| if process_rule.mode == "hierarchical": | |||
| rules = Rule(**process_rule.rules_dict) | |||
| if rules.parent_mode: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| else: | |||
| return [] | |||
| else: | |||
| if process_rule and process_rule.mode == "hierarchical": | |||
| rules_dict = process_rule.rules_dict | |||
| if rules_dict: | |||
| rules = Rule(**rules_dict) | |||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| return [] | |||
| def get_child_chunks(self) -> list[Any]: | |||
| if not self.document: | |||
| return [] | |||
| process_rule = self.document.dataset_process_rule | |||
| if process_rule and process_rule.mode == "hierarchical": | |||
| rules_dict = process_rule.rules_dict | |||
| if rules_dict: | |||
| rules = Rule(**rules_dict) | |||
| if rules.parent_mode: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .where(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| return [] | |||
| @property | |||
| def sign_content(self): | |||
| def sign_content(self) -> str: | |||
| return self.get_sign_content() | |||
| def get_sign_content(self): | |||
| signed_urls = [] | |||
| def get_sign_content(self) -> str: | |||
| signed_urls: list[tuple[int, int, str]] = [] | |||
| text = self.content | |||
| # For data before v0.10.0 | |||
| @@ -890,17 +893,22 @@ class DatasetKeywordTable(Base): | |||
| ) | |||
| @property | |||
| def keyword_table_dict(self): | |||
| def keyword_table_dict(self) -> dict[str, set[Any]] | None: | |||
| class SetDecoder(json.JSONDecoder): | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(object_hook=self.object_hook, *args, **kwargs) | |||
| def object_hook(self, dct): | |||
| if isinstance(dct, dict): | |||
| for keyword, node_idxs in dct.items(): | |||
| if isinstance(node_idxs, list): | |||
| dct[keyword] = set(node_idxs) | |||
| return dct | |||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |||
| def object_hook(dct: Any) -> Any: | |||
| if isinstance(dct, dict): | |||
| result: dict[str, Any] = {} | |||
| items = cast(dict[str, Any], dct).items() | |||
| for keyword, node_idxs in items: | |||
| if isinstance(node_idxs, list): | |||
| result[keyword] = set(cast(list[Any], node_idxs)) | |||
| else: | |||
| result[keyword] = node_idxs | |||
| return result | |||
| return dct | |||
| super().__init__(object_hook=object_hook, *args, **kwargs) | |||
| # get dataset | |||
| dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() | |||
| @@ -1026,7 +1034,7 @@ class ExternalKnowledgeApis(Base): | |||
| updated_by = mapped_column(StringUUID, nullable=True) | |||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "tenant_id": self.tenant_id, | |||
| @@ -1039,14 +1047,14 @@ class ExternalKnowledgeApis(Base): | |||
| } | |||
| @property | |||
| def settings_dict(self): | |||
| def settings_dict(self) -> dict[str, Any] | None: | |||
| try: | |||
| return json.loads(self.settings) if self.settings else None | |||
| except JSONDecodeError: | |||
| return None | |||
| @property | |||
| def dataset_bindings(self): | |||
| def dataset_bindings(self) -> list[dict[str, Any]]: | |||
| external_knowledge_bindings = ( | |||
| db.session.query(ExternalKnowledgeBindings) | |||
| .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | |||
| @@ -1054,7 +1062,7 @@ class ExternalKnowledgeApis(Base): | |||
| ) | |||
| dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | |||
| datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() | |||
| dataset_bindings = [] | |||
| dataset_bindings: list[dict[str, Any]] = [] | |||
| for dataset in datasets: | |||
| dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | |||
| @@ -16,7 +16,7 @@ if TYPE_CHECKING: | |||
| import sqlalchemy as sa | |||
| from flask import request | |||
| from flask_login import UserMixin | |||
| from flask_login import UserMixin # type: ignore[import-untyped] | |||
| from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column | |||
| @@ -24,7 +24,7 @@ from configs import dify_config | |||
| from constants import DEFAULT_FILE_NUMBER_LIMITS | |||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType | |||
| from core.file import helpers as file_helpers | |||
| from libs.helper import generate_string | |||
| from libs.helper import generate_string # type: ignore[import-not-found] | |||
| from .account import Account, Tenant | |||
| from .base import Base | |||
| @@ -98,7 +98,7 @@ class App(Base): | |||
| use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | |||
| @property | |||
| def desc_or_prompt(self): | |||
| def desc_or_prompt(self) -> str: | |||
| if self.description: | |||
| return self.description | |||
| else: | |||
| @@ -109,12 +109,12 @@ class App(Base): | |||
| return "" | |||
| @property | |||
| def site(self): | |||
| def site(self) -> Optional["Site"]: | |||
| site = db.session.query(Site).where(Site.app_id == self.id).first() | |||
| return site | |||
| @property | |||
| def app_model_config(self): | |||
| def app_model_config(self) -> Optional["AppModelConfig"]: | |||
| if self.app_model_config_id: | |||
| return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() | |||
| @@ -130,11 +130,11 @@ class App(Base): | |||
| return None | |||
| @property | |||
| def api_base_url(self): | |||
| def api_base_url(self) -> str: | |||
| return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" | |||
| @property | |||
| def tenant(self): | |||
| def tenant(self) -> Optional[Tenant]: | |||
| tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | |||
| return tenant | |||
| @@ -162,7 +162,7 @@ class App(Base): | |||
| return str(self.mode) | |||
| @property | |||
| def deleted_tools(self): | |||
| def deleted_tools(self) -> list[dict[str, str]]: | |||
| from core.tools.tool_manager import ToolManager | |||
| from services.plugin.plugin_service import PluginService | |||
| @@ -242,7 +242,7 @@ class App(Base): | |||
| provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) | |||
| } | |||
| deleted_tools = [] | |||
| deleted_tools: list[dict[str, str]] = [] | |||
| for tool in tools: | |||
| keys = list(tool.keys()) | |||
| @@ -275,7 +275,7 @@ class App(Base): | |||
| return deleted_tools | |||
| @property | |||
| def tags(self): | |||
| def tags(self) -> list["Tag"]: | |||
| tags = ( | |||
| db.session.query(Tag) | |||
| .join(TagBinding, Tag.id == TagBinding.tag_id) | |||
| @@ -291,7 +291,7 @@ class App(Base): | |||
| return tags or [] | |||
| @property | |||
| def author_name(self): | |||
| def author_name(self) -> Optional[str]: | |||
| if self.created_by: | |||
| account = db.session.query(Account).where(Account.id == self.created_by).first() | |||
| if account: | |||
| @@ -334,20 +334,20 @@ class AppModelConfig(Base): | |||
| file_upload = mapped_column(sa.Text) | |||
| @property | |||
| def app(self): | |||
| def app(self) -> Optional[App]: | |||
| app = db.session.query(App).where(App.id == self.app_id).first() | |||
| return app | |||
| @property | |||
| def model_dict(self): | |||
| def model_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.model) if self.model else {} | |||
| @property | |||
| def suggested_questions_list(self): | |||
| def suggested_questions_list(self) -> list[str]: | |||
| return json.loads(self.suggested_questions) if self.suggested_questions else [] | |||
| @property | |||
| def suggested_questions_after_answer_dict(self): | |||
| def suggested_questions_after_answer_dict(self) -> dict[str, Any]: | |||
| return ( | |||
| json.loads(self.suggested_questions_after_answer) | |||
| if self.suggested_questions_after_answer | |||
| @@ -355,19 +355,19 @@ class AppModelConfig(Base): | |||
| ) | |||
| @property | |||
| def speech_to_text_dict(self): | |||
| def speech_to_text_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} | |||
| @property | |||
| def text_to_speech_dict(self): | |||
| def text_to_speech_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} | |||
| @property | |||
| def retriever_resource_dict(self): | |||
| def retriever_resource_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} | |||
| @property | |||
| def annotation_reply_dict(self): | |||
| def annotation_reply_dict(self) -> dict[str, Any]: | |||
| annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() | |||
| ) | |||
| @@ -390,11 +390,11 @@ class AppModelConfig(Base): | |||
| return {"enabled": False} | |||
| @property | |||
| def more_like_this_dict(self): | |||
| def more_like_this_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} | |||
| @property | |||
| def sensitive_word_avoidance_dict(self): | |||
| def sensitive_word_avoidance_dict(self) -> dict[str, Any]: | |||
| return ( | |||
| json.loads(self.sensitive_word_avoidance) | |||
| if self.sensitive_word_avoidance | |||
| @@ -402,15 +402,15 @@ class AppModelConfig(Base): | |||
| ) | |||
| @property | |||
| def external_data_tools_list(self) -> list[dict]: | |||
| def external_data_tools_list(self) -> list[dict[str, Any]]: | |||
| return json.loads(self.external_data_tools) if self.external_data_tools else [] | |||
| @property | |||
| def user_input_form_list(self): | |||
| def user_input_form_list(self) -> list[dict[str, Any]]: | |||
| return json.loads(self.user_input_form) if self.user_input_form else [] | |||
| @property | |||
| def agent_mode_dict(self): | |||
| def agent_mode_dict(self) -> dict[str, Any]: | |||
| return ( | |||
| json.loads(self.agent_mode) | |||
| if self.agent_mode | |||
| @@ -418,17 +418,17 @@ class AppModelConfig(Base): | |||
| ) | |||
| @property | |||
| def chat_prompt_config_dict(self): | |||
| def chat_prompt_config_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} | |||
| @property | |||
| def completion_prompt_config_dict(self): | |||
| def completion_prompt_config_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} | |||
| @property | |||
| def dataset_configs_dict(self): | |||
| def dataset_configs_dict(self) -> dict[str, Any]: | |||
| if self.dataset_configs: | |||
| dataset_configs: dict = json.loads(self.dataset_configs) | |||
| dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) | |||
| if "retrieval_model" not in dataset_configs: | |||
| return {"retrieval_model": "single"} | |||
| else: | |||
| @@ -438,7 +438,7 @@ class AppModelConfig(Base): | |||
| } | |||
| @property | |||
| def file_upload_dict(self): | |||
| def file_upload_dict(self) -> dict[str, Any]: | |||
| return ( | |||
| json.loads(self.file_upload) | |||
| if self.file_upload | |||
| @@ -452,7 +452,7 @@ class AppModelConfig(Base): | |||
| } | |||
| ) | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "opening_statement": self.opening_statement, | |||
| "suggested_questions": self.suggested_questions_list, | |||
| @@ -546,7 +546,7 @@ class RecommendedApp(Base): | |||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @property | |||
| def app(self): | |||
| def app(self) -> Optional[App]: | |||
| app = db.session.query(App).where(App.id == self.app_id).first() | |||
| return app | |||
| @@ -570,12 +570,12 @@ class InstalledApp(Base): | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @property | |||
| def app(self): | |||
| def app(self) -> Optional[App]: | |||
| app = db.session.query(App).where(App.id == self.app_id).first() | |||
| return app | |||
| @property | |||
| def tenant(self): | |||
| def tenant(self) -> Optional[Tenant]: | |||
| tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | |||
| return tenant | |||
| @@ -622,7 +622,7 @@ class Conversation(Base): | |||
| mode: Mapped[str] = mapped_column(String(255)) | |||
| name: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| summary = mapped_column(sa.Text) | |||
| _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) | |||
| _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) | |||
| introduction = mapped_column(sa.Text) | |||
| system_instruction = mapped_column(sa.Text) | |||
| system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | |||
| @@ -652,7 +652,7 @@ class Conversation(Base): | |||
| is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | |||
| @property | |||
| def inputs(self): | |||
| def inputs(self) -> dict[str, Any]: | |||
| inputs = self._inputs.copy() | |||
| # Convert file mapping to File object | |||
| @@ -660,22 +660,39 @@ class Conversation(Base): | |||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | |||
| from factories import file_factory | |||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value["tool_file_id"] = value["related_id"] | |||
| elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| value["upload_file_id"] = value["related_id"] | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||
| elif isinstance(value, list) and all( | |||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||
| if ( | |||
| isinstance(value, dict) | |||
| and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| ): | |||
| inputs[key] = [] | |||
| for item in value: | |||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item["tool_file_id"] = item["related_id"] | |||
| elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| item["upload_file_id"] = item["related_id"] | |||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||
| value_dict = cast(dict[str, Any], value) | |||
| if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value_dict["tool_file_id"] = value_dict["related_id"] | |||
| elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| value_dict["upload_file_id"] = value_dict["related_id"] | |||
| tenant_id = cast(str, value_dict.get("tenant_id", "")) | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) | |||
| elif isinstance(value, list): | |||
| value_list = cast(list[Any], value) | |||
| if all( | |||
| isinstance(item, dict) | |||
| and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| for item in value_list | |||
| ): | |||
| file_list: list[File] = [] | |||
| for item in value_list: | |||
| if not isinstance(item, dict): | |||
| continue | |||
| item_dict = cast(dict[str, Any], item) | |||
| if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item_dict["tool_file_id"] = item_dict["related_id"] | |||
| elif item_dict["transfer_method"] in [ | |||
| FileTransferMethod.LOCAL_FILE, | |||
| FileTransferMethod.REMOTE_URL, | |||
| ]: | |||
| item_dict["upload_file_id"] = item_dict["related_id"] | |||
| tenant_id = cast(str, item_dict.get("tenant_id", "")) | |||
| file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) | |||
| inputs[key] = file_list | |||
| return inputs | |||
| @@ -685,8 +702,10 @@ class Conversation(Base): | |||
| for k, v in inputs.items(): | |||
| if isinstance(v, File): | |||
| inputs[k] = v.model_dump() | |||
| elif isinstance(v, list) and all(isinstance(item, File) for item in v): | |||
| inputs[k] = [item.model_dump() for item in v] | |||
| elif isinstance(v, list): | |||
| v_list = cast(list[Any], v) | |||
| if all(isinstance(item, File) for item in v_list): | |||
| inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] | |||
| self._inputs = inputs | |||
| @property | |||
| @@ -826,7 +845,7 @@ class Conversation(Base): | |||
| ) | |||
| @property | |||
| def app(self): | |||
| def app(self) -> Optional[App]: | |||
| return db.session.query(App).where(App.id == self.app_id).first() | |||
| @property | |||
| @@ -839,7 +858,7 @@ class Conversation(Base): | |||
| return None | |||
| @property | |||
| def from_account_name(self): | |||
| def from_account_name(self) -> Optional[str]: | |||
| if self.from_account_id: | |||
| account = db.session.query(Account).where(Account.id == self.from_account_id).first() | |||
| if account: | |||
| @@ -848,10 +867,10 @@ class Conversation(Base): | |||
| return None | |||
| @property | |||
| def in_debug_mode(self): | |||
| def in_debug_mode(self) -> bool: | |||
| return self.override_model_configs is not None | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "app_id": self.app_id, | |||
| @@ -897,7 +916,7 @@ class Message(Base): | |||
| model_id = mapped_column(String(255), nullable=True) | |||
| override_model_configs = mapped_column(sa.Text) | |||
| conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) | |||
| _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) | |||
| _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) | |||
| query: Mapped[str] = mapped_column(sa.Text, nullable=False) | |||
| message = mapped_column(sa.JSON, nullable=False) | |||
| message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | |||
| @@ -924,28 +943,45 @@ class Message(Base): | |||
| workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) | |||
| @property | |||
| def inputs(self): | |||
| def inputs(self) -> dict[str, Any]: | |||
| inputs = self._inputs.copy() | |||
| for key, value in inputs.items(): | |||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | |||
| from factories import file_factory | |||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value["tool_file_id"] = value["related_id"] | |||
| elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| value["upload_file_id"] = value["related_id"] | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||
| elif isinstance(value, list) and all( | |||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||
| if ( | |||
| isinstance(value, dict) | |||
| and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| ): | |||
| inputs[key] = [] | |||
| for item in value: | |||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item["tool_file_id"] = item["related_id"] | |||
| elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| item["upload_file_id"] = item["related_id"] | |||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||
| value_dict = cast(dict[str, Any], value) | |||
| if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| value_dict["tool_file_id"] = value_dict["related_id"] | |||
| elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||
| value_dict["upload_file_id"] = value_dict["related_id"] | |||
| tenant_id = cast(str, value_dict.get("tenant_id", "")) | |||
| inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) | |||
| elif isinstance(value, list): | |||
| value_list = cast(list[Any], value) | |||
| if all( | |||
| isinstance(item, dict) | |||
| and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||
| for item in value_list | |||
| ): | |||
| file_list: list[File] = [] | |||
| for item in value_list: | |||
| if not isinstance(item, dict): | |||
| continue | |||
| item_dict = cast(dict[str, Any], item) | |||
| if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||
| item_dict["tool_file_id"] = item_dict["related_id"] | |||
| elif item_dict["transfer_method"] in [ | |||
| FileTransferMethod.LOCAL_FILE, | |||
| FileTransferMethod.REMOTE_URL, | |||
| ]: | |||
| item_dict["upload_file_id"] = item_dict["related_id"] | |||
| tenant_id = cast(str, item_dict.get("tenant_id", "")) | |||
| file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) | |||
| inputs[key] = file_list | |||
| return inputs | |||
| @inputs.setter | |||
| @@ -954,8 +990,10 @@ class Message(Base): | |||
| for k, v in inputs.items(): | |||
| if isinstance(v, File): | |||
| inputs[k] = v.model_dump() | |||
| elif isinstance(v, list) and all(isinstance(item, File) for item in v): | |||
| inputs[k] = [item.model_dump() for item in v] | |||
| elif isinstance(v, list): | |||
| v_list = cast(list[Any], v) | |||
| if all(isinstance(item, File) for item in v_list): | |||
| inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] | |||
| self._inputs = inputs | |||
| @property | |||
| @@ -1083,15 +1121,15 @@ class Message(Base): | |||
| return None | |||
| @property | |||
| def in_debug_mode(self): | |||
| def in_debug_mode(self) -> bool: | |||
| return self.override_model_configs is not None | |||
| @property | |||
| def message_metadata_dict(self): | |||
| def message_metadata_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.message_metadata) if self.message_metadata else {} | |||
| @property | |||
| def agent_thoughts(self): | |||
| def agent_thoughts(self) -> list["MessageAgentThought"]: | |||
| return ( | |||
| db.session.query(MessageAgentThought) | |||
| .where(MessageAgentThought.message_id == self.id) | |||
| @@ -1100,11 +1138,11 @@ class Message(Base): | |||
| ) | |||
| @property | |||
| def retriever_resources(self): | |||
| def retriever_resources(self) -> Any | list[Any]: | |||
| return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] | |||
| @property | |||
| def message_files(self): | |||
| def message_files(self) -> list[dict[str, Any]]: | |||
| from factories import file_factory | |||
| message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() | |||
| @@ -1112,7 +1150,7 @@ class Message(Base): | |||
| if not current_app: | |||
| raise ValueError(f"App {self.app_id} not found") | |||
| files = [] | |||
| files: list[File] = [] | |||
| for message_file in message_files: | |||
| if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: | |||
| if message_file.upload_file_id is None: | |||
| @@ -1159,7 +1197,7 @@ class Message(Base): | |||
| ) | |||
| files.append(file) | |||
| result = [ | |||
| result: list[dict[str, Any]] = [ | |||
| {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} | |||
| for (file, message_file) in zip(files, message_files) | |||
| ] | |||
| @@ -1176,7 +1214,7 @@ class Message(Base): | |||
| return None | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "app_id": self.app_id, | |||
| @@ -1200,7 +1238,7 @@ class Message(Base): | |||
| } | |||
| @classmethod | |||
| def from_dict(cls, data: dict): | |||
| def from_dict(cls, data: dict[str, Any]) -> "Message": | |||
| return cls( | |||
| id=data["id"], | |||
| app_id=data["app_id"], | |||
| @@ -1250,7 +1288,7 @@ class MessageFeedback(Base): | |||
| account = db.session.query(Account).where(Account.id == self.from_account_id).first() | |||
| return account | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": str(self.id), | |||
| "app_id": str(self.app_id), | |||
| @@ -1435,7 +1473,18 @@ class EndUser(Base, UserMixin): | |||
| type: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| external_user_id = mapped_column(String(255), nullable=True) | |||
| name = mapped_column(String(255)) | |||
| is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | |||
| _is_anonymous: Mapped[bool] = mapped_column( | |||
| "is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true") | |||
| ) | |||
| @property | |||
| def is_anonymous(self) -> Literal[False]: | |||
| return False | |||
| @is_anonymous.setter | |||
| def is_anonymous(self, value: bool) -> None: | |||
| self._is_anonymous = value | |||
| session_id: Mapped[str] = mapped_column() | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @@ -1461,7 +1510,7 @@ class AppMCPServer(Base): | |||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @staticmethod | |||
| def generate_server_code(n): | |||
| def generate_server_code(n: int) -> str: | |||
| while True: | |||
| result = generate_string(n) | |||
| while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: | |||
| @@ -1518,7 +1567,7 @@ class Site(Base): | |||
| self._custom_disclaimer = value | |||
| @staticmethod | |||
| def generate_code(n): | |||
| def generate_code(n: int) -> str: | |||
| while True: | |||
| result = generate_string(n) | |||
| while db.session.query(Site).where(Site.code == result).count() > 0: | |||
| @@ -1549,7 +1598,7 @@ class ApiToken(Base): | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @staticmethod | |||
| def generate_api_key(prefix, n): | |||
| def generate_api_key(prefix: str, n: int) -> str: | |||
| while True: | |||
| result = prefix + generate_string(n) | |||
| if db.session.scalar(select(exists().where(ApiToken.token == result))): | |||
| @@ -1689,7 +1738,7 @@ class MessageAgentThought(Base): | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| @property | |||
| def files(self): | |||
| def files(self) -> list[Any]: | |||
| if self.message_files: | |||
| return cast(list[Any], json.loads(self.message_files)) | |||
| else: | |||
| @@ -1700,32 +1749,32 @@ class MessageAgentThought(Base): | |||
| return self.tool.split(";") if self.tool else [] | |||
| @property | |||
| def tool_labels(self): | |||
| def tool_labels(self) -> dict[str, Any]: | |||
| try: | |||
| if self.tool_labels_str: | |||
| return cast(dict, json.loads(self.tool_labels_str)) | |||
| return cast(dict[str, Any], json.loads(self.tool_labels_str)) | |||
| else: | |||
| return {} | |||
| except Exception: | |||
| return {} | |||
| @property | |||
| def tool_meta(self): | |||
| def tool_meta(self) -> dict[str, Any]: | |||
| try: | |||
| if self.tool_meta_str: | |||
| return cast(dict, json.loads(self.tool_meta_str)) | |||
| return cast(dict[str, Any], json.loads(self.tool_meta_str)) | |||
| else: | |||
| return {} | |||
| except Exception: | |||
| return {} | |||
| @property | |||
| def tool_inputs_dict(self): | |||
| def tool_inputs_dict(self) -> dict[str, Any]: | |||
| tools = self.tools | |||
| try: | |||
| if self.tool_input: | |||
| data = json.loads(self.tool_input) | |||
| result = {} | |||
| result: dict[str, Any] = {} | |||
| for tool in tools: | |||
| if tool in data: | |||
| result[tool] = data[tool] | |||
| @@ -1741,12 +1790,12 @@ class MessageAgentThought(Base): | |||
| return {} | |||
| @property | |||
| def tool_outputs_dict(self): | |||
| def tool_outputs_dict(self) -> dict[str, Any]: | |||
| tools = self.tools | |||
| try: | |||
| if self.observation: | |||
| data = json.loads(self.observation) | |||
| result = {} | |||
| result: dict[str, Any] = {} | |||
| for tool in tools: | |||
| if tool in data: | |||
| result[tool] = data[tool] | |||
| @@ -1844,14 +1893,14 @@ class TraceAppConfig(Base): | |||
| is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | |||
| @property | |||
| def tracing_config_dict(self): | |||
| def tracing_config_dict(self) -> dict[str, Any]: | |||
| return self.tracing_config or {} | |||
| @property | |||
| def tracing_config_str(self): | |||
| def tracing_config_str(self) -> str: | |||
| return json.dumps(self.tracing_config_dict) | |||
| def to_dict(self): | |||
| def to_dict(self) -> dict[str, Any]: | |||
| return { | |||
| "id": self.id, | |||
| "app_id": self.app_id, | |||
| @@ -17,7 +17,7 @@ class ProviderType(Enum): | |||
| SYSTEM = "system" | |||
| @staticmethod | |||
| def value_of(value): | |||
| def value_of(value: str) -> "ProviderType": | |||
| for member in ProviderType: | |||
| if member.value == value: | |||
| return member | |||
| @@ -35,7 +35,7 @@ class ProviderQuotaType(Enum): | |||
| """hosted trial quota""" | |||
| @staticmethod | |||
| def value_of(value): | |||
| def value_of(value: str) -> "ProviderQuotaType": | |||
| for member in ProviderQuotaType: | |||
| if member.value == value: | |||
| return member | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| from datetime import datetime | |||
| from typing import Optional, cast | |||
| from typing import Any, Optional, cast | |||
| from urllib.parse import urlparse | |||
| import sqlalchemy as sa | |||
| @@ -54,8 +54,8 @@ class ToolOAuthTenantClient(Base): | |||
| encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) | |||
| @property | |||
| def oauth_params(self): | |||
| return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) | |||
| def oauth_params(self) -> dict[str, Any]: | |||
| return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}")) | |||
| class BuiltinToolProvider(Base): | |||
| @@ -96,8 +96,8 @@ class BuiltinToolProvider(Base): | |||
| expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) | |||
| @property | |||
| def credentials(self): | |||
| return cast(dict, json.loads(self.encrypted_credentials)) | |||
| def credentials(self) -> dict[str, Any]: | |||
| return cast(dict[str, Any], json.loads(self.encrypted_credentials)) | |||
| class ApiToolProvider(Base): | |||
| @@ -146,8 +146,8 @@ class ApiToolProvider(Base): | |||
| return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] | |||
| @property | |||
| def credentials(self): | |||
| return dict(json.loads(self.credentials_str)) | |||
| def credentials(self) -> dict[str, Any]: | |||
| return dict[str, Any](json.loads(self.credentials_str)) | |||
| @property | |||
| def user(self) -> Account | None: | |||
| @@ -289,9 +289,9 @@ class MCPToolProvider(Base): | |||
| return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | |||
| @property | |||
| def credentials(self): | |||
| def credentials(self) -> dict[str, Any]: | |||
| try: | |||
| return cast(dict, json.loads(self.encrypted_credentials)) or {} | |||
| return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {} | |||
| except Exception: | |||
| return {} | |||
| @@ -327,12 +327,12 @@ class MCPToolProvider(Base): | |||
| return mask_url(self.decrypted_server_url) | |||
| @property | |||
| def decrypted_credentials(self): | |||
| def decrypted_credentials(self) -> dict[str, Any]: | |||
| from core.helper.provider_cache import NoOpProviderCredentialCache | |||
| from core.tools.mcp_tool.provider import MCPToolProviderController | |||
| from core.tools.utils.encryption import create_provider_encrypter | |||
| provider_controller = MCPToolProviderController._from_db(self) | |||
| provider_controller = MCPToolProviderController.from_db(self) | |||
| encrypter, _ = create_provider_encrypter( | |||
| tenant_id=self.tenant_id, | |||
| @@ -340,7 +340,7 @@ class MCPToolProvider(Base): | |||
| cache=NoOpProviderCredentialCache(), | |||
| ) | |||
| return encrypter.decrypt(self.credentials) # type: ignore | |||
| return encrypter.decrypt(self.credentials) | |||
| class ToolModelInvoke(Base): | |||
| @@ -1,29 +1,34 @@ | |||
| import enum | |||
| from typing import Generic, TypeVar | |||
| import uuid | |||
| from typing import Any, Generic, TypeVar | |||
| from sqlalchemy import CHAR, VARCHAR, TypeDecorator | |||
| from sqlalchemy.dialects.postgresql import UUID | |||
| from sqlalchemy.engine.interfaces import Dialect | |||
| from sqlalchemy.sql.type_api import TypeEngine | |||
| class StringUUID(TypeDecorator): | |||
| class StringUUID(TypeDecorator[uuid.UUID | str | None]): | |||
| impl = CHAR | |||
| cache_ok = True | |||
| def process_bind_param(self, value, dialect): | |||
| def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: | |||
| if value is None: | |||
| return value | |||
| elif dialect.name == "postgresql": | |||
| return str(value) | |||
| else: | |||
| return value.hex | |||
| if isinstance(value, uuid.UUID): | |||
| return value.hex | |||
| return value | |||
| def load_dialect_impl(self, dialect): | |||
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: | |||
| if dialect.name == "postgresql": | |||
| return dialect.type_descriptor(UUID()) | |||
| else: | |||
| return dialect.type_descriptor(CHAR(36)) | |||
| def process_result_value(self, value, dialect): | |||
| def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: | |||
| if value is None: | |||
| return value | |||
| return str(value) | |||
| @@ -32,7 +37,7 @@ class StringUUID(TypeDecorator): | |||
| _E = TypeVar("_E", bound=enum.StrEnum) | |||
| class EnumText(TypeDecorator, Generic[_E]): | |||
| class EnumText(TypeDecorator[_E | None], Generic[_E]): | |||
| impl = VARCHAR | |||
| cache_ok = True | |||
| @@ -50,28 +55,25 @@ class EnumText(TypeDecorator, Generic[_E]): | |||
| # leave some rooms for future longer enum values. | |||
| self._length = max(max_enum_value_len, 20) | |||
| def process_bind_param(self, value: _E | str | None, dialect): | |||
| def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None: | |||
| if value is None: | |||
| return value | |||
| if isinstance(value, self._enum_class): | |||
| return value.value | |||
| elif isinstance(value, str): | |||
| self._enum_class(value) | |||
| return value | |||
| else: | |||
| raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") | |||
| # Since _E is bound to StrEnum which inherits from str, at this point value must be str | |||
| self._enum_class(value) | |||
| return value | |||
| def load_dialect_impl(self, dialect): | |||
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: | |||
| return dialect.type_descriptor(VARCHAR(self._length)) | |||
| def process_result_value(self, value, dialect) -> _E | None: | |||
| def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None: | |||
| if value is None: | |||
| return value | |||
| if not isinstance(value, str): | |||
| raise TypeError(f"expected str, got {type(value)}") | |||
| # Type annotation guarantees value is str at this point | |||
| return self._enum_class(value) | |||
| def compare_values(self, x, y): | |||
| def compare_values(self, x: _E | None, y: _E | None) -> bool: | |||
| if x is None or y is None: | |||
| return x is y | |||
| return x == y | |||
| @@ -3,7 +3,7 @@ import logging | |||
| from collections.abc import Mapping, Sequence | |||
| from datetime import datetime | |||
| from enum import Enum, StrEnum | |||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||
| from uuid import uuid4 | |||
| import sqlalchemy as sa | |||
| @@ -224,7 +224,7 @@ class Workflow(Base): | |||
| raise WorkflowDataError("nodes not found in workflow graph") | |||
| try: | |||
| node_config = next(filter(lambda node: node["id"] == node_id, nodes)) | |||
| node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) | |||
| except StopIteration: | |||
| raise NodeNotFoundError(node_id) | |||
| assert isinstance(node_config, dict) | |||
| @@ -289,7 +289,7 @@ class Workflow(Base): | |||
| def features_dict(self) -> dict[str, Any]: | |||
| return json.loads(self.features) if self.features else {} | |||
| def user_input_form(self, to_old_structure: bool = False): | |||
| def user_input_form(self, to_old_structure: bool = False) -> list[Any]: | |||
| # get start node from graph | |||
| if not self.graph: | |||
| return [] | |||
| @@ -306,7 +306,7 @@ class Workflow(Base): | |||
| variables: list[Any] = start_node.get("data", {}).get("variables", []) | |||
| if to_old_structure: | |||
| old_structure_variables = [] | |||
| old_structure_variables: list[dict[str, Any]] = [] | |||
| for variable in variables: | |||
| old_structure_variables.append({variable["type"]: variable}) | |||
| @@ -346,9 +346,7 @@ class Workflow(Base): | |||
| @property | |||
| def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: | |||
| # TODO: find some way to init `self._environment_variables` when instance created. | |||
| if self._environment_variables is None: | |||
| self._environment_variables = "{}" | |||
| # _environment_variables is guaranteed to be non-None due to server_default="{}" | |||
| # Use workflow.tenant_id to avoid relying on request user in background threads | |||
| tenant_id = self.tenant_id | |||
| @@ -362,17 +360,18 @@ class Workflow(Base): | |||
| ] | |||
| # decrypt secret variables value | |||
| def decrypt_func(var): | |||
| def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: | |||
| if isinstance(var, SecretVariable): | |||
| return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) | |||
| elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): | |||
| return var | |||
| else: | |||
| raise AssertionError("this statement should be unreachable.") | |||
| # Other variable types are not supported for environment variables | |||
| raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}") | |||
| decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( | |||
| map(decrypt_func, results) | |||
| ) | |||
| decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [ | |||
| decrypt_func(var) for var in results | |||
| ] | |||
| return decrypted_results | |||
| @environment_variables.setter | |||
| @@ -400,7 +399,7 @@ class Workflow(Base): | |||
| value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) | |||
| # encrypt secret variables value | |||
| def encrypt_func(var): | |||
| def encrypt_func(var: Variable) -> Variable: | |||
| if isinstance(var, SecretVariable): | |||
| return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) | |||
| else: | |||
| @@ -430,9 +429,7 @@ class Workflow(Base): | |||
| @property | |||
| def conversation_variables(self) -> Sequence[Variable]: | |||
| # TODO: find some way to init `self._conversation_variables` when instance created. | |||
| if self._conversation_variables is None: | |||
| self._conversation_variables = "{}" | |||
| # _conversation_variables is guaranteed to be non-None due to server_default="{}" | |||
| variables_dict: dict[str, Any] = json.loads(self._conversation_variables) | |||
| results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] | |||
| @@ -577,7 +574,7 @@ class WorkflowRun(Base): | |||
| } | |||
| @classmethod | |||
| def from_dict(cls, data: dict) -> "WorkflowRun": | |||
| def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": | |||
| return cls( | |||
| id=data.get("id"), | |||
| tenant_id=data.get("tenant_id"), | |||
| @@ -662,7 +659,8 @@ class WorkflowNodeExecutionModel(Base): | |||
| __tablename__ = "workflow_node_executions" | |||
| @declared_attr | |||
| def __table_args__(cls): # noqa | |||
| @classmethod | |||
| def __table_args__(cls) -> Any: | |||
| return ( | |||
| PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), | |||
| Index( | |||
| @@ -699,7 +697,7 @@ class WorkflowNodeExecutionModel(Base): | |||
| # MyPy may flag the following line because it doesn't recognize that | |||
| # the `declared_attr` decorator passes the receiving class as the first | |||
| # argument to this method, allowing us to reference class attributes. | |||
| cls.created_at.desc(), # type: ignore | |||
| cls.created_at.desc(), | |||
| ), | |||
| ) | |||
| @@ -761,15 +759,15 @@ class WorkflowNodeExecutionModel(Base): | |||
| return json.loads(self.execution_metadata) if self.execution_metadata else {} | |||
| @property | |||
| def extras(self): | |||
| def extras(self) -> dict[str, Any]: | |||
| from core.tools.tool_manager import ToolManager | |||
| extras = {} | |||
| extras: dict[str, Any] = {} | |||
| if self.execution_metadata_dict: | |||
| from core.workflow.nodes import NodeType | |||
| if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: | |||
| tool_info = self.execution_metadata_dict["tool_info"] | |||
| tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] | |||
| extras["icon"] = ToolManager.get_tool_icon( | |||
| tenant_id=self.tenant_id, | |||
| provider_type=tool_info["provider_type"], | |||
| @@ -1037,7 +1035,7 @@ class WorkflowDraftVariable(Base): | |||
| # making this attribute harder to access from outside the class. | |||
| __value: Segment | None | |||
| def __init__(self, *args, **kwargs): | |||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |||
| """ | |||
| The constructor of `WorkflowDraftVariable` is not intended for | |||
| direct use outside this file. Its solo purpose is setup private state | |||
| @@ -1055,15 +1053,15 @@ class WorkflowDraftVariable(Base): | |||
| self.__value = None | |||
| def get_selector(self) -> list[str]: | |||
| selector = json.loads(self.selector) | |||
| selector: Any = json.loads(self.selector) | |||
| if not isinstance(selector, list): | |||
| logger.error( | |||
| "invalid selector loaded from database, type=%s, value=%s", | |||
| type(selector), | |||
| type(selector).__name__, | |||
| self.selector, | |||
| ) | |||
| raise ValueError("invalid selector.") | |||
| return selector | |||
| return cast(list[str], selector) | |||
| def _set_selector(self, value: list[str]): | |||
| self.selector = json.dumps(value) | |||
| @@ -1086,15 +1084,17 @@ class WorkflowDraftVariable(Base): | |||
| # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. | |||
| if isinstance(value, dict): | |||
| if not maybe_file_object(value): | |||
| return value | |||
| return cast(Any, value) | |||
| return File.model_validate(value) | |||
| elif isinstance(value, list) and value: | |||
| first = value[0] | |||
| value_list = cast(list[Any], value) | |||
| first: Any = value_list[0] | |||
| if not maybe_file_object(first): | |||
| return value | |||
| return [File.model_validate(i) for i in value] | |||
| return cast(Any, value) | |||
| file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] | |||
| return cast(Any, file_list) | |||
| else: | |||
| return value | |||
| return cast(Any, value) | |||
| @classmethod | |||
| def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: | |||
| @@ -6,7 +6,6 @@ | |||
| "tests/", | |||
| "migrations/", | |||
| ".venv/", | |||
| "models/", | |||
| "core/", | |||
| "controllers/", | |||
| "tasks/", | |||
| @@ -1,5 +1,5 @@ | |||
| import threading | |||
| from typing import Optional | |||
| from typing import Any, Optional | |||
| import pytz | |||
| from flask_login import current_user | |||
| @@ -68,7 +68,7 @@ class AgentService: | |||
| if not app_model_config: | |||
| raise ValueError("App model config not found") | |||
| result = { | |||
| result: dict[str, Any] = { | |||
| "meta": { | |||
| "status": "success", | |||
| "executor": executor, | |||
| @@ -171,6 +171,8 @@ class AppService: | |||
| # get original app model config | |||
| if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: | |||
| model_config = app.app_model_config | |||
| if not model_config: | |||
| return app | |||
| agent_mode = model_config.agent_mode_dict | |||
| # decrypt agent tool parameters if it's secret-input | |||
| for tool in agent_mode.get("tools") or []: | |||
| @@ -205,7 +207,8 @@ class AppService: | |||
| pass | |||
| # override agent mode | |||
| model_config.agent_mode = json.dumps(agent_mode) | |||
| if model_config: | |||
| model_config.agent_mode = json.dumps(agent_mode) | |||
| class ModifiedApp(App): | |||
| """ | |||
| @@ -12,7 +12,7 @@ from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from extensions.ext_database import db | |||
| from models.enums import MessageStatus | |||
| from models.model import App, AppMode, AppModelConfig, Message | |||
| from models.model import App, AppMode, Message | |||
| from services.errors.audio import ( | |||
| AudioTooLargeServiceError, | |||
| NoAudioUploadedServiceError, | |||
| @@ -40,7 +40,9 @@ class AudioService: | |||
| if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): | |||
| raise ValueError("Speech to text is not enabled") | |||
| else: | |||
| app_model_config: AppModelConfig = app_model.app_model_config | |||
| app_model_config = app_model.app_model_config | |||
| if not app_model_config: | |||
| raise ValueError("Speech to text is not enabled") | |||
| if not app_model_config.speech_to_text_dict["enabled"]: | |||
| raise ValueError("Speech to text is not enabled") | |||
| @@ -973,7 +973,7 @@ class DocumentService: | |||
| file_ids = [ | |||
| document.data_source_info_dict["upload_file_id"] | |||
| for document in documents | |||
| if document.data_source_type == "upload_file" | |||
| if document.data_source_type == "upload_file" and document.data_source_info_dict | |||
| ] | |||
| batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) | |||
| @@ -1067,8 +1067,9 @@ class DocumentService: | |||
| # sync document indexing | |||
| document.indexing_status = "waiting" | |||
| data_source_info = document.data_source_info_dict | |||
| data_source_info["mode"] = "scrape" | |||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||
| if data_source_info: | |||
| data_source_info["mode"] = "scrape" | |||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| @@ -114,8 +114,9 @@ class ExternalDatasetService: | |||
| ) | |||
| if external_knowledge_api is None: | |||
| raise ValueError("api template not found") | |||
| if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: | |||
| args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") | |||
| settings = args.get("settings") | |||
| if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict: | |||
| settings["api_key"] = external_knowledge_api.settings_dict.get("api_key") | |||
| external_knowledge_api.name = args.get("name") | |||
| external_knowledge_api.description = args.get("description", "") | |||
| @@ -226,7 +226,7 @@ class MCPToolManageService: | |||
| def update_mcp_provider_credentials( | |||
| cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False | |||
| ): | |||
| provider_controller = MCPToolProviderController._from_db(mcp_provider) | |||
| provider_controller = MCPToolProviderController.from_db(mcp_provider) | |||
| tool_configuration = ProviderConfigEncrypter( | |||
| tenant_id=mcp_provider.tenant_id, | |||
| config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] | |||
| @@ -154,7 +154,7 @@ class TestEnumText: | |||
| TestCase( | |||
| name="session insert with invalid type", | |||
| action=lambda s: _session_insert_with_value(s, 1), | |||
| exc_type=TypeError, | |||
| exc_type=ValueError, | |||
| ), | |||
| TestCase( | |||
| name="insert with invalid value", | |||
| @@ -164,7 +164,7 @@ class TestEnumText: | |||
| TestCase( | |||
| name="insert with invalid type", | |||
| action=lambda s: _insert_with_user(s, 1), | |||
| exc_type=TypeError, | |||
| exc_type=ValueError, | |||
| ), | |||
| ] | |||
| for idx, c in enumerate(cases, 1): | |||