Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>tags/2.0.0-beta.2^2
| custom="max_keys_exceeded", | custom="max_keys_exceeded", | ||||
| ) | ) | ||||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||||
| key = ApiToken.generate_api_key(self.token_prefix or "", 24) | |||||
| api_token = ApiToken() | api_token = ApiToken() | ||||
| setattr(api_token, self.resource_id_field, resource_id) | setattr(api_token, self.resource_id_field, resource_id) | ||||
| api_token.tenant_id = current_user.current_tenant_id | api_token.tenant_id = current_user.current_tenant_id |
| data_source_info = document.data_source_info_dict | data_source_info = document.data_source_info_dict | ||||
| if document.data_source_type == "upload_file": | if document.data_source_type == "upload_file": | ||||
| if not data_source_info: | |||||
| continue | |||||
| file_id = data_source_info["upload_file_id"] | file_id = data_source_info["upload_file_id"] | ||||
| file_detail = ( | file_detail = ( | ||||
| db.session.query(UploadFile) | db.session.query(UploadFile) | ||||
| extract_settings.append(extract_setting) | extract_settings.append(extract_setting) | ||||
| elif document.data_source_type == "notion_import": | elif document.data_source_type == "notion_import": | ||||
| if not data_source_info: | |||||
| continue | |||||
| extract_setting = ExtractSetting( | extract_setting = ExtractSetting( | ||||
| datasource_type=DatasourceType.NOTION.value, | datasource_type=DatasourceType.NOTION.value, | ||||
| notion_info={ | notion_info={ | ||||
| ) | ) | ||||
| extract_settings.append(extract_setting) | extract_settings.append(extract_setting) | ||||
| elif document.data_source_type == "website_crawl": | elif document.data_source_type == "website_crawl": | ||||
| if not data_source_info: | |||||
| continue | |||||
| extract_setting = ExtractSetting( | extract_setting = ExtractSetting( | ||||
| datasource_type=DatasourceType.WEBSITE.value, | datasource_type=DatasourceType.WEBSITE.value, | ||||
| website_info={ | website_info={ |
| def get(self, installed_app: InstalledApp): | def get(self, installed_app: InstalledApp): | ||||
| """Get app meta""" | """Get app meta""" | ||||
| app_model = installed_app.app | app_model = installed_app.app | ||||
| if not app_model: | |||||
| raise ValueError("App not found") | |||||
| return AppService().get_app_meta(app_model) | return AppService().get_app_meta(app_model) | ||||
| Run workflow | Run workflow | ||||
| """ | """ | ||||
| app_model = installed_app.app | app_model = installed_app.app | ||||
| if not app_model: | |||||
| raise NotWorkflowAppError() | |||||
| app_mode = AppMode.value_of(app_model.mode) | app_mode = AppMode.value_of(app_model.mode) | ||||
| if app_mode != AppMode.WORKFLOW: | if app_mode != AppMode.WORKFLOW: | ||||
| raise NotWorkflowAppError() | raise NotWorkflowAppError() | ||||
| Stop workflow task | Stop workflow task | ||||
| """ | """ | ||||
| app_model = installed_app.app | app_model = installed_app.app | ||||
| if not app_model: | |||||
| raise NotWorkflowAppError() | |||||
| app_mode = AppMode.value_of(app_model.mode) | app_mode = AppMode.value_of(app_model.mode) | ||||
| if app_mode != AppMode.WORKFLOW: | if app_mode != AppMode.WORKFLOW: | ||||
| raise NotWorkflowAppError() | raise NotWorkflowAppError() |
| raise MessageNotExistsError() | raise MessageNotExistsError() | ||||
| current_app_model_config = app_model.app_model_config | current_app_model_config = app_model.app_model_config | ||||
| if not current_app_model_config: | |||||
| raise MoreLikeThisDisabledError() | |||||
| more_like_this = current_app_model_config.more_like_this_dict | more_like_this = current_app_model_config.more_like_this_dict | ||||
| if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: | if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: |
| last_edited_time = self.get_notion_last_edited_time() | last_edited_time = self.get_notion_last_edited_time() | ||||
| data_source_info = document_model.data_source_info_dict | data_source_info = document_model.data_source_info_dict | ||||
| data_source_info["last_edited_time"] = last_edited_time | |||||
| if data_source_info: | |||||
| data_source_info["last_edited_time"] = last_edited_time | |||||
| db.session.query(DocumentModel).filter_by(id=document_model.id).update( | db.session.query(DocumentModel).filter_by(id=document_model.id).update( | ||||
| {DocumentModel.data_source_info: json.dumps(data_source_info)} | {DocumentModel.data_source_info: json.dumps(data_source_info)} |
| import json | import json | ||||
| from typing import Any, Optional | |||||
| from typing import Any, Optional, Self | |||||
| from core.mcp.types import Tool as RemoteMCPTool | from core.mcp.types import Tool as RemoteMCPTool | ||||
| from core.tools.__base.tool_provider import ToolProviderController | from core.tools.__base.tool_provider import ToolProviderController | ||||
| return ToolProviderType.MCP | return ToolProviderType.MCP | ||||
| @classmethod | @classmethod | ||||
| def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController": | |||||
| def from_db(cls, db_provider: MCPToolProvider) -> Self: | |||||
| """ | """ | ||||
| from db provider | from db provider | ||||
| """ | """ |
| if provider is None: | if provider is None: | ||||
| raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") | raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found") | ||||
| controller = MCPToolProviderController._from_db(provider) | |||||
| controller = MCPToolProviderController.from_db(provider) | |||||
| return controller | return controller | ||||
| tenant_id: str, | tenant_id: str, | ||||
| provider_type: ToolProviderType, | provider_type: ToolProviderType, | ||||
| provider_id: str, | provider_id: str, | ||||
| ) -> Union[str, dict]: | |||||
| ) -> Union[str, dict[str, Any]]: | |||||
| """ | """ | ||||
| get the tool icon | get the tool icon | ||||
| import enum | import enum | ||||
| import json | import json | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from typing import Optional | |||||
| from typing import Any, Optional | |||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| from flask_login import UserMixin | |||||
| from flask_login import UserMixin # type: ignore[import-untyped] | |||||
| from sqlalchemy import DateTime, String, func, select | from sqlalchemy import DateTime, String, func, select | ||||
| from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor | from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor | ||||
| ) | ) | ||||
| @property | @property | ||||
| def custom_config_dict(self): | |||||
| def custom_config_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.custom_config) if self.custom_config else {} | return json.loads(self.custom_config) if self.custom_config else {} | ||||
| @custom_config_dict.setter | @custom_config_dict.setter | ||||
| def custom_config_dict(self, value: dict): | |||||
| def custom_config_dict(self, value: dict[str, Any]) -> None: | |||||
| self.custom_config = json.dumps(value) | self.custom_config = json.dumps(value) | ||||
| "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | ||||
| } | } | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "dataset_id": self.dataset_id, | "dataset_id": self.dataset_id, | ||||
| } | } | ||||
| @property | @property | ||||
| def rules_dict(self): | |||||
| def rules_dict(self) -> dict[str, Any] | None: | |||||
| try: | try: | ||||
| return json.loads(self.rules) if self.rules else None | return json.loads(self.rules) if self.rules else None | ||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| return status | return status | ||||
| @property | @property | ||||
| def data_source_info_dict(self): | |||||
| def data_source_info_dict(self) -> dict[str, Any] | None: | |||||
| if self.data_source_info: | if self.data_source_info: | ||||
| try: | try: | ||||
| data_source_info_dict = json.loads(self.data_source_info) | |||||
| data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) | |||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| data_source_info_dict = {} | data_source_info_dict = {} | ||||
| return None | return None | ||||
| @property | @property | ||||
| def data_source_detail_dict(self): | |||||
| def data_source_detail_dict(self) -> dict[str, Any]: | |||||
| if self.data_source_info: | if self.data_source_info: | ||||
| if self.data_source_type == "upload_file": | if self.data_source_type == "upload_file": | ||||
| data_source_info_dict = json.loads(self.data_source_info) | |||||
| data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) | |||||
| file_detail = ( | file_detail = ( | ||||
| db.session.query(UploadFile) | db.session.query(UploadFile) | ||||
| .where(UploadFile.id == data_source_info_dict["upload_file_id"]) | .where(UploadFile.id == data_source_info_dict["upload_file_id"]) | ||||
| } | } | ||||
| } | } | ||||
| elif self.data_source_type in {"notion_import", "website_crawl"}: | elif self.data_source_type in {"notion_import", "website_crawl"}: | ||||
| return json.loads(self.data_source_info) | |||||
| result: dict[str, Any] = json.loads(self.data_source_info) | |||||
| return result | |||||
| return {} | return {} | ||||
| @property | @property | ||||
| return self.updated_at | return self.updated_at | ||||
| @property | @property | ||||
| def doc_metadata_details(self): | |||||
| def doc_metadata_details(self) -> list[dict[str, Any]] | None: | |||||
| if self.doc_metadata: | if self.doc_metadata: | ||||
| document_metadatas = ( | document_metadatas = ( | ||||
| db.session.query(DatasetMetadata) | db.session.query(DatasetMetadata) | ||||
| ) | ) | ||||
| .all() | .all() | ||||
| ) | ) | ||||
| metadata_list = [] | |||||
| metadata_list: list[dict[str, Any]] = [] | |||||
| for metadata in document_metadatas: | for metadata in document_metadatas: | ||||
| metadata_dict = { | |||||
| metadata_dict: dict[str, Any] = { | |||||
| "id": metadata.id, | "id": metadata.id, | ||||
| "name": metadata.name, | "name": metadata.name, | ||||
| "type": metadata.type, | "type": metadata.type, | ||||
| return None | return None | ||||
| @property | @property | ||||
| def process_rule_dict(self): | |||||
| if self.dataset_process_rule_id: | |||||
| def process_rule_dict(self) -> dict[str, Any] | None: | |||||
| if self.dataset_process_rule_id and self.dataset_process_rule: | |||||
| return self.dataset_process_rule.to_dict() | return self.dataset_process_rule.to_dict() | ||||
| return None | return None | ||||
| def get_built_in_fields(self): | |||||
| built_in_fields = [] | |||||
| def get_built_in_fields(self) -> list[dict[str, Any]]: | |||||
| built_in_fields: list[dict[str, Any]] = [] | |||||
| built_in_fields.append( | built_in_fields.append( | ||||
| { | { | ||||
| "id": "built-in", | "id": "built-in", | ||||
| ) | ) | ||||
| return built_in_fields | return built_in_fields | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "tenant_id": self.tenant_id, | "tenant_id": self.tenant_id, | ||||
| "data_source_info_dict": self.data_source_info_dict, | "data_source_info_dict": self.data_source_info_dict, | ||||
| "average_segment_length": self.average_segment_length, | "average_segment_length": self.average_segment_length, | ||||
| "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | ||||
| "dataset": self.dataset.to_dict() if self.dataset else None, | |||||
| "dataset": None, # Dataset class doesn't have a to_dict method | |||||
| "segment_count": self.segment_count, | "segment_count": self.segment_count, | ||||
| "hit_count": self.hit_count, | "hit_count": self.hit_count, | ||||
| } | } | ||||
| @classmethod | @classmethod | ||||
| def from_dict(cls, data: dict): | |||||
| def from_dict(cls, data: dict[str, Any]): | |||||
| return cls( | return cls( | ||||
| id=data.get("id"), | id=data.get("id"), | ||||
| tenant_id=data.get("tenant_id"), | tenant_id=data.get("tenant_id"), | ||||
| ) | ) | ||||
| @property | @property | ||||
| def child_chunks(self): | |||||
| process_rule = self.document.dataset_process_rule | |||||
| if process_rule.mode == "hierarchical": | |||||
| rules = Rule(**process_rule.rules_dict) | |||||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| else: | |||||
| return [] | |||||
| else: | |||||
| def child_chunks(self) -> list[Any]: | |||||
| if not self.document: | |||||
| return [] | return [] | ||||
| def get_child_chunks(self): | |||||
| process_rule = self.document.dataset_process_rule | process_rule = self.document.dataset_process_rule | ||||
| if process_rule.mode == "hierarchical": | |||||
| rules = Rule(**process_rule.rules_dict) | |||||
| if rules.parent_mode: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| else: | |||||
| return [] | |||||
| else: | |||||
| if process_rule and process_rule.mode == "hierarchical": | |||||
| rules_dict = process_rule.rules_dict | |||||
| if rules_dict: | |||||
| rules = Rule(**rules_dict) | |||||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| return [] | |||||
| def get_child_chunks(self) -> list[Any]: | |||||
| if not self.document: | |||||
| return [] | return [] | ||||
| process_rule = self.document.dataset_process_rule | |||||
| if process_rule and process_rule.mode == "hierarchical": | |||||
| rules_dict = process_rule.rules_dict | |||||
| if rules_dict: | |||||
| rules = Rule(**rules_dict) | |||||
| if rules.parent_mode: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| return [] | |||||
| @property | @property | ||||
| def sign_content(self): | |||||
| def sign_content(self) -> str: | |||||
| return self.get_sign_content() | return self.get_sign_content() | ||||
| def get_sign_content(self): | |||||
| signed_urls = [] | |||||
| def get_sign_content(self) -> str: | |||||
| signed_urls: list[tuple[int, int, str]] = [] | |||||
| text = self.content | text = self.content | ||||
| # For data before v0.10.0 | # For data before v0.10.0 | ||||
| ) | ) | ||||
| @property | @property | ||||
| def keyword_table_dict(self): | |||||
| def keyword_table_dict(self) -> dict[str, set[Any]] | None: | |||||
| class SetDecoder(json.JSONDecoder): | class SetDecoder(json.JSONDecoder): | ||||
| def __init__(self, *args, **kwargs): | |||||
| super().__init__(object_hook=self.object_hook, *args, **kwargs) | |||||
| def object_hook(self, dct): | |||||
| if isinstance(dct, dict): | |||||
| for keyword, node_idxs in dct.items(): | |||||
| if isinstance(node_idxs, list): | |||||
| dct[keyword] = set(node_idxs) | |||||
| return dct | |||||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |||||
| def object_hook(dct: Any) -> Any: | |||||
| if isinstance(dct, dict): | |||||
| result: dict[str, Any] = {} | |||||
| items = cast(dict[str, Any], dct).items() | |||||
| for keyword, node_idxs in items: | |||||
| if isinstance(node_idxs, list): | |||||
| result[keyword] = set(cast(list[Any], node_idxs)) | |||||
| else: | |||||
| result[keyword] = node_idxs | |||||
| return result | |||||
| return dct | |||||
| super().__init__(object_hook=object_hook, *args, **kwargs) | |||||
| # get dataset | # get dataset | ||||
| dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() | dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() | ||||
| updated_by = mapped_column(StringUUID, nullable=True) | updated_by = mapped_column(StringUUID, nullable=True) | ||||
| updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "tenant_id": self.tenant_id, | "tenant_id": self.tenant_id, | ||||
| } | } | ||||
| @property | @property | ||||
| def settings_dict(self): | |||||
| def settings_dict(self) -> dict[str, Any] | None: | |||||
| try: | try: | ||||
| return json.loads(self.settings) if self.settings else None | return json.loads(self.settings) if self.settings else None | ||||
| except JSONDecodeError: | except JSONDecodeError: | ||||
| return None | return None | ||||
| @property | @property | ||||
| def dataset_bindings(self): | |||||
| def dataset_bindings(self) -> list[dict[str, Any]]: | |||||
| external_knowledge_bindings = ( | external_knowledge_bindings = ( | ||||
| db.session.query(ExternalKnowledgeBindings) | db.session.query(ExternalKnowledgeBindings) | ||||
| .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | ||||
| ) | ) | ||||
| dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | ||||
| datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() | datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() | ||||
| dataset_bindings = [] | |||||
| dataset_bindings: list[dict[str, Any]] = [] | |||||
| for dataset in datasets: | for dataset in datasets: | ||||
| dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| from flask import request | from flask import request | ||||
| from flask_login import UserMixin | |||||
| from flask_login import UserMixin # type: ignore[import-untyped] | |||||
| from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text | from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text | ||||
| from sqlalchemy.orm import Mapped, Session, mapped_column | from sqlalchemy.orm import Mapped, Session, mapped_column | ||||
| from constants import DEFAULT_FILE_NUMBER_LIMITS | from constants import DEFAULT_FILE_NUMBER_LIMITS | ||||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType | from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType | ||||
| from core.file import helpers as file_helpers | from core.file import helpers as file_helpers | ||||
| from libs.helper import generate_string | |||||
| from libs.helper import generate_string # type: ignore[import-not-found] | |||||
| from .account import Account, Tenant | from .account import Account, Tenant | ||||
| from .base import Base | from .base import Base | ||||
| use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | ||||
| @property | @property | ||||
| def desc_or_prompt(self): | |||||
| def desc_or_prompt(self) -> str: | |||||
| if self.description: | if self.description: | ||||
| return self.description | return self.description | ||||
| else: | else: | ||||
| return "" | return "" | ||||
| @property | @property | ||||
| def site(self): | |||||
| def site(self) -> Optional["Site"]: | |||||
| site = db.session.query(Site).where(Site.app_id == self.id).first() | site = db.session.query(Site).where(Site.app_id == self.id).first() | ||||
| return site | return site | ||||
| @property | @property | ||||
| def app_model_config(self): | |||||
| def app_model_config(self) -> Optional["AppModelConfig"]: | |||||
| if self.app_model_config_id: | if self.app_model_config_id: | ||||
| return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() | return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() | ||||
| return None | return None | ||||
| @property | @property | ||||
| def api_base_url(self): | |||||
| def api_base_url(self) -> str: | |||||
| return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" | return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" | ||||
| @property | @property | ||||
| def tenant(self): | |||||
| def tenant(self) -> Optional[Tenant]: | |||||
| tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | ||||
| return tenant | return tenant | ||||
| return str(self.mode) | return str(self.mode) | ||||
| @property | @property | ||||
| def deleted_tools(self): | |||||
| def deleted_tools(self) -> list[dict[str, str]]: | |||||
| from core.tools.tool_manager import ToolManager | from core.tools.tool_manager import ToolManager | ||||
| from services.plugin.plugin_service import PluginService | from services.plugin.plugin_service import PluginService | ||||
| provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) | provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) | ||||
| } | } | ||||
| deleted_tools = [] | |||||
| deleted_tools: list[dict[str, str]] = [] | |||||
| for tool in tools: | for tool in tools: | ||||
| keys = list(tool.keys()) | keys = list(tool.keys()) | ||||
| return deleted_tools | return deleted_tools | ||||
| @property | @property | ||||
| def tags(self): | |||||
| def tags(self) -> list["Tag"]: | |||||
| tags = ( | tags = ( | ||||
| db.session.query(Tag) | db.session.query(Tag) | ||||
| .join(TagBinding, Tag.id == TagBinding.tag_id) | .join(TagBinding, Tag.id == TagBinding.tag_id) | ||||
| return tags or [] | return tags or [] | ||||
| @property | @property | ||||
| def author_name(self): | |||||
| def author_name(self) -> Optional[str]: | |||||
| if self.created_by: | if self.created_by: | ||||
| account = db.session.query(Account).where(Account.id == self.created_by).first() | account = db.session.query(Account).where(Account.id == self.created_by).first() | ||||
| if account: | if account: | ||||
| file_upload = mapped_column(sa.Text) | file_upload = mapped_column(sa.Text) | ||||
| @property | @property | ||||
| def app(self): | |||||
| def app(self) -> Optional[App]: | |||||
| app = db.session.query(App).where(App.id == self.app_id).first() | app = db.session.query(App).where(App.id == self.app_id).first() | ||||
| return app | return app | ||||
| @property | @property | ||||
| def model_dict(self): | |||||
| def model_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.model) if self.model else {} | return json.loads(self.model) if self.model else {} | ||||
| @property | @property | ||||
| def suggested_questions_list(self): | |||||
| def suggested_questions_list(self) -> list[str]: | |||||
| return json.loads(self.suggested_questions) if self.suggested_questions else [] | return json.loads(self.suggested_questions) if self.suggested_questions else [] | ||||
| @property | @property | ||||
| def suggested_questions_after_answer_dict(self): | |||||
| def suggested_questions_after_answer_dict(self) -> dict[str, Any]: | |||||
| return ( | return ( | ||||
| json.loads(self.suggested_questions_after_answer) | json.loads(self.suggested_questions_after_answer) | ||||
| if self.suggested_questions_after_answer | if self.suggested_questions_after_answer | ||||
| ) | ) | ||||
| @property | @property | ||||
| def speech_to_text_dict(self): | |||||
| def speech_to_text_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} | return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} | ||||
| @property | @property | ||||
| def text_to_speech_dict(self): | |||||
| def text_to_speech_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} | return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} | ||||
| @property | @property | ||||
| def retriever_resource_dict(self): | |||||
| def retriever_resource_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} | return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} | ||||
| @property | @property | ||||
| def annotation_reply_dict(self): | |||||
| def annotation_reply_dict(self) -> dict[str, Any]: | |||||
| annotation_setting = ( | annotation_setting = ( | ||||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() | db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() | ||||
| ) | ) | ||||
| return {"enabled": False} | return {"enabled": False} | ||||
| @property | @property | ||||
| def more_like_this_dict(self): | |||||
| def more_like_this_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} | return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} | ||||
| @property | @property | ||||
| def sensitive_word_avoidance_dict(self): | |||||
| def sensitive_word_avoidance_dict(self) -> dict[str, Any]: | |||||
| return ( | return ( | ||||
| json.loads(self.sensitive_word_avoidance) | json.loads(self.sensitive_word_avoidance) | ||||
| if self.sensitive_word_avoidance | if self.sensitive_word_avoidance | ||||
| ) | ) | ||||
| @property | @property | ||||
| def external_data_tools_list(self) -> list[dict]: | |||||
| def external_data_tools_list(self) -> list[dict[str, Any]]: | |||||
| return json.loads(self.external_data_tools) if self.external_data_tools else [] | return json.loads(self.external_data_tools) if self.external_data_tools else [] | ||||
| @property | @property | ||||
| def user_input_form_list(self): | |||||
| def user_input_form_list(self) -> list[dict[str, Any]]: | |||||
| return json.loads(self.user_input_form) if self.user_input_form else [] | return json.loads(self.user_input_form) if self.user_input_form else [] | ||||
| @property | @property | ||||
| def agent_mode_dict(self): | |||||
| def agent_mode_dict(self) -> dict[str, Any]: | |||||
| return ( | return ( | ||||
| json.loads(self.agent_mode) | json.loads(self.agent_mode) | ||||
| if self.agent_mode | if self.agent_mode | ||||
| ) | ) | ||||
| @property | @property | ||||
| def chat_prompt_config_dict(self): | |||||
| def chat_prompt_config_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} | return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} | ||||
| @property | @property | ||||
| def completion_prompt_config_dict(self): | |||||
| def completion_prompt_config_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} | return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} | ||||
| @property | @property | ||||
| def dataset_configs_dict(self): | |||||
| def dataset_configs_dict(self) -> dict[str, Any]: | |||||
| if self.dataset_configs: | if self.dataset_configs: | ||||
| dataset_configs: dict = json.loads(self.dataset_configs) | |||||
| dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) | |||||
| if "retrieval_model" not in dataset_configs: | if "retrieval_model" not in dataset_configs: | ||||
| return {"retrieval_model": "single"} | return {"retrieval_model": "single"} | ||||
| else: | else: | ||||
| } | } | ||||
| @property | @property | ||||
| def file_upload_dict(self): | |||||
| def file_upload_dict(self) -> dict[str, Any]: | |||||
| return ( | return ( | ||||
| json.loads(self.file_upload) | json.loads(self.file_upload) | ||||
| if self.file_upload | if self.file_upload | ||||
| } | } | ||||
| ) | ) | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "opening_statement": self.opening_statement, | "opening_statement": self.opening_statement, | ||||
| "suggested_questions": self.suggested_questions_list, | "suggested_questions": self.suggested_questions_list, | ||||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @property | @property | ||||
| def app(self): | |||||
| def app(self) -> Optional[App]: | |||||
| app = db.session.query(App).where(App.id == self.app_id).first() | app = db.session.query(App).where(App.id == self.app_id).first() | ||||
| return app | return app | ||||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @property | @property | ||||
| def app(self): | |||||
| def app(self) -> Optional[App]: | |||||
| app = db.session.query(App).where(App.id == self.app_id).first() | app = db.session.query(App).where(App.id == self.app_id).first() | ||||
| return app | return app | ||||
| @property | @property | ||||
| def tenant(self): | |||||
| def tenant(self) -> Optional[Tenant]: | |||||
| tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | ||||
| return tenant | return tenant | ||||
| mode: Mapped[str] = mapped_column(String(255)) | mode: Mapped[str] = mapped_column(String(255)) | ||||
| name: Mapped[str] = mapped_column(String(255), nullable=False) | name: Mapped[str] = mapped_column(String(255), nullable=False) | ||||
| summary = mapped_column(sa.Text) | summary = mapped_column(sa.Text) | ||||
| _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) | |||||
| _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) | |||||
| introduction = mapped_column(sa.Text) | introduction = mapped_column(sa.Text) | ||||
| system_instruction = mapped_column(sa.Text) | system_instruction = mapped_column(sa.Text) | ||||
| system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | ||||
| is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) | ||||
| @property | @property | ||||
| def inputs(self): | |||||
| def inputs(self) -> dict[str, Any]: | |||||
| inputs = self._inputs.copy() | inputs = self._inputs.copy() | ||||
| # Convert file mapping to File object | # Convert file mapping to File object | ||||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | ||||
| from factories import file_factory | from factories import file_factory | ||||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| value["tool_file_id"] = value["related_id"] | |||||
| elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| value["upload_file_id"] = value["related_id"] | |||||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||||
| elif isinstance(value, list) and all( | |||||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||||
| if ( | |||||
| isinstance(value, dict) | |||||
| and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||||
| ): | ): | ||||
| inputs[key] = [] | |||||
| for item in value: | |||||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| item["tool_file_id"] = item["related_id"] | |||||
| elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| item["upload_file_id"] = item["related_id"] | |||||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||||
| value_dict = cast(dict[str, Any], value) | |||||
| if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| value_dict["tool_file_id"] = value_dict["related_id"] | |||||
| elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| value_dict["upload_file_id"] = value_dict["related_id"] | |||||
| tenant_id = cast(str, value_dict.get("tenant_id", "")) | |||||
| inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) | |||||
| elif isinstance(value, list): | |||||
| value_list = cast(list[Any], value) | |||||
| if all( | |||||
| isinstance(item, dict) | |||||
| and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||||
| for item in value_list | |||||
| ): | |||||
| file_list: list[File] = [] | |||||
| for item in value_list: | |||||
| if not isinstance(item, dict): | |||||
| continue | |||||
| item_dict = cast(dict[str, Any], item) | |||||
| if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| item_dict["tool_file_id"] = item_dict["related_id"] | |||||
| elif item_dict["transfer_method"] in [ | |||||
| FileTransferMethod.LOCAL_FILE, | |||||
| FileTransferMethod.REMOTE_URL, | |||||
| ]: | |||||
| item_dict["upload_file_id"] = item_dict["related_id"] | |||||
| tenant_id = cast(str, item_dict.get("tenant_id", "")) | |||||
| file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) | |||||
| inputs[key] = file_list | |||||
| return inputs | return inputs | ||||
| for k, v in inputs.items(): | for k, v in inputs.items(): | ||||
| if isinstance(v, File): | if isinstance(v, File): | ||||
| inputs[k] = v.model_dump() | inputs[k] = v.model_dump() | ||||
| elif isinstance(v, list) and all(isinstance(item, File) for item in v): | |||||
| inputs[k] = [item.model_dump() for item in v] | |||||
| elif isinstance(v, list): | |||||
| v_list = cast(list[Any], v) | |||||
| if all(isinstance(item, File) for item in v_list): | |||||
| inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] | |||||
| self._inputs = inputs | self._inputs = inputs | ||||
| @property | @property | ||||
| ) | ) | ||||
| @property | @property | ||||
| def app(self): | |||||
| def app(self) -> Optional[App]: | |||||
| return db.session.query(App).where(App.id == self.app_id).first() | return db.session.query(App).where(App.id == self.app_id).first() | ||||
| @property | @property | ||||
| return None | return None | ||||
| @property | @property | ||||
| def from_account_name(self): | |||||
| def from_account_name(self) -> Optional[str]: | |||||
| if self.from_account_id: | if self.from_account_id: | ||||
| account = db.session.query(Account).where(Account.id == self.from_account_id).first() | account = db.session.query(Account).where(Account.id == self.from_account_id).first() | ||||
| if account: | if account: | ||||
| return None | return None | ||||
| @property | @property | ||||
| def in_debug_mode(self): | |||||
| def in_debug_mode(self) -> bool: | |||||
| return self.override_model_configs is not None | return self.override_model_configs is not None | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "app_id": self.app_id, | "app_id": self.app_id, | ||||
| model_id = mapped_column(String(255), nullable=True) | model_id = mapped_column(String(255), nullable=True) | ||||
| override_model_configs = mapped_column(sa.Text) | override_model_configs = mapped_column(sa.Text) | ||||
| conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) | conversation_id = mapped_column(StringUUID, sa.ForeignKey("conversations.id"), nullable=False) | ||||
| _inputs: Mapped[dict] = mapped_column("inputs", sa.JSON) | |||||
| _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) | |||||
| query: Mapped[str] = mapped_column(sa.Text, nullable=False) | query: Mapped[str] = mapped_column(sa.Text, nullable=False) | ||||
| message = mapped_column(sa.JSON, nullable=False) | message = mapped_column(sa.JSON, nullable=False) | ||||
| message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | message_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) | ||||
| workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) | workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID) | ||||
| @property | @property | ||||
| def inputs(self): | |||||
| def inputs(self) -> dict[str, Any]: | |||||
| inputs = self._inputs.copy() | inputs = self._inputs.copy() | ||||
| for key, value in inputs.items(): | for key, value in inputs.items(): | ||||
| # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. | ||||
| from factories import file_factory | from factories import file_factory | ||||
| if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | |||||
| if value["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| value["tool_file_id"] = value["related_id"] | |||||
| elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| value["upload_file_id"] = value["related_id"] | |||||
| inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"]) | |||||
| elif isinstance(value, list) and all( | |||||
| isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value | |||||
| if ( | |||||
| isinstance(value, dict) | |||||
| and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||||
| ): | ): | ||||
| inputs[key] = [] | |||||
| for item in value: | |||||
| if item["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| item["tool_file_id"] = item["related_id"] | |||||
| elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| item["upload_file_id"] = item["related_id"] | |||||
| inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"])) | |||||
| value_dict = cast(dict[str, Any], value) | |||||
| if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| value_dict["tool_file_id"] = value_dict["related_id"] | |||||
| elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: | |||||
| value_dict["upload_file_id"] = value_dict["related_id"] | |||||
| tenant_id = cast(str, value_dict.get("tenant_id", "")) | |||||
| inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) | |||||
| elif isinstance(value, list): | |||||
| value_list = cast(list[Any], value) | |||||
| if all( | |||||
| isinstance(item, dict) | |||||
| and cast(dict[str, Any], item).get("dify_model_identity") == FILE_MODEL_IDENTITY | |||||
| for item in value_list | |||||
| ): | |||||
| file_list: list[File] = [] | |||||
| for item in value_list: | |||||
| if not isinstance(item, dict): | |||||
| continue | |||||
| item_dict = cast(dict[str, Any], item) | |||||
| if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: | |||||
| item_dict["tool_file_id"] = item_dict["related_id"] | |||||
| elif item_dict["transfer_method"] in [ | |||||
| FileTransferMethod.LOCAL_FILE, | |||||
| FileTransferMethod.REMOTE_URL, | |||||
| ]: | |||||
| item_dict["upload_file_id"] = item_dict["related_id"] | |||||
| tenant_id = cast(str, item_dict.get("tenant_id", "")) | |||||
| file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) | |||||
| inputs[key] = file_list | |||||
| return inputs | return inputs | ||||
| @inputs.setter | @inputs.setter | ||||
| for k, v in inputs.items(): | for k, v in inputs.items(): | ||||
| if isinstance(v, File): | if isinstance(v, File): | ||||
| inputs[k] = v.model_dump() | inputs[k] = v.model_dump() | ||||
| elif isinstance(v, list) and all(isinstance(item, File) for item in v): | |||||
| inputs[k] = [item.model_dump() for item in v] | |||||
| elif isinstance(v, list): | |||||
| v_list = cast(list[Any], v) | |||||
| if all(isinstance(item, File) for item in v_list): | |||||
| inputs[k] = [item.model_dump() for item in v_list if isinstance(item, File)] | |||||
| self._inputs = inputs | self._inputs = inputs | ||||
| @property | @property | ||||
| return None | return None | ||||
| @property | @property | ||||
| def in_debug_mode(self): | |||||
| def in_debug_mode(self) -> bool: | |||||
| return self.override_model_configs is not None | return self.override_model_configs is not None | ||||
| @property | @property | ||||
| def message_metadata_dict(self): | |||||
| def message_metadata_dict(self) -> dict[str, Any]: | |||||
| return json.loads(self.message_metadata) if self.message_metadata else {} | return json.loads(self.message_metadata) if self.message_metadata else {} | ||||
| @property | @property | ||||
| def agent_thoughts(self): | |||||
| def agent_thoughts(self) -> list["MessageAgentThought"]: | |||||
| return ( | return ( | ||||
| db.session.query(MessageAgentThought) | db.session.query(MessageAgentThought) | ||||
| .where(MessageAgentThought.message_id == self.id) | .where(MessageAgentThought.message_id == self.id) | ||||
| ) | ) | ||||
| @property | @property | ||||
| def retriever_resources(self): | |||||
| def retriever_resources(self) -> Any | list[Any]: | |||||
| return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] | return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] | ||||
| @property | @property | ||||
| def message_files(self): | |||||
| def message_files(self) -> list[dict[str, Any]]: | |||||
| from factories import file_factory | from factories import file_factory | ||||
| message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() | message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all() | ||||
| if not current_app: | if not current_app: | ||||
| raise ValueError(f"App {self.app_id} not found") | raise ValueError(f"App {self.app_id} not found") | ||||
| files = [] | |||||
| files: list[File] = [] | |||||
| for message_file in message_files: | for message_file in message_files: | ||||
| if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: | if message_file.transfer_method == FileTransferMethod.LOCAL_FILE.value: | ||||
| if message_file.upload_file_id is None: | if message_file.upload_file_id is None: | ||||
| ) | ) | ||||
| files.append(file) | files.append(file) | ||||
| result = [ | |||||
| result: list[dict[str, Any]] = [ | |||||
| {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} | {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} | ||||
| for (file, message_file) in zip(files, message_files) | for (file, message_file) in zip(files, message_files) | ||||
| ] | ] | ||||
| return None | return None | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "app_id": self.app_id, | "app_id": self.app_id, | ||||
| } | } | ||||
| @classmethod | @classmethod | ||||
| def from_dict(cls, data: dict): | |||||
| def from_dict(cls, data: dict[str, Any]) -> "Message": | |||||
| return cls( | return cls( | ||||
| id=data["id"], | id=data["id"], | ||||
| app_id=data["app_id"], | app_id=data["app_id"], | ||||
| account = db.session.query(Account).where(Account.id == self.from_account_id).first() | account = db.session.query(Account).where(Account.id == self.from_account_id).first() | ||||
| return account | return account | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": str(self.id), | "id": str(self.id), | ||||
| "app_id": str(self.app_id), | "app_id": str(self.app_id), | ||||
| type: Mapped[str] = mapped_column(String(255), nullable=False) | type: Mapped[str] = mapped_column(String(255), nullable=False) | ||||
| external_user_id = mapped_column(String(255), nullable=True) | external_user_id = mapped_column(String(255), nullable=True) | ||||
| name = mapped_column(String(255)) | name = mapped_column(String(255)) | ||||
| is_anonymous: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | |||||
| _is_anonymous: Mapped[bool] = mapped_column( | |||||
| "is_anonymous", sa.Boolean, nullable=False, server_default=sa.text("true") | |||||
| ) | |||||
| @property | |||||
| def is_anonymous(self) -> Literal[False]: | |||||
| return False | |||||
| @is_anonymous.setter | |||||
| def is_anonymous(self, value: bool) -> None: | |||||
| self._is_anonymous = value | |||||
| session_id: Mapped[str] = mapped_column() | session_id: Mapped[str] = mapped_column() | ||||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @staticmethod | @staticmethod | ||||
| def generate_server_code(n): | |||||
| def generate_server_code(n: int) -> str: | |||||
| while True: | while True: | ||||
| result = generate_string(n) | result = generate_string(n) | ||||
| while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: | while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: | ||||
| self._custom_disclaimer = value | self._custom_disclaimer = value | ||||
| @staticmethod | @staticmethod | ||||
| def generate_code(n): | |||||
| def generate_code(n: int) -> str: | |||||
| while True: | while True: | ||||
| result = generate_string(n) | result = generate_string(n) | ||||
| while db.session.query(Site).where(Site.code == result).count() > 0: | while db.session.query(Site).where(Site.code == result).count() > 0: | ||||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| @staticmethod | @staticmethod | ||||
| def generate_api_key(prefix, n): | |||||
| def generate_api_key(prefix: str, n: int) -> str: | |||||
| while True: | while True: | ||||
| result = prefix + generate_string(n) | result = prefix + generate_string(n) | ||||
| if db.session.scalar(select(exists().where(ApiToken.token == result))): | if db.session.scalar(select(exists().where(ApiToken.token == result))): | ||||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | ||||
| @property | @property | ||||
| def files(self): | |||||
| def files(self) -> list[Any]: | |||||
| if self.message_files: | if self.message_files: | ||||
| return cast(list[Any], json.loads(self.message_files)) | return cast(list[Any], json.loads(self.message_files)) | ||||
| else: | else: | ||||
| return self.tool.split(";") if self.tool else [] | return self.tool.split(";") if self.tool else [] | ||||
| @property | @property | ||||
| def tool_labels(self): | |||||
| def tool_labels(self) -> dict[str, Any]: | |||||
| try: | try: | ||||
| if self.tool_labels_str: | if self.tool_labels_str: | ||||
| return cast(dict, json.loads(self.tool_labels_str)) | |||||
| return cast(dict[str, Any], json.loads(self.tool_labels_str)) | |||||
| else: | else: | ||||
| return {} | return {} | ||||
| except Exception: | except Exception: | ||||
| return {} | return {} | ||||
| @property | @property | ||||
| def tool_meta(self): | |||||
| def tool_meta(self) -> dict[str, Any]: | |||||
| try: | try: | ||||
| if self.tool_meta_str: | if self.tool_meta_str: | ||||
| return cast(dict, json.loads(self.tool_meta_str)) | |||||
| return cast(dict[str, Any], json.loads(self.tool_meta_str)) | |||||
| else: | else: | ||||
| return {} | return {} | ||||
| except Exception: | except Exception: | ||||
| return {} | return {} | ||||
| @property | @property | ||||
| def tool_inputs_dict(self): | |||||
| def tool_inputs_dict(self) -> dict[str, Any]: | |||||
| tools = self.tools | tools = self.tools | ||||
| try: | try: | ||||
| if self.tool_input: | if self.tool_input: | ||||
| data = json.loads(self.tool_input) | data = json.loads(self.tool_input) | ||||
| result = {} | |||||
| result: dict[str, Any] = {} | |||||
| for tool in tools: | for tool in tools: | ||||
| if tool in data: | if tool in data: | ||||
| result[tool] = data[tool] | result[tool] = data[tool] | ||||
| return {} | return {} | ||||
| @property | @property | ||||
| def tool_outputs_dict(self): | |||||
| def tool_outputs_dict(self) -> dict[str, Any]: | |||||
| tools = self.tools | tools = self.tools | ||||
| try: | try: | ||||
| if self.observation: | if self.observation: | ||||
| data = json.loads(self.observation) | data = json.loads(self.observation) | ||||
| result = {} | |||||
| result: dict[str, Any] = {} | |||||
| for tool in tools: | for tool in tools: | ||||
| if tool in data: | if tool in data: | ||||
| result[tool] = data[tool] | result[tool] = data[tool] | ||||
| is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) | ||||
| @property | @property | ||||
| def tracing_config_dict(self): | |||||
| def tracing_config_dict(self) -> dict[str, Any]: | |||||
| return self.tracing_config or {} | return self.tracing_config or {} | ||||
| @property | @property | ||||
| def tracing_config_str(self): | |||||
| def tracing_config_str(self) -> str: | |||||
| return json.dumps(self.tracing_config_dict) | return json.dumps(self.tracing_config_dict) | ||||
| def to_dict(self): | |||||
| def to_dict(self) -> dict[str, Any]: | |||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| "app_id": self.app_id, | "app_id": self.app_id, |
| SYSTEM = "system" | SYSTEM = "system" | ||||
| @staticmethod | @staticmethod | ||||
| def value_of(value): | |||||
| def value_of(value: str) -> "ProviderType": | |||||
| for member in ProviderType: | for member in ProviderType: | ||||
| if member.value == value: | if member.value == value: | ||||
| return member | return member | ||||
| """hosted trial quota""" | """hosted trial quota""" | ||||
| @staticmethod | @staticmethod | ||||
| def value_of(value): | |||||
| def value_of(value: str) -> "ProviderQuotaType": | |||||
| for member in ProviderQuotaType: | for member in ProviderQuotaType: | ||||
| if member.value == value: | if member.value == value: | ||||
| return member | return member |
| import json | import json | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from typing import Optional, cast | |||||
| from typing import Any, Optional, cast | |||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) | encrypted_oauth_params: Mapped[str] = mapped_column(sa.Text, nullable=False) | ||||
| @property | @property | ||||
| def oauth_params(self): | |||||
| return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) | |||||
| def oauth_params(self) -> dict[str, Any]: | |||||
| return cast(dict[str, Any], json.loads(self.encrypted_oauth_params or "{}")) | |||||
| class BuiltinToolProvider(Base): | class BuiltinToolProvider(Base): | ||||
| expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) | expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1")) | ||||
| @property | @property | ||||
| def credentials(self): | |||||
| return cast(dict, json.loads(self.encrypted_credentials)) | |||||
| def credentials(self) -> dict[str, Any]: | |||||
| return cast(dict[str, Any], json.loads(self.encrypted_credentials)) | |||||
| class ApiToolProvider(Base): | class ApiToolProvider(Base): | ||||
| return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] | return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] | ||||
| @property | @property | ||||
| def credentials(self): | |||||
| return dict(json.loads(self.credentials_str)) | |||||
| def credentials(self) -> dict[str, Any]: | |||||
| return dict[str, Any](json.loads(self.credentials_str)) | |||||
| @property | @property | ||||
| def user(self) -> Account | None: | def user(self) -> Account | None: | ||||
| return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() | ||||
| @property | @property | ||||
| def credentials(self): | |||||
| def credentials(self) -> dict[str, Any]: | |||||
| try: | try: | ||||
| return cast(dict, json.loads(self.encrypted_credentials)) or {} | |||||
| return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {} | |||||
| except Exception: | except Exception: | ||||
| return {} | return {} | ||||
| return mask_url(self.decrypted_server_url) | return mask_url(self.decrypted_server_url) | ||||
| @property | @property | ||||
| def decrypted_credentials(self): | |||||
| def decrypted_credentials(self) -> dict[str, Any]: | |||||
| from core.helper.provider_cache import NoOpProviderCredentialCache | from core.helper.provider_cache import NoOpProviderCredentialCache | ||||
| from core.tools.mcp_tool.provider import MCPToolProviderController | from core.tools.mcp_tool.provider import MCPToolProviderController | ||||
| from core.tools.utils.encryption import create_provider_encrypter | from core.tools.utils.encryption import create_provider_encrypter | ||||
| provider_controller = MCPToolProviderController._from_db(self) | |||||
| provider_controller = MCPToolProviderController.from_db(self) | |||||
| encrypter, _ = create_provider_encrypter( | encrypter, _ = create_provider_encrypter( | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| cache=NoOpProviderCredentialCache(), | cache=NoOpProviderCredentialCache(), | ||||
| ) | ) | ||||
| return encrypter.decrypt(self.credentials) # type: ignore | |||||
| return encrypter.decrypt(self.credentials) | |||||
| class ToolModelInvoke(Base): | class ToolModelInvoke(Base): |
| import enum | import enum | ||||
| from typing import Generic, TypeVar | |||||
| import uuid | |||||
| from typing import Any, Generic, TypeVar | |||||
| from sqlalchemy import CHAR, VARCHAR, TypeDecorator | from sqlalchemy import CHAR, VARCHAR, TypeDecorator | ||||
| from sqlalchemy.dialects.postgresql import UUID | from sqlalchemy.dialects.postgresql import UUID | ||||
| from sqlalchemy.engine.interfaces import Dialect | |||||
| from sqlalchemy.sql.type_api import TypeEngine | |||||
| class StringUUID(TypeDecorator): | |||||
| class StringUUID(TypeDecorator[uuid.UUID | str | None]): | |||||
| impl = CHAR | impl = CHAR | ||||
| cache_ok = True | cache_ok = True | ||||
| def process_bind_param(self, value, dialect): | |||||
| def process_bind_param(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: | |||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| elif dialect.name == "postgresql": | elif dialect.name == "postgresql": | ||||
| return str(value) | return str(value) | ||||
| else: | else: | ||||
| return value.hex | |||||
| if isinstance(value, uuid.UUID): | |||||
| return value.hex | |||||
| return value | |||||
| def load_dialect_impl(self, dialect): | |||||
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: | |||||
| if dialect.name == "postgresql": | if dialect.name == "postgresql": | ||||
| return dialect.type_descriptor(UUID()) | return dialect.type_descriptor(UUID()) | ||||
| else: | else: | ||||
| return dialect.type_descriptor(CHAR(36)) | return dialect.type_descriptor(CHAR(36)) | ||||
| def process_result_value(self, value, dialect): | |||||
| def process_result_value(self, value: uuid.UUID | str | None, dialect: Dialect) -> str | None: | |||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| return str(value) | return str(value) | ||||
| _E = TypeVar("_E", bound=enum.StrEnum) | _E = TypeVar("_E", bound=enum.StrEnum) | ||||
| class EnumText(TypeDecorator, Generic[_E]): | |||||
| class EnumText(TypeDecorator[_E | None], Generic[_E]): | |||||
| impl = VARCHAR | impl = VARCHAR | ||||
| cache_ok = True | cache_ok = True | ||||
| # leave some rooms for future longer enum values. | # leave some rooms for future longer enum values. | ||||
| self._length = max(max_enum_value_len, 20) | self._length = max(max_enum_value_len, 20) | ||||
| def process_bind_param(self, value: _E | str | None, dialect): | |||||
| def process_bind_param(self, value: _E | str | None, dialect: Dialect) -> str | None: | |||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| if isinstance(value, self._enum_class): | if isinstance(value, self._enum_class): | ||||
| return value.value | return value.value | ||||
| elif isinstance(value, str): | |||||
| self._enum_class(value) | |||||
| return value | |||||
| else: | |||||
| raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") | |||||
| # Since _E is bound to StrEnum which inherits from str, at this point value must be str | |||||
| self._enum_class(value) | |||||
| return value | |||||
| def load_dialect_impl(self, dialect): | |||||
| def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: | |||||
| return dialect.type_descriptor(VARCHAR(self._length)) | return dialect.type_descriptor(VARCHAR(self._length)) | ||||
| def process_result_value(self, value, dialect) -> _E | None: | |||||
| def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None: | |||||
| if value is None: | if value is None: | ||||
| return value | return value | ||||
| if not isinstance(value, str): | |||||
| raise TypeError(f"expected str, got {type(value)}") | |||||
| # Type annotation guarantees value is str at this point | |||||
| return self._enum_class(value) | return self._enum_class(value) | ||||
| def compare_values(self, x, y): | |||||
| def compare_values(self, x: _E | None, y: _E | None) -> bool: | |||||
| if x is None or y is None: | if x is None or y is None: | ||||
| return x is y | return x is y | ||||
| return x == y | return x == y |
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from datetime import datetime | from datetime import datetime | ||||
| from enum import Enum, StrEnum | from enum import Enum, StrEnum | ||||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| raise WorkflowDataError("nodes not found in workflow graph") | raise WorkflowDataError("nodes not found in workflow graph") | ||||
| try: | try: | ||||
| node_config = next(filter(lambda node: node["id"] == node_id, nodes)) | |||||
| node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) | |||||
| except StopIteration: | except StopIteration: | ||||
| raise NodeNotFoundError(node_id) | raise NodeNotFoundError(node_id) | ||||
| assert isinstance(node_config, dict) | assert isinstance(node_config, dict) | ||||
| def features_dict(self) -> dict[str, Any]: | def features_dict(self) -> dict[str, Any]: | ||||
| return json.loads(self.features) if self.features else {} | return json.loads(self.features) if self.features else {} | ||||
| def user_input_form(self, to_old_structure: bool = False): | |||||
| def user_input_form(self, to_old_structure: bool = False) -> list[Any]: | |||||
| # get start node from graph | # get start node from graph | ||||
| if not self.graph: | if not self.graph: | ||||
| return [] | return [] | ||||
| variables: list[Any] = start_node.get("data", {}).get("variables", []) | variables: list[Any] = start_node.get("data", {}).get("variables", []) | ||||
| if to_old_structure: | if to_old_structure: | ||||
| old_structure_variables = [] | |||||
| old_structure_variables: list[dict[str, Any]] = [] | |||||
| for variable in variables: | for variable in variables: | ||||
| old_structure_variables.append({variable["type"]: variable}) | old_structure_variables.append({variable["type"]: variable}) | ||||
| @property | @property | ||||
| def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: | def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: | ||||
| # TODO: find some way to init `self._environment_variables` when instance created. | |||||
| if self._environment_variables is None: | |||||
| self._environment_variables = "{}" | |||||
| # _environment_variables is guaranteed to be non-None due to server_default="{}" | |||||
| # Use workflow.tenant_id to avoid relying on request user in background threads | # Use workflow.tenant_id to avoid relying on request user in background threads | ||||
| tenant_id = self.tenant_id | tenant_id = self.tenant_id | ||||
| ] | ] | ||||
| # decrypt secret variables value | # decrypt secret variables value | ||||
| def decrypt_func(var): | |||||
| def decrypt_func(var: Variable) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: | |||||
| if isinstance(var, SecretVariable): | if isinstance(var, SecretVariable): | ||||
| return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) | return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) | ||||
| elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): | elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): | ||||
| return var | return var | ||||
| else: | else: | ||||
| raise AssertionError("this statement should be unreachable.") | |||||
| # Other variable types are not supported for environment variables | |||||
| raise AssertionError(f"Unexpected variable type for environment variable: {type(var)}") | |||||
| decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( | |||||
| map(decrypt_func, results) | |||||
| ) | |||||
| decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = [ | |||||
| decrypt_func(var) for var in results | |||||
| ] | |||||
| return decrypted_results | return decrypted_results | ||||
| @environment_variables.setter | @environment_variables.setter | ||||
| value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) | value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) | ||||
| # encrypt secret variables value | # encrypt secret variables value | ||||
| def encrypt_func(var): | |||||
| def encrypt_func(var: Variable) -> Variable: | |||||
| if isinstance(var, SecretVariable): | if isinstance(var, SecretVariable): | ||||
| return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) | return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) | ||||
| else: | else: | ||||
| @property | @property | ||||
| def conversation_variables(self) -> Sequence[Variable]: | def conversation_variables(self) -> Sequence[Variable]: | ||||
| # TODO: find some way to init `self._conversation_variables` when instance created. | |||||
| if self._conversation_variables is None: | |||||
| self._conversation_variables = "{}" | |||||
| # _conversation_variables is guaranteed to be non-None due to server_default="{}" | |||||
| variables_dict: dict[str, Any] = json.loads(self._conversation_variables) | variables_dict: dict[str, Any] = json.loads(self._conversation_variables) | ||||
| results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] | results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] | ||||
| } | } | ||||
| @classmethod | @classmethod | ||||
| def from_dict(cls, data: dict) -> "WorkflowRun": | |||||
| def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun": | |||||
| return cls( | return cls( | ||||
| id=data.get("id"), | id=data.get("id"), | ||||
| tenant_id=data.get("tenant_id"), | tenant_id=data.get("tenant_id"), | ||||
| __tablename__ = "workflow_node_executions" | __tablename__ = "workflow_node_executions" | ||||
| @declared_attr | @declared_attr | ||||
| def __table_args__(cls): # noqa | |||||
| @classmethod | |||||
| def __table_args__(cls) -> Any: | |||||
| return ( | return ( | ||||
| PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), | PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), | ||||
| Index( | Index( | ||||
| # MyPy may flag the following line because it doesn't recognize that | # MyPy may flag the following line because it doesn't recognize that | ||||
| # the `declared_attr` decorator passes the receiving class as the first | # the `declared_attr` decorator passes the receiving class as the first | ||||
| # argument to this method, allowing us to reference class attributes. | # argument to this method, allowing us to reference class attributes. | ||||
| cls.created_at.desc(), # type: ignore | |||||
| cls.created_at.desc(), | |||||
| ), | ), | ||||
| ) | ) | ||||
| return json.loads(self.execution_metadata) if self.execution_metadata else {} | return json.loads(self.execution_metadata) if self.execution_metadata else {} | ||||
| @property | @property | ||||
| def extras(self): | |||||
| def extras(self) -> dict[str, Any]: | |||||
| from core.tools.tool_manager import ToolManager | from core.tools.tool_manager import ToolManager | ||||
| extras = {} | |||||
| extras: dict[str, Any] = {} | |||||
| if self.execution_metadata_dict: | if self.execution_metadata_dict: | ||||
| from core.workflow.nodes import NodeType | from core.workflow.nodes import NodeType | ||||
| if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: | if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: | ||||
| tool_info = self.execution_metadata_dict["tool_info"] | |||||
| tool_info: dict[str, Any] = self.execution_metadata_dict["tool_info"] | |||||
| extras["icon"] = ToolManager.get_tool_icon( | extras["icon"] = ToolManager.get_tool_icon( | ||||
| tenant_id=self.tenant_id, | tenant_id=self.tenant_id, | ||||
| provider_type=tool_info["provider_type"], | provider_type=tool_info["provider_type"], | ||||
| # making this attribute harder to access from outside the class. | # making this attribute harder to access from outside the class. | ||||
| __value: Segment | None | __value: Segment | None | ||||
| def __init__(self, *args, **kwargs): | |||||
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |||||
| """ | """ | ||||
| The constructor of `WorkflowDraftVariable` is not intended for | The constructor of `WorkflowDraftVariable` is not intended for | ||||
| direct use outside this file. Its solo purpose is setup private state | direct use outside this file. Its solo purpose is setup private state | ||||
| self.__value = None | self.__value = None | ||||
| def get_selector(self) -> list[str]: | def get_selector(self) -> list[str]: | ||||
| selector = json.loads(self.selector) | |||||
| selector: Any = json.loads(self.selector) | |||||
| if not isinstance(selector, list): | if not isinstance(selector, list): | ||||
| logger.error( | logger.error( | ||||
| "invalid selector loaded from database, type=%s, value=%s", | "invalid selector loaded from database, type=%s, value=%s", | ||||
| type(selector), | |||||
| type(selector).__name__, | |||||
| self.selector, | self.selector, | ||||
| ) | ) | ||||
| raise ValueError("invalid selector.") | raise ValueError("invalid selector.") | ||||
| return selector | |||||
| return cast(list[str], selector) | |||||
| def _set_selector(self, value: list[str]): | def _set_selector(self, value: list[str]): | ||||
| self.selector = json.dumps(value) | self.selector = json.dumps(value) | ||||
| # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. | # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging. | ||||
| if isinstance(value, dict): | if isinstance(value, dict): | ||||
| if not maybe_file_object(value): | if not maybe_file_object(value): | ||||
| return value | |||||
| return cast(Any, value) | |||||
| return File.model_validate(value) | return File.model_validate(value) | ||||
| elif isinstance(value, list) and value: | elif isinstance(value, list) and value: | ||||
| first = value[0] | |||||
| value_list = cast(list[Any], value) | |||||
| first: Any = value_list[0] | |||||
| if not maybe_file_object(first): | if not maybe_file_object(first): | ||||
| return value | |||||
| return [File.model_validate(i) for i in value] | |||||
| return cast(Any, value) | |||||
| file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] | |||||
| return cast(Any, file_list) | |||||
| else: | else: | ||||
| return value | |||||
| return cast(Any, value) | |||||
| @classmethod | @classmethod | ||||
| def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: | def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment: |
| "tests/", | "tests/", | ||||
| "migrations/", | "migrations/", | ||||
| ".venv/", | ".venv/", | ||||
| "models/", | |||||
| "core/", | "core/", | ||||
| "controllers/", | "controllers/", | ||||
| "tasks/", | "tasks/", |
| import threading | import threading | ||||
| from typing import Optional | |||||
| from typing import Any, Optional | |||||
| import pytz | import pytz | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| if not app_model_config: | if not app_model_config: | ||||
| raise ValueError("App model config not found") | raise ValueError("App model config not found") | ||||
| result = { | |||||
| result: dict[str, Any] = { | |||||
| "meta": { | "meta": { | ||||
| "status": "success", | "status": "success", | ||||
| "executor": executor, | "executor": executor, |
| # get original app model config | # get original app model config | ||||
| if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: | if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: | ||||
| model_config = app.app_model_config | model_config = app.app_model_config | ||||
| if not model_config: | |||||
| return app | |||||
| agent_mode = model_config.agent_mode_dict | agent_mode = model_config.agent_mode_dict | ||||
| # decrypt agent tool parameters if it's secret-input | # decrypt agent tool parameters if it's secret-input | ||||
| for tool in agent_mode.get("tools") or []: | for tool in agent_mode.get("tools") or []: | ||||
| pass | pass | ||||
| # override agent mode | # override agent mode | ||||
| model_config.agent_mode = json.dumps(agent_mode) | |||||
| if model_config: | |||||
| model_config.agent_mode = json.dumps(agent_mode) | |||||
| class ModifiedApp(App): | class ModifiedApp(App): | ||||
| """ | """ |
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.enums import MessageStatus | from models.enums import MessageStatus | ||||
| from models.model import App, AppMode, AppModelConfig, Message | |||||
| from models.model import App, AppMode, Message | |||||
| from services.errors.audio import ( | from services.errors.audio import ( | ||||
| AudioTooLargeServiceError, | AudioTooLargeServiceError, | ||||
| NoAudioUploadedServiceError, | NoAudioUploadedServiceError, | ||||
| if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): | if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): | ||||
| raise ValueError("Speech to text is not enabled") | raise ValueError("Speech to text is not enabled") | ||||
| else: | else: | ||||
| app_model_config: AppModelConfig = app_model.app_model_config | |||||
| app_model_config = app_model.app_model_config | |||||
| if not app_model_config: | |||||
| raise ValueError("Speech to text is not enabled") | |||||
| if not app_model_config.speech_to_text_dict["enabled"]: | if not app_model_config.speech_to_text_dict["enabled"]: | ||||
| raise ValueError("Speech to text is not enabled") | raise ValueError("Speech to text is not enabled") |
| file_ids = [ | file_ids = [ | ||||
| document.data_source_info_dict["upload_file_id"] | document.data_source_info_dict["upload_file_id"] | ||||
| for document in documents | for document in documents | ||||
| if document.data_source_type == "upload_file" | |||||
| if document.data_source_type == "upload_file" and document.data_source_info_dict | |||||
| ] | ] | ||||
| batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) | batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) | ||||
| # sync document indexing | # sync document indexing | ||||
| document.indexing_status = "waiting" | document.indexing_status = "waiting" | ||||
| data_source_info = document.data_source_info_dict | data_source_info = document.data_source_info_dict | ||||
| data_source_info["mode"] = "scrape" | |||||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||||
| if data_source_info: | |||||
| data_source_info["mode"] = "scrape" | |||||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| ) | ) | ||||
| if external_knowledge_api is None: | if external_knowledge_api is None: | ||||
| raise ValueError("api template not found") | raise ValueError("api template not found") | ||||
| if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: | |||||
| args.get("settings")["api_key"] = external_knowledge_api.settings_dict.get("api_key") | |||||
| settings = args.get("settings") | |||||
| if settings and settings.get("api_key") == HIDDEN_VALUE and external_knowledge_api.settings_dict: | |||||
| settings["api_key"] = external_knowledge_api.settings_dict.get("api_key") | |||||
| external_knowledge_api.name = args.get("name") | external_knowledge_api.name = args.get("name") | ||||
| external_knowledge_api.description = args.get("description", "") | external_knowledge_api.description = args.get("description", "") |
| def update_mcp_provider_credentials( | def update_mcp_provider_credentials( | ||||
| cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False | cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False | ||||
| ): | ): | ||||
| provider_controller = MCPToolProviderController._from_db(mcp_provider) | |||||
| provider_controller = MCPToolProviderController.from_db(mcp_provider) | |||||
| tool_configuration = ProviderConfigEncrypter( | tool_configuration = ProviderConfigEncrypter( | ||||
| tenant_id=mcp_provider.tenant_id, | tenant_id=mcp_provider.tenant_id, | ||||
| config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] | config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] |
| TestCase( | TestCase( | ||||
| name="session insert with invalid type", | name="session insert with invalid type", | ||||
| action=lambda s: _session_insert_with_value(s, 1), | action=lambda s: _session_insert_with_value(s, 1), | ||||
| exc_type=TypeError, | |||||
| exc_type=ValueError, | |||||
| ), | ), | ||||
| TestCase( | TestCase( | ||||
| name="insert with invalid value", | name="insert with invalid value", | ||||
| TestCase( | TestCase( | ||||
| name="insert with invalid type", | name="insert with invalid type", | ||||
| action=lambda s: _insert_with_user(s, 1), | action=lambda s: _insert_with_user(s, 1), | ||||
| exc_type=TypeError, | |||||
| exc_type=ValueError, | |||||
| ), | ), | ||||
| ] | ] | ||||
| for idx, c in enumerate(cases, 1): | for idx, c in enumerate(cases, 1): |