| environment_variables=self._workflow.environment_variables, | environment_variables=self._workflow.environment_variables, | ||||
| # Based on the definition of `VariableUnion`, | # Based on the definition of `VariableUnion`, | ||||
| # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. | # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. | ||||
| conversation_variables=cast(list[VariableUnion], conversation_variables), | |||||
| conversation_variables=conversation_variables, | |||||
| ) | ) | ||||
| # init graph | # init graph |
| from libs import rsa | from libs import rsa | ||||
| def obfuscated_token(token: str): | |||||
| def obfuscated_token(token: str) -> str: | |||||
| if not token: | if not token: | ||||
| return token | return token | ||||
| if len(token) <= 8: | if len(token) <= 8: |
| """ | """ | ||||
| if not isinstance(self.model_type_instance, LargeLanguageModel): | if not isinstance(self.model_type_instance, LargeLanguageModel): | ||||
| raise Exception("Model type instance is not LargeLanguageModel") | raise Exception("Model type instance is not LargeLanguageModel") | ||||
| self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) | |||||
| return cast( | return cast( | ||||
| Union[LLMResult, Generator], | Union[LLMResult, Generator], | ||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| """ | """ | ||||
| if not isinstance(self.model_type_instance, LargeLanguageModel): | if not isinstance(self.model_type_instance, LargeLanguageModel): | ||||
| raise Exception("Model type instance is not LargeLanguageModel") | raise Exception("Model type instance is not LargeLanguageModel") | ||||
| self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) | |||||
| return cast( | return cast( | ||||
| int, | int, | ||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| """ | """ | ||||
| if not isinstance(self.model_type_instance, TextEmbeddingModel): | if not isinstance(self.model_type_instance, TextEmbeddingModel): | ||||
| raise Exception("Model type instance is not TextEmbeddingModel") | raise Exception("Model type instance is not TextEmbeddingModel") | ||||
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | |||||
| return cast( | return cast( | ||||
| TextEmbeddingResult, | TextEmbeddingResult, | ||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| """ | """ | ||||
| if not isinstance(self.model_type_instance, TextEmbeddingModel): | if not isinstance(self.model_type_instance, TextEmbeddingModel): | ||||
| raise Exception("Model type instance is not TextEmbeddingModel") | raise Exception("Model type instance is not TextEmbeddingModel") | ||||
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | |||||
| return cast( | return cast( | ||||
| list[int], | list[int], | ||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| """ | """ | ||||
| if not isinstance(self.model_type_instance, RerankModel): | if not isinstance(self.model_type_instance, RerankModel): | ||||
| raise Exception("Model type instance is not RerankModel") | raise Exception("Model type instance is not RerankModel") | ||||
| self.model_type_instance = cast(RerankModel, self.model_type_instance) | |||||
| return cast( | return cast( | ||||
| RerankResult, | RerankResult, | ||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| """ | """ | ||||
| if not isinstance(self.model_type_instance, ModerationModel): | if not isinstance(self.model_type_instance, ModerationModel): | ||||
| raise Exception("Model type instance is not ModerationModel") | raise Exception("Model type instance is not ModerationModel") | ||||
| self.model_type_instance = cast(ModerationModel, self.model_type_instance) | |||||
| return cast( | return cast( | ||||
| bool, | bool, | ||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| """ | """ | ||||
| if not isinstance(self.model_type_instance, Speech2TextModel): | if not isinstance(self.model_type_instance, Speech2TextModel): | ||||
| raise Exception("Model type instance is not Speech2TextModel") | raise Exception("Model type instance is not Speech2TextModel") | ||||
| self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) | |||||
| return cast( | return cast( | ||||
| str, | str, | ||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| """ | """ | ||||
| if not isinstance(self.model_type_instance, TTSModel): | if not isinstance(self.model_type_instance, TTSModel): | ||||
| raise Exception("Model type instance is not TTSModel") | raise Exception("Model type instance is not TTSModel") | ||||
| self.model_type_instance = cast(TTSModel, self.model_type_instance) | |||||
| return cast( | return cast( | ||||
| Iterable[bytes], | Iterable[bytes], | ||||
| self._round_robin_invoke( | self._round_robin_invoke( | ||||
| """ | """ | ||||
| if not isinstance(self.model_type_instance, TTSModel): | if not isinstance(self.model_type_instance, TTSModel): | ||||
| raise Exception("Model type instance is not TTSModel") | raise Exception("Model type instance is not TTSModel") | ||||
| self.model_type_instance = cast(TTSModel, self.model_type_instance) | |||||
| return self.model_type_instance.get_tts_model_voices( | return self.model_type_instance.get_tts_model_voices( | ||||
| model=self.model, credentials=self.credentials, language=language | model=self.model, credentials=self.credentials, language=language | ||||
| ) | ) |
| if isinstance(prompt_message.content, list): | if isinstance(prompt_message.content, list): | ||||
| for content in prompt_message.content: | for content in prompt_message.content: | ||||
| if content.type == PromptMessageContentType.TEXT: | if content.type == PromptMessageContentType.TEXT: | ||||
| content = cast(TextPromptMessageContent, content) | |||||
| text += content.data | text += content.data | ||||
| else: | else: | ||||
| content = cast(ImagePromptMessageContent, content) | content = cast(ImagePromptMessageContent, content) |
| import json | import json | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from typing import Any, Optional, cast | |||||
| from typing import Any, Optional | |||||
| from sqlalchemy import select | from sqlalchemy import select | ||||
| from sqlalchemy.exc import IntegrityError | from sqlalchemy.exc import IntegrityError | ||||
| for provider_entity in provider_entities: | for provider_entity in provider_entities: | ||||
| # handle include, exclude | # handle include, exclude | ||||
| if is_filtered( | if is_filtered( | ||||
| include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET), | |||||
| exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET), | |||||
| include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, | |||||
| exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, | |||||
| data=provider_entity, | data=provider_entity, | ||||
| name_func=lambda x: x.provider, | name_func=lambda x: x.provider, | ||||
| ): | ): |
| import uuid | import uuid | ||||
| from collections.abc import Generator, Iterable, Sequence | from collections.abc import Generator, Iterable, Sequence | ||||
| from itertools import islice | from itertools import islice | ||||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||||
| from typing import TYPE_CHECKING, Any, Optional, Union | |||||
| import qdrant_client | import qdrant_client | ||||
| from flask import current_app | from flask import current_app | ||||
| def _reload_if_needed(self): | def _reload_if_needed(self): | ||||
| if isinstance(self._client, QdrantLocal): | if isinstance(self._client, QdrantLocal): | ||||
| self._client = cast(QdrantLocal, self._client) | |||||
| self._client._load() | self._client._load() | ||||
| @classmethod | @classmethod |
| import re | import re | ||||
| from pathlib import Path | from pathlib import Path | ||||
| from typing import Optional, cast | |||||
| from typing import Optional | |||||
| from core.rag.extractor.extractor_base import BaseExtractor | from core.rag.extractor.extractor_base import BaseExtractor | ||||
| from core.rag.extractor.helpers import detect_file_encodings | from core.rag.extractor.helpers import detect_file_encodings | ||||
| markdown_tups.append((current_header, current_text)) | markdown_tups.append((current_header, current_text)) | ||||
| markdown_tups = [ | markdown_tups = [ | ||||
| (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value)) | |||||
| (re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value)) | |||||
| for key, value in markdown_tups | for key, value in markdown_tups | ||||
| ] | ] | ||||
| f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}" | f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}" | ||||
| ) | ) | ||||
| return cast(str, data_source_binding.access_token) | |||||
| return data_source_binding.access_token |
| import contextlib | import contextlib | ||||
| from collections.abc import Iterator | from collections.abc import Iterator | ||||
| from typing import Optional, cast | |||||
| from typing import Optional | |||||
| from core.rag.extractor.blob.blob import Blob | from core.rag.extractor.blob.blob import Blob | ||||
| from core.rag.extractor.extractor_base import BaseExtractor | from core.rag.extractor.extractor_base import BaseExtractor | ||||
| plaintext_file_exists = False | plaintext_file_exists = False | ||||
| if self._file_cache_key: | if self._file_cache_key: | ||||
| with contextlib.suppress(FileNotFoundError): | with contextlib.suppress(FileNotFoundError): | ||||
| text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") | |||||
| text = storage.load(self._file_cache_key).decode("utf-8") | |||||
| plaintext_file_exists = True | plaintext_file_exists = True | ||||
| return [Document(page_content=text)] | return [Document(page_content=text)] | ||||
| documents = list(self.load()) | documents = list(self.load()) |
| if controller_tools is None or len(controller_tools) == 0: | if controller_tools is None or len(controller_tools) == 0: | ||||
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") | raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") | ||||
| return cast( | |||||
| WorkflowTool, | |||||
| controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( | |||||
| runtime=ToolRuntime( | |||||
| tenant_id=tenant_id, | |||||
| credentials={}, | |||||
| invoke_from=invoke_from, | |||||
| tool_invoke_from=tool_invoke_from, | |||||
| ) | |||||
| ), | |||||
| return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( | |||||
| runtime=ToolRuntime( | |||||
| tenant_id=tenant_id, | |||||
| credentials={}, | |||||
| invoke_from=invoke_from, | |||||
| tool_invoke_from=tool_invoke_from, | |||||
| ) | |||||
| ) | ) | ||||
| elif provider_type == ToolProviderType.APP: | elif provider_type == ToolProviderType.APP: | ||||
| raise NotImplementedError("app provider not implemented") | raise NotImplementedError("app provider not implemented") | ||||
| for provider in builtin_providers: | for provider in builtin_providers: | ||||
| # handle include, exclude | # handle include, exclude | ||||
| if is_filtered( | if is_filtered( | ||||
| include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), | |||||
| exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), | |||||
| include_set=dify_config.POSITION_TOOL_INCLUDES_SET, | |||||
| exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, | |||||
| data=provider, | data=provider, | ||||
| name_func=lambda x: x.identity.name, | name_func=lambda x: x.identity.name, | ||||
| ): | ): |
| from datetime import date, datetime | from datetime import date, datetime | ||||
| from decimal import Decimal | from decimal import Decimal | ||||
| from mimetypes import guess_extension | from mimetypes import guess_extension | ||||
| from typing import Optional, cast | |||||
| from typing import Optional | |||||
| from uuid import UUID | from uuid import UUID | ||||
| import numpy as np | import numpy as np | ||||
| elif message.type == ToolInvokeMessage.MessageType.JSON: | elif message.type == ToolInvokeMessage.MessageType.JSON: | ||||
| if isinstance(message.message, ToolInvokeMessage.JsonMessage): | if isinstance(message.message, ToolInvokeMessage.JsonMessage): | ||||
| json_msg = cast(ToolInvokeMessage.JsonMessage, message.message) | |||||
| json_msg.json_object = safe_json_value(json_msg.json_object) | |||||
| message.message.json_object = safe_json_value(message.message.json_object) | |||||
| yield message | yield message | ||||
| else: | else: | ||||
| yield message | yield message |
| db.session.commit() | db.session.commit() | ||||
| try: | try: | ||||
| response: LLMResult = cast( | |||||
| LLMResult, | |||||
| model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | |||||
| model_parameters=model_parameters, | |||||
| tools=[], | |||||
| stop=[], | |||||
| stream=False, | |||||
| user=user_id, | |||||
| callbacks=[], | |||||
| ), | |||||
| response: LLMResult = model_instance.invoke_llm( | |||||
| prompt_messages=prompt_messages, | |||||
| model_parameters=model_parameters, | |||||
| tools=[], | |||||
| stop=[], | |||||
| stream=False, | |||||
| user=user_id, | |||||
| callbacks=[], | |||||
| ) | ) | ||||
| except InvokeRateLimitError as e: | except InvokeRateLimitError as e: | ||||
| raise InvokeModelError(f"Invoke rate limit error: {e}") | raise InvokeModelError(f"Invoke rate limit error: {e}") |
| import json | import json | ||||
| import logging | import logging | ||||
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import Any, Optional, cast | |||||
| from typing import Any, Optional | |||||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | ||||
| from core.tools.__base.tool import Tool | from core.tools.__base.tool import Tool | ||||
| item = self._update_file_mapping(item) | item = self._update_file_mapping(item) | ||||
| file = build_from_mapping( | file = build_from_mapping( | ||||
| mapping=item, | mapping=item, | ||||
| tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), | |||||
| tenant_id=str(self.runtime.tenant_id), | |||||
| ) | ) | ||||
| files.append(file) | files.append(file) | ||||
| elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: | ||||
| value = self._update_file_mapping(value) | value = self._update_file_mapping(value) | ||||
| file = build_from_mapping( | file = build_from_mapping( | ||||
| mapping=value, | mapping=value, | ||||
| tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id), | |||||
| tenant_id=str(self.runtime.tenant_id), | |||||
| ) | ) | ||||
| files.append(file) | files.append(file) | ||||
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| from typing import Annotated, TypeAlias, cast | |||||
| from typing import Annotated, TypeAlias | |||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| from pydantic import Discriminator, Field, Tag | from pydantic import Discriminator, Field, Tag | ||||
| @property | @property | ||||
| def log(self) -> str: | def log(self) -> str: | ||||
| return cast(str, encrypter.obfuscated_token(self.value)) | |||||
| return encrypter.obfuscated_token(self.value) | |||||
| class NoneVariable(NoneSegment, Variable): | class NoneVariable(NoneSegment, Variable): |
| if len(sub_edge_mappings) == 0: | if len(sub_edge_mappings) == 0: | ||||
| continue | continue | ||||
| edge = cast(GraphEdge, sub_edge_mappings[0]) | |||||
| edge = sub_edge_mappings[0] | |||||
| if edge.run_condition is None: | if edge.run_condition is None: | ||||
| logger.warning("Edge %s run condition is None", edge.target_node_id) | logger.warning("Edge %s run condition is None", edge.target_node_id) | ||||
| continue | continue |
| messages=message_stream, | messages=message_stream, | ||||
| tool_info={ | tool_info={ | ||||
| "icon": self.agent_strategy_icon, | "icon": self.agent_strategy_icon, | ||||
| "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, | |||||
| "agent_strategy": self._node_data.agent_strategy_name, | |||||
| }, | }, | ||||
| parameters_for_log=parameters_for_log, | parameters_for_log=parameters_for_log, | ||||
| user_id=self.user_id, | user_id=self.user_id, | ||||
| current_plugin = next( | current_plugin = next( | ||||
| plugin | plugin | ||||
| for plugin in plugins | for plugin in plugins | ||||
| if f"{plugin.plugin_id}/{plugin.name}" | |||||
| == cast(AgentNodeData, self._node_data).agent_strategy_provider_name | |||||
| if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name | |||||
| ) | ) | ||||
| icon = current_plugin.declaration.icon | icon = current_plugin.declaration.icon | ||||
| except StopIteration: | except StopIteration: |
| encoding = "utf-8" | encoding = "utf-8" | ||||
| yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) | yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) | ||||
| return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) | |||||
| return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) | |||||
| except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: | except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: | ||||
| # If decoding fails, try with utf-8 as last resort | # If decoding fails, try with utf-8 as last resort | ||||
| try: | try: | ||||
| yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) | yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) | ||||
| return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) | |||||
| return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) | |||||
| except (UnicodeDecodeError, yaml.YAMLError): | except (UnicodeDecodeError, yaml.YAMLError): | ||||
| raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e | raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e | ||||
| """ | """ | ||||
| Run the node. | Run the node. | ||||
| """ | """ | ||||
| node_data = cast(ParameterExtractorNodeData, self._node_data) | |||||
| node_data = self._node_data | |||||
| variable = self.graph_runtime_state.variable_pool.get(node_data.query) | variable = self.graph_runtime_state.variable_pool.get(node_data.query) | ||||
| query = variable.text if variable else "" | query = variable.text if variable else "" | ||||
| import json | import json | ||||
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from typing import TYPE_CHECKING, Any, Optional, cast | |||||
| from typing import TYPE_CHECKING, Any, Optional | |||||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | ||||
| from core.memory.token_buffer_memory import TokenBufferMemory | from core.memory.token_buffer_memory import TokenBufferMemory | ||||
| return "1" | return "1" | ||||
| def _run(self): | def _run(self): | ||||
| node_data = cast(QuestionClassifierNodeData, self._node_data) | |||||
| node_data = self._node_data | |||||
| variable_pool = self.graph_runtime_state.variable_pool | variable_pool = self.graph_runtime_state.variable_pool | ||||
| # extract variables | # extract variables |
| from collections.abc import Generator, Mapping, Sequence | from collections.abc import Generator, Mapping, Sequence | ||||
| from typing import Any, Optional, cast | |||||
| from typing import Any, Optional | |||||
| from sqlalchemy import select | from sqlalchemy import select | ||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| Run the tool node | Run the tool node | ||||
| """ | """ | ||||
| node_data = cast(ToolNodeData, self._node_data) | |||||
| node_data = self._node_data | |||||
| # fetch tool icon | # fetch tool icon | ||||
| tool_info = { | tool_info = { |
| import time | import time | ||||
| import uuid | import uuid | ||||
| from collections.abc import Generator, Mapping, Sequence | from collections.abc import Generator, Mapping, Sequence | ||||
| from typing import Any, Optional, cast | |||||
| from typing import Any, Optional | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.app.apps.exc import GenerateTaskStoppedError | from core.app.apps.exc import GenerateTaskStoppedError | ||||
| environment_variables=[], | environment_variables=[], | ||||
| ) | ) | ||||
| node_cls = cast(type[BaseNode], node_cls) | |||||
| # init workflow run state | # init workflow run state | ||||
| node: BaseNode = node_cls( | node: BaseNode = node_cls( | ||||
| id=str(uuid.uuid4()), | id=str(uuid.uuid4()), |
| import urllib.parse | import urllib.parse | ||||
| import uuid | import uuid | ||||
| from collections.abc import Callable, Mapping, Sequence | from collections.abc import Callable, Mapping, Sequence | ||||
| from typing import Any, cast | |||||
| from typing import Any | |||||
| import httpx | import httpx | ||||
| from sqlalchemy import select | from sqlalchemy import select | ||||
| mime_type = "" | mime_type = "" | ||||
| resp = ssrf_proxy.head(url, follow_redirects=True) | resp = ssrf_proxy.head(url, follow_redirects=True) | ||||
| resp = cast(httpx.Response, resp) | |||||
| if resp.status_code == httpx.codes.OK: | if resp.status_code == httpx.codes.OK: | ||||
| if content_disposition := resp.headers.get("Content-Disposition"): | if content_disposition := resp.headers.get("Content-Disposition"): | ||||
| filename = str(content_disposition.split("filename=")[-1].strip('"')) | filename = str(content_disposition.split("filename=")[-1].strip('"')) |
| @property | @property | ||||
| def decrypted_server_url(self) -> str: | def decrypted_server_url(self) -> str: | ||||
| return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url)) | |||||
| return encrypter.decrypt_token(self.tenant_id, self.server_url) | |||||
| @property | @property | ||||
| def masked_server_url(self) -> str: | def masked_server_url(self) -> str: |
| account.last_active_at = naive_utc_now() | account.last_active_at = naive_utc_now() | ||||
| db.session.commit() | db.session.commit() | ||||
| return cast(Account, account) | |||||
| return account | |||||
| @staticmethod | @staticmethod | ||||
| def get_account_jwt_token(account: Account) -> str: | def get_account_jwt_token(account: Account) -> str: | ||||
| db.session.commit() | db.session.commit() | ||||
| return cast(Account, account) | |||||
| return account | |||||
| @staticmethod | @staticmethod | ||||
| def update_account_password(account, password, new_password): | def update_account_password(account, password, new_password): | ||||
| def get_custom_config(tenant_id: str) -> dict: | def get_custom_config(tenant_id: str) -> dict: | ||||
| tenant = db.get_or_404(Tenant, tenant_id) | tenant = db.get_or_404(Tenant, tenant_id) | ||||
| return cast(dict, tenant.custom_config_dict) | |||||
| return tenant.custom_config_dict | |||||
| @staticmethod | @staticmethod | ||||
| def is_owner(account: Account, tenant: Tenant) -> bool: | def is_owner(account: Account, tenant: Tenant) -> bool: |
| import uuid | import uuid | ||||
| from typing import cast | |||||
| from typing import Optional | |||||
| import pandas as pd | import pandas as pd | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| if not message: | if not message: | ||||
| raise NotFound("Message Not Exists.") | raise NotFound("Message Not Exists.") | ||||
| annotation = message.annotation | |||||
| annotation: Optional[MessageAnnotation] = message.annotation | |||||
| # save the message annotation | # save the message annotation | ||||
| if annotation: | if annotation: | ||||
| annotation.content = args["answer"] | annotation.content = args["answer"] | ||||
| app_id, | app_id, | ||||
| annotation_setting.collection_binding_id, | annotation_setting.collection_binding_id, | ||||
| ) | ) | ||||
| return cast(MessageAnnotation, annotation) | |||||
| return annotation | |||||
| @classmethod | @classmethod | ||||
| def enable_app_annotation(cls, args: dict, app_id: str) -> dict: | def enable_app_annotation(cls, args: dict, app_id: str) -> dict: |
| import time | import time | ||||
| import uuid | import uuid | ||||
| from os import getenv | from os import getenv | ||||
| from typing import cast | |||||
| import pytest | import pytest | ||||
| from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams | ||||
| from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState | ||||
| from core.workflow.nodes.code.code_node import CodeNode | from core.workflow.nodes.code.code_node import CodeNode | ||||
| from core.workflow.nodes.code.entities import CodeNodeData | |||||
| from core.workflow.system_variable import SystemVariable | from core.workflow.system_variable import SystemVariable | ||||
| from models.enums import UserFrom | from models.enums import UserFrom | ||||
| from models.workflow import WorkflowType | from models.workflow import WorkflowType | ||||
| "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}}, | ||||
| } | } | ||||
| node._node_data = cast(CodeNodeData, node._node_data) | |||||
| # validate | # validate | ||||
| node._transform_result(result, node._node_data.outputs) | node._transform_result(result, node._node_data.outputs) | ||||
| ] | ] | ||||
| } | } | ||||
| node._node_data = cast(CodeNodeData, node._node_data) | |||||
| # validate | # validate | ||||
| node._transform_result(result, node._node_data.outputs) | node._transform_result(result, node._node_data.outputs) | ||||