| @@ -3,6 +3,7 @@ from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| ChatbotAppBlockingResponse, | |||
| ChatbotAppStreamResponse, | |||
| ErrorStreamResponse, | |||
| @@ -51,7 +52,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| @@ -82,7 +83,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| @@ -56,7 +56,7 @@ class AppGenerateResponseConverter(ABC): | |||
| @abstractmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[str, None, None]: | |||
| ) -> Generator[dict | str, None, None]: | |||
| raise NotImplementedError | |||
| @classmethod | |||
| @@ -3,6 +3,7 @@ from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| ChatbotAppBlockingResponse, | |||
| ChatbotAppStreamResponse, | |||
| ErrorStreamResponse, | |||
| @@ -51,7 +52,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| @@ -82,7 +83,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| @@ -3,6 +3,7 @@ from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| CompletionAppBlockingResponse, | |||
| CompletionAppStreamResponse, | |||
| ErrorStreamResponse, | |||
| @@ -50,7 +51,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[CompletionAppStreamResponse, None, None] | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| @@ -80,7 +81,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[CompletionAppStreamResponse, None, None] | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| @@ -149,7 +149,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| ) -> Union[dict, Generator[str | dict, None, None]]: | |||
| ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: | |||
| """ | |||
| Generate App response. | |||
| @@ -200,9 +200,9 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| workflow: Workflow, | |||
| node_id: str, | |||
| user: Account | EndUser, | |||
| args: dict, | |||
| args: Mapping[str, Any], | |||
| streaming: bool = True, | |||
| ) -> dict[str, Any] | Generator[str | dict, Any, None]: | |||
| ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: | |||
| """ | |||
| Generate App response. | |||
| @@ -3,6 +3,7 @@ from typing import cast | |||
| from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter | |||
| from core.app.entities.task_entities import ( | |||
| AppStreamResponse, | |||
| ErrorStreamResponse, | |||
| NodeFinishStreamResponse, | |||
| NodeStartStreamResponse, | |||
| @@ -35,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_full_response( | |||
| cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream full response. | |||
| @@ -64,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| @classmethod | |||
| def convert_stream_simple_response( | |||
| cls, stream_response: Generator[WorkflowAppStreamResponse, None, None] | |||
| cls, stream_response: Generator[AppStreamResponse, None, None] | |||
| ) -> Generator[dict | str, None, None]: | |||
| """ | |||
| Convert stream simple response. | |||
| @@ -157,7 +157,7 @@ class ProviderConfiguration(BaseModel): | |||
| """ | |||
| return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 | |||
| def get_custom_credentials(self, obfuscated: bool = False): | |||
| def get_custom_credentials(self, obfuscated: bool = False) -> dict | None: | |||
| """ | |||
| Get custom credentials. | |||
| @@ -741,11 +741,11 @@ class ProviderConfiguration(BaseModel): | |||
| model_provider_factory = ModelProviderFactory(self.tenant_id) | |||
| provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) | |||
| model_types = [] | |||
| model_types: list[ModelType] = [] | |||
| if model_type: | |||
| model_types.append(model_type) | |||
| else: | |||
| model_types = provider_schema.supported_model_types | |||
| model_types = list(provider_schema.supported_model_types) | |||
| # Group model settings by model type and model | |||
| model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict) | |||
| @@ -1065,11 +1065,11 @@ class ProviderConfigurations(BaseModel): | |||
| def values(self) -> Iterator[ProviderConfiguration]: | |||
| return iter(self.configurations.values()) | |||
| def get(self, key, default=None): | |||
| def get(self, key, default=None) -> ProviderConfiguration | None: | |||
| if "/" not in key: | |||
| key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" | |||
| return self.configurations.get(key, default) | |||
| return self.configurations.get(key, default) # type: ignore | |||
| class ProviderModelBundle(BaseModel): | |||
| @@ -20,7 +20,7 @@ class UploadFileParser: | |||
| if upload_file.extension not in IMAGE_EXTENSIONS: | |||
| return None | |||
| if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url: | |||
| if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url: | |||
| return cls.get_signed_temp_image_url(upload_file.id) | |||
| else: | |||
| # get image file base64 | |||
| @@ -48,7 +48,7 @@ class LLMGenerator: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False | |||
| prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False | |||
| ), | |||
| ) | |||
| answer = cast(str, response.message.content) | |||
| @@ -101,7 +101,7 @@ class LLMGenerator: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| prompt_messages=list(prompt_messages), | |||
| model_parameters={"max_tokens": 256, "temperature": 0}, | |||
| stream=False, | |||
| ), | |||
| @@ -110,7 +110,7 @@ class LLMGenerator: | |||
| questions = output_parser.parse(cast(str, response.message.content)) | |||
| except InvokeError: | |||
| questions = [] | |||
| except Exception as e: | |||
| except Exception: | |||
| logging.exception("Failed to generate suggested questions after answer") | |||
| questions = [] | |||
| @@ -150,7 +150,7 @@ class LLMGenerator: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| @@ -200,7 +200,7 @@ class LLMGenerator: | |||
| prompt_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| except InvokeError as e: | |||
| @@ -236,7 +236,7 @@ class LLMGenerator: | |||
| parameter_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) | |||
| @@ -248,7 +248,7 @@ class LLMGenerator: | |||
| statement_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=statement_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| rule_config["opening_statement"] = cast(str, statement_content.message.content) | |||
| @@ -301,7 +301,7 @@ class LLMGenerator: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| ) | |||
| @@ -84,6 +84,8 @@ class LargeLanguageModel(AIModel): | |||
| callbacks=callbacks, | |||
| ) | |||
| result: Union[LLMResult, Generator[LLMResultChunk, None, None]] | |||
| try: | |||
| plugin_model_manager = PluginModelManager() | |||
| result = plugin_model_manager.invoke_llm( | |||
| @@ -285,17 +285,17 @@ class ModelProviderFactory: | |||
| } | |||
| if model_type == ModelType.LLM: | |||
| return LargeLanguageModel(**init_params) | |||
| return LargeLanguageModel(**init_params) # type: ignore | |||
| elif model_type == ModelType.TEXT_EMBEDDING: | |||
| return TextEmbeddingModel(**init_params) | |||
| return TextEmbeddingModel(**init_params) # type: ignore | |||
| elif model_type == ModelType.RERANK: | |||
| return RerankModel(**init_params) | |||
| return RerankModel(**init_params) # type: ignore | |||
| elif model_type == ModelType.SPEECH2TEXT: | |||
| return Speech2TextModel(**init_params) | |||
| return Speech2TextModel(**init_params) # type: ignore | |||
| elif model_type == ModelType.MODERATION: | |||
| return ModerationModel(**init_params) | |||
| return ModerationModel(**init_params) # type: ignore | |||
| elif model_type == ModelType.TTS: | |||
| return TTSModel(**init_params) | |||
| return TTSModel(**init_params) # type: ignore | |||
| def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: | |||
| """ | |||
| @@ -119,7 +119,7 @@ class BasePluginManager: | |||
| Make a request to the plugin daemon inner API and return the response as a model. | |||
| """ | |||
| response = self._request(method, path, headers, data, params, files) | |||
| return type(**response.json()) | |||
| return type(**response.json()) # type: ignore | |||
| def _request_with_plugin_daemon_response( | |||
| self, | |||
| @@ -140,7 +140,7 @@ class BasePluginManager: | |||
| if transformer: | |||
| json_response = transformer(json_response) | |||
| rep = PluginDaemonBasicResponse[type](**json_response) | |||
| rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore | |||
| if rep.code != 0: | |||
| try: | |||
| error = PluginDaemonError(**json.loads(rep.message)) | |||
| @@ -171,7 +171,7 @@ class BasePluginManager: | |||
| line_data = None | |||
| try: | |||
| line_data = json.loads(line) | |||
| rep = PluginDaemonBasicResponse[type](**line_data) | |||
| rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore | |||
| except Exception: | |||
| # TODO modify this when line_data has code and message | |||
| if line_data and "error" in line_data: | |||
| @@ -742,7 +742,7 @@ class ProviderManager: | |||
| try: | |||
| provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) | |||
| except JSONDecodeError: | |||
| provider_credentials: dict[str, Any] = {} | |||
| provider_credentials = {} | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self._extract_secret_variables( | |||
| @@ -601,6 +601,9 @@ class DatasetRetrieval: | |||
| elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | |||
| from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||
| if retrieve_config.reranking_model is None: | |||
| raise ValueError("Reranking model is required for multiple retrieval") | |||
| tool = DatasetMultiRetrieverTool.from_dataset( | |||
| dataset_ids=[dataset.id for dataset in available_datasets], | |||
| tenant_id=tenant_id, | |||
| @@ -30,14 +30,14 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): | |||
| disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 | |||
| **kwargs: Any, | |||
| ): | |||
| def _token_encoder(text: str) -> int: | |||
| if not text: | |||
| return 0 | |||
| def _token_encoder(texts: list[str]) -> list[int]: | |||
| if not texts: | |||
| return [] | |||
| if embedding_model_instance: | |||
| return embedding_model_instance.get_text_embedding_num_tokens(texts=[text]) | |||
| return embedding_model_instance.get_text_embedding_num_tokens(texts=texts) | |||
| else: | |||
| return GPT2Tokenizer.get_num_tokens(text) | |||
| return [GPT2Tokenizer.get_num_tokens(text) for text in texts] | |||
| if issubclass(cls, TokenTextSplitter): | |||
| extra_kwargs = { | |||
| @@ -96,7 +96,6 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) | |||
| _good_splits_lengths = [] # cache the lengths of the splits | |||
| s_lens = self._length_function(splits) | |||
| for s, s_len in zip(splits, s_lens): | |||
| s_len = self._length_function(s) | |||
| if s_len < self._chunk_size: | |||
| _good_splits.append(s) | |||
| _good_splits_lengths.append(s_len) | |||
| @@ -106,7 +106,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): | |||
| def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]: | |||
| # We now want to combine these smaller pieces into medium size | |||
| # chunks to send to the LLM. | |||
| separator_len = self._length_function(separator) | |||
| separator_len = self._length_function([separator])[0] | |||
| docs = [] | |||
| current_doc: list[str] = [] | |||
| @@ -129,7 +129,9 @@ class TextSplitter(BaseDocumentTransformer, ABC): | |||
| while total > self._chunk_overlap or ( | |||
| total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0 | |||
| ): | |||
| total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0) | |||
| total -= self._length_function([current_doc[0]])[0] + ( | |||
| separator_len if len(current_doc) > 1 else 0 | |||
| ) | |||
| current_doc = current_doc[1:] | |||
| current_doc.append(d) | |||
| total += _len + (separator_len if len(current_doc) > 1 else 0) | |||
| @@ -155,7 +157,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): | |||
| raise ValueError( | |||
| "Could not import transformers python package. Please install it with `pip install transformers`." | |||
| ) | |||
| return cls(length_function=_huggingface_tokenizer_length, **kwargs) | |||
| return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs) | |||
| @classmethod | |||
| def from_tiktoken_encoder( | |||
| @@ -199,7 +201,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): | |||
| } | |||
| kwargs = {**kwargs, **extra_kwargs} | |||
| return cls(length_function=_tiktoken_encoder, **kwargs) | |||
| return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs) | |||
| def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: | |||
| """Transform sequence of documents by splitting them.""" | |||
| @@ -71,13 +71,13 @@ class Tool(ABC): | |||
| if isinstance(result, ToolInvokeMessage): | |||
| def single_generator(): | |||
| def single_generator() -> Generator[ToolInvokeMessage, None, None]: | |||
| yield result | |||
| return single_generator() | |||
| elif isinstance(result, list): | |||
| def generator(): | |||
| def generator() -> Generator[ToolInvokeMessage, None, None]: | |||
| yield from result | |||
| return generator() | |||
| @@ -109,11 +109,11 @@ class BuiltinToolProviderController(ToolProviderController): | |||
| """ | |||
| return self._get_builtin_tools() | |||
| def get_tool(self, tool_name: str) -> BuiltinTool | None: | |||
| def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore | |||
| """ | |||
| returns the tool that the provider can provide | |||
| """ | |||
| return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) | |||
| return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore | |||
| @property | |||
| def need_credentials(self) -> bool: | |||
| @@ -1,6 +1,7 @@ | |||
| from typing import Any | |||
| from core.tools.builtin_tool.provider import BuiltinToolProviderController | |||
| class AudioToolProvider(BuiltinToolProviderController): | |||
| def _validate_credentials(self, credentials: dict) -> None: | |||
| def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: | |||
| pass | |||
| @@ -27,7 +27,7 @@ class LocaltimeToTimestampTool(BuiltinTool): | |||
| timezone = None | |||
| time_format = "%Y-%m-%d %H:%M:%S" | |||
| timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) | |||
| timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore | |||
| if not timestamp: | |||
| yield self.create_text_message(f"Invalid localtime: {localtime}") | |||
| return | |||
| @@ -42,8 +42,8 @@ class LocaltimeToTimestampTool(BuiltinTool): | |||
| if isinstance(local_tz, str): | |||
| local_tz = pytz.timezone(local_tz) | |||
| local_time = datetime.strptime(localtime, time_format) | |||
| localtime = local_tz.localize(local_time) | |||
| timestamp = int(localtime.timestamp()) | |||
| localtime = local_tz.localize(local_time) # type: ignore | |||
| timestamp = int(localtime.timestamp()) # type: ignore | |||
| return timestamp | |||
| except Exception as e: | |||
| raise ToolInvokeError(str(e)) | |||
| @@ -21,7 +21,7 @@ class TimestampToLocaltimeTool(BuiltinTool): | |||
| """ | |||
| Convert timestamp to localtime | |||
| """ | |||
| timestamp = tool_parameters.get("timestamp") | |||
| timestamp: int = tool_parameters.get("timestamp", 0) | |||
| timezone = tool_parameters.get("timezone", "Asia/Shanghai") | |||
| if not timezone: | |||
| timezone = None | |||
| @@ -24,7 +24,7 @@ class TimezoneConversionTool(BuiltinTool): | |||
| current_time = tool_parameters.get("current_time") | |||
| current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai") | |||
| target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo") | |||
| target_time = self.timezone_convert(current_time, current_timezone, target_timezone) | |||
| target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore | |||
| if not target_time: | |||
| yield self.create_text_message( | |||
| f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}" | |||
| @@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController | |||
| class WebscraperProvider(BuiltinToolProviderController): | |||
| def _validate_credentials(self, credentials: dict[str, Any]) -> None: | |||
| def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: | |||
| pass | |||
| @@ -31,7 +31,7 @@ class ApiToolProviderController(ToolProviderController): | |||
| self.tools = [] | |||
| @classmethod | |||
| def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType): | |||
| def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController": | |||
| credentials_schema = [ | |||
| ProviderConfig( | |||
| name="auth_type", | |||
| @@ -44,7 +44,7 @@ class PluginToolProviderController(BuiltinToolProviderController): | |||
| ): | |||
| raise ToolProviderCredentialValidationError("Invalid credentials") | |||
| def get_tool(self, tool_name: str) -> PluginTool: | |||
| def get_tool(self, tool_name: str) -> PluginTool: # type: ignore | |||
| """ | |||
| return tool with given name | |||
| """ | |||
| @@ -61,7 +61,7 @@ class PluginToolProviderController(BuiltinToolProviderController): | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| def get_tools(self) -> list[PluginTool]: | |||
| def get_tools(self) -> list[PluginTool]: # type: ignore | |||
| """ | |||
| get all tools | |||
| """ | |||
| @@ -59,7 +59,12 @@ class PluginTool(Tool): | |||
| plugin_unique_identifier=self.plugin_unique_identifier, | |||
| ) | |||
| def get_runtime_parameters(self) -> list[ToolParameter]: | |||
| def get_runtime_parameters( | |||
| self, | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> list[ToolParameter]: | |||
| """ | |||
| get the runtime parameters | |||
| """ | |||
| @@ -76,6 +81,9 @@ class PluginTool(Tool): | |||
| provider=self.entity.identity.provider, | |||
| tool=self.entity.identity.name, | |||
| credentials=self.runtime.credentials, | |||
| conversation_id=conversation_id, | |||
| app_id=app_id, | |||
| message_id=message_id, | |||
| ) | |||
| return self.runtime_parameters | |||
| @@ -4,7 +4,7 @@ import mimetypes | |||
| from collections.abc import Generator | |||
| from os import listdir, path | |||
| from threading import Lock | |||
| from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||
| from typing import TYPE_CHECKING, Any, Union, cast | |||
| from yarl import URL | |||
| @@ -57,7 +57,7 @@ logger = logging.getLogger(__name__) | |||
| class ToolManager: | |||
| _builtin_provider_lock = Lock() | |||
| _hardcoded_providers = {} | |||
| _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} | |||
| _builtin_providers_loaded = False | |||
| _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} | |||
| @@ -203,7 +203,7 @@ class ToolManager: | |||
| if builtin_provider is None: | |||
| raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") | |||
| else: | |||
| builtin_provider: BuiltinToolProvider | None = ( | |||
| builtin_provider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) | |||
| .first() | |||
| @@ -270,9 +270,7 @@ class ToolManager: | |||
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") | |||
| controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) | |||
| controller_tools: Optional[list[Tool]] = controller.get_tools( | |||
| user_id="", tenant_id=workflow_provider.tenant_id | |||
| ) | |||
| controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) | |||
| if controller_tools is None or len(controller_tools) == 0: | |||
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") | |||
| @@ -747,18 +745,21 @@ class ToolManager: | |||
| # add tool labels | |||
| labels = ToolLabelManager.get_tool_labels(controller) | |||
| return jsonable_encoder( | |||
| { | |||
| "schema_type": provider_obj.schema_type, | |||
| "schema": provider_obj.schema, | |||
| "tools": provider_obj.tools, | |||
| "icon": icon, | |||
| "description": provider_obj.description, | |||
| "credentials": masked_credentials, | |||
| "privacy_policy": provider_obj.privacy_policy, | |||
| "custom_disclaimer": provider_obj.custom_disclaimer, | |||
| "labels": labels, | |||
| } | |||
| return cast( | |||
| dict, | |||
| jsonable_encoder( | |||
| { | |||
| "schema_type": provider_obj.schema_type, | |||
| "schema": provider_obj.schema, | |||
| "tools": provider_obj.tools, | |||
| "icon": icon, | |||
| "description": provider_obj.description, | |||
| "credentials": masked_credentials, | |||
| "privacy_policy": provider_obj.privacy_policy, | |||
| "custom_disclaimer": provider_obj.custom_disclaimer, | |||
| "labels": labels, | |||
| } | |||
| ), | |||
| ) | |||
| @classmethod | |||
| @@ -795,7 +796,8 @@ class ToolManager: | |||
| if workflow_provider is None: | |||
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") | |||
| return json.loads(workflow_provider.icon) | |||
| icon: dict = json.loads(workflow_provider.icon) | |||
| return icon | |||
| except Exception: | |||
| return {"background": "#252525", "content": "\ud83d\ude01"} | |||
| @@ -811,7 +813,8 @@ class ToolManager: | |||
| if api_provider is None: | |||
| raise ToolProviderNotFoundError(f"api provider {provider_id} not found") | |||
| return json.loads(api_provider.icon) | |||
| icon: dict = json.loads(api_provider.icon) | |||
| return icon | |||
| except Exception: | |||
| return {"background": "#252525", "content": "\ud83d\ude01"} | |||
| @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.models.document import Document as RetrievalDocument | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from services.external_knowledge_service import ExternalDatasetService | |||
| @@ -1,5 +1,5 @@ | |||
| from collections.abc import Generator | |||
| from typing import Any | |||
| from typing import Any, Optional | |||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| @@ -83,7 +83,12 @@ class DatasetRetrieverTool(Tool): | |||
| return tools | |||
| def get_runtime_parameters(self) -> list[ToolParameter]: | |||
| def get_runtime_parameters( | |||
| self, | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> list[ToolParameter]: | |||
| return [ | |||
| ToolParameter( | |||
| name="query", | |||
| @@ -101,7 +106,14 @@ class DatasetRetrieverTool(Tool): | |||
| def tool_provider_type(self) -> ToolProviderType: | |||
| return ToolProviderType.DATASET_RETRIEVAL | |||
| def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]: | |||
| def _invoke( | |||
| self, | |||
| user_id: str, | |||
| tool_parameters: dict[str, Any], | |||
| conversation_id: Optional[str] = None, | |||
| app_id: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ) -> Generator[ToolInvokeMessage, None, None]: | |||
| """ | |||
| invoke dataset retriever tool | |||
| """ | |||
| @@ -91,7 +91,7 @@ class ToolFileMessageTransformer: | |||
| ) | |||
| elif message.type == ToolInvokeMessage.MessageType.FILE: | |||
| meta = message.meta or {} | |||
| file = meta.get("file") | |||
| file = meta.get("file", None) | |||
| if isinstance(file, File): | |||
| if file.transfer_method == FileTransferMethod.TOOL_FILE: | |||
| assert file.related_id is not None | |||
| @@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils: | |||
| @classmethod | |||
| def check_is_synced( | |||
| cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration] | |||
| ) -> bool: | |||
| ): | |||
| """ | |||
| check is synced | |||
| @@ -6,7 +6,6 @@ from pydantic import Field | |||
| from core.app.app_config.entities import VariableEntity, VariableEntityType | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.plugin.entities.parameters import PluginParameterOption | |||
| from core.tools.__base.tool import Tool | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| from core.tools.entities.common_entities import I18nObject | |||
| @@ -101,7 +100,7 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) | |||
| def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: | |||
| return next(filter(lambda x: x.variable == variable_name, variables), None) | |||
| return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore | |||
| user = db_provider.user | |||
| @@ -212,7 +211,7 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| return self.tools | |||
| def get_tool(self, tool_name: str) -> Optional[Tool]: | |||
| def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore | |||
| """ | |||
| get tool by name | |||
| @@ -106,9 +106,9 @@ class WorkflowTool(Tool): | |||
| if outputs is None: | |||
| outputs = {} | |||
| else: | |||
| outputs, files = self._extract_files(outputs) | |||
| outputs, files = self._extract_files(outputs) # type: ignore | |||
| for file in files: | |||
| yield self.create_file_message(file) | |||
| yield self.create_file_message(file) # type: ignore | |||
| yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) | |||
| yield self.create_json_message(outputs) | |||
| @@ -217,7 +217,7 @@ class WorkflowTool(Tool): | |||
| :param result: the result | |||
| :return: the result, files | |||
| """ | |||
| files = [] | |||
| files: list[File] = [] | |||
| result = {} | |||
| for key, value in outputs.items(): | |||
| if isinstance(value, list): | |||
| @@ -238,4 +238,5 @@ class WorkflowTool(Tool): | |||
| files.append(file) | |||
| result[key] = value | |||
| return result, files | |||
| @@ -27,7 +27,7 @@ class AgentNode(ToolNode): | |||
| Agent Node | |||
| """ | |||
| _node_data_cls = AgentNodeData | |||
| _node_data_cls = AgentNodeData # type: ignore | |||
| _node_type = NodeType.AGENT | |||
| def _run(self) -> Generator: | |||
| @@ -125,7 +125,7 @@ class AgentNode(ToolNode): | |||
| """ | |||
| agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} | |||
| result = {} | |||
| result: dict[str, Any] = {} | |||
| for parameter_name in node_data.agent_parameters: | |||
| parameter = agent_parameters_dictionary.get(parameter_name) | |||
| if not parameter: | |||
| @@ -214,7 +214,7 @@ class AgentNode(ToolNode): | |||
| :return: | |||
| """ | |||
| node_data = cast(AgentNodeData, node_data) | |||
| result = {} | |||
| result: dict[str, Any] = {} | |||
| for parameter_name in node_data.agent_parameters: | |||
| input = node_data.agent_parameters[parameter_name] | |||
| if input.type == "mixed": | |||
| @@ -233,9 +233,9 @@ class LLMNode(BaseNode[LLMNodeData]): | |||
| db.session.close() | |||
| invoke_result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| prompt_messages=list(prompt_messages), | |||
| model_parameters=node_data_model.completion_params, | |||
| stop=stop, | |||
| stop=list(stop or []), | |||
| stream=True, | |||
| user=self.user_id, | |||
| ) | |||
| @@ -1,5 +1,5 @@ | |||
| from collections.abc import Generator, Mapping, Sequence | |||
| from typing import Any, Optional, cast | |||
| from typing import Any, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| @@ -197,7 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]): | |||
| json: list[dict] = [] | |||
| agent_logs: list[AgentLogEvent] = [] | |||
| agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {} | |||
| agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {} | |||
| variables: dict[str, Any] = {} | |||
| @@ -284,8 +284,6 @@ class WorkflowEntry: | |||
| user_inputs=user_inputs, | |||
| variable_pool=variable_pool, | |||
| tenant_id=tenant_id, | |||
| node_type=node_type, | |||
| node_data=node_instance.node_data, | |||
| ) | |||
| # run node | |||
| @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast | |||
| from zoneinfo import available_timezones | |||
| from flask import Response, stream_with_context | |||
| from flask_restful import fields | |||
| from flask_restful import fields # type: ignore | |||
| from configs import dify_config | |||
| from core.app.features.rate_limiting.rate_limit import RateLimitGenerator | |||
| @@ -102,6 +102,6 @@ def _get_user() -> EndUser | Account | None: | |||
| if "_login_user" not in g: | |||
| current_app.login_manager._load_user() # type: ignore | |||
| return g._login_user | |||
| return g._login_user # type: ignore | |||
| return None | |||
| @@ -1,7 +1,7 @@ | |||
| import enum | |||
| import json | |||
| from flask_login import UserMixin | |||
| from flask_login import UserMixin # type: ignore | |||
| from sqlalchemy import func | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| @@ -56,7 +56,7 @@ class Account(UserMixin, Base): | |||
| if ta: | |||
| tenant.current_role = ta.role | |||
| else: | |||
| tenant = None | |||
| tenant = None # type: ignore | |||
| self._current_tenant = tenant | |||
| @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast | |||
| import sqlalchemy as sa | |||
| from flask import request | |||
| from flask_login import UserMixin | |||
| from flask_login import UserMixin # type: ignore | |||
| from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| from datetime import datetime | |||
| from typing import Any, Optional | |||
| from typing import Any, Optional, cast | |||
| import sqlalchemy as sa | |||
| from deprecated import deprecated | |||
| @@ -48,7 +48,7 @@ class BuiltinToolProvider(Base): | |||
| @property | |||
| def credentials(self) -> dict: | |||
| return json.loads(self.encrypted_credentials) | |||
| return cast(dict, json.loads(self.encrypted_credentials)) | |||
| class ApiToolProvider(Base): | |||
| @@ -302,13 +302,9 @@ class DeprecatedPublishedAppTool(Base): | |||
| db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), | |||
| ) | |||
| # id of the tool provider | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # id of the app | |||
| app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) | |||
| # who published this tool | |||
| user_id = db.Column(StringUUID, nullable=False) | |||
| # description of the tool, stored in i18n format, for human | |||
| description = db.Column(db.Text, nullable=False) | |||
| # llm_description of the tool, for LLM | |||
| llm_description = db.Column(db.Text, nullable=False) | |||
| @@ -328,10 +324,6 @@ class DeprecatedPublishedAppTool(Base): | |||
| def description_i18n(self) -> I18nObject: | |||
| return I18nObject(**json.loads(self.description)) | |||
| @property | |||
| def app(self) -> App: | |||
| return db.session.query(App).filter(App.id == self.app_id).first() | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| user_id: Mapped[str] = db.Column(StringUUID, nullable=False) | |||
| tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) | |||
| @@ -23,7 +23,7 @@ class AgentService: | |||
| contexts.plugin_tool_providers.set({}) | |||
| contexts.plugin_tool_providers_lock.set(threading.Lock()) | |||
| conversation: Conversation = ( | |||
| conversation: Conversation | None = ( | |||
| db.session.query(Conversation) | |||
| .filter( | |||
| Conversation.id == conversation_id, | |||
| @@ -156,7 +156,7 @@ class DefaultModelResponse(BaseModel): | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| class ModelWithProviderEntityResponse(ModelWithProviderEntity): | |||
| class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity): | |||
| """ | |||
| Model with provider entity. | |||
| """ | |||
| @@ -173,9 +173,8 @@ class PluginMigration: | |||
| """ | |||
| Extract model tables. | |||
| NOTE: rename google to gemini | |||
| """ | |||
| models = [] | |||
| models: list[str] = [] | |||
| table_pairs = [ | |||
| ("providers", "provider_name"), | |||
| ("provider_models", "provider_name"), | |||
| @@ -439,7 +439,7 @@ class ApiToolManageService: | |||
| tenant_id=tenant_id, | |||
| ) | |||
| ) | |||
| result = runtime_tool.validate_credentials(credentials, parameters) | |||
| result = tool.validate_credentials(credentials, parameters) | |||
| except Exception as e: | |||
| return {"error": str(e)} | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| import logging | |||
| from typing import Optional, Union | |||
| from typing import Optional, Union, cast | |||
| from yarl import URL | |||
| @@ -44,7 +44,7 @@ class ToolTransformService: | |||
| elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: | |||
| try: | |||
| if isinstance(icon, str): | |||
| return json.loads(icon) | |||
| return cast(dict, json.loads(icon)) | |||
| return icon | |||
| except Exception: | |||
| return {"background": "#252525", "content": "\ud83d\ude01"} | |||
| @@ -1,7 +1,7 @@ | |||
| import json | |||
| from collections.abc import Mapping, Sequence | |||
| from collections.abc import Mapping | |||
| from datetime import datetime | |||
| from typing import Any, Optional | |||
| from typing import Any | |||
| from sqlalchemy import or_ | |||
| @@ -11,6 +11,7 @@ from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntit | |||
| from core.tools.tool_label_manager import ToolLabelManager | |||
| from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils | |||
| from core.tools.workflow_as_tool.provider import WorkflowToolProviderController | |||
| from core.tools.workflow_as_tool.tool import WorkflowTool | |||
| from extensions.ext_database import db | |||
| from models.model import App | |||
| from models.tools import WorkflowToolProvider | |||
| @@ -187,7 +188,7 @@ class WorkflowToolManageService: | |||
| """ | |||
| db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() | |||
| tools: Sequence[WorkflowToolProviderController] = [] | |||
| tools: list[WorkflowToolProviderController] = [] | |||
| for provider in db_tools: | |||
| try: | |||
| tools.append(ToolTransformService.workflow_provider_to_controller(provider)) | |||
| @@ -264,7 +265,7 @@ class WorkflowToolManageService: | |||
| return cls._get_workflow_tool(tenant_id, db_tool) | |||
| @classmethod | |||
| def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None): | |||
| def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict: | |||
| """ | |||
| Get a workflow tool. | |||
| :db_tool: the database tool | |||
| @@ -285,8 +286,8 @@ class WorkflowToolManageService: | |||
| raise ValueError("Workflow not found") | |||
| tool = ToolTransformService.workflow_provider_to_controller(db_tool) | |||
| to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(tenant_id) | |||
| if to_user_tool is None or len(to_user_tool) == 0: | |||
| workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) | |||
| if len(workflow_tools) == 0: | |||
| raise ValueError(f"Tool {db_tool.id} not found") | |||
| return { | |||
| @@ -325,8 +326,8 @@ class WorkflowToolManageService: | |||
| raise ValueError(f"Tool {workflow_tool_id} not found") | |||
| tool = ToolTransformService.workflow_provider_to_controller(db_tool) | |||
| to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(user_id, tenant_id) | |||
| if to_user_tool is None or len(to_user_tool) == 0: | |||
| workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id) | |||
| if len(workflow_tools) == 0: | |||
| raise ValueError(f"Tool {workflow_tool_id} not found") | |||
| return [ | |||
| @@ -67,7 +67,7 @@ def batch_create_segment_to_index_task( | |||
| for segment, tokens in zip(content, tokens_list): | |||
| content = segment["content"] | |||
| doc_id = str(uuid.uuid4()) | |||
| segment_hash = helper.generate_text_hash(content) | |||
| segment_hash = helper.generate_text_hash(content) # type: ignore | |||
| max_position = ( | |||
| db.session.query(func.max(DocumentSegment.position)) | |||
| .filter(DocumentSegment.document_id == dataset_document.id) | |||