소스 검색

fix: mypy issues

tags/1.0.0-beta.1
Yeuoly 9 달 전
부모
커밋
f748d6c7c4
49개의 변경된 파일157개의 추가작업 그리고 133개의 파일을 삭제
  1. 3
    2
      api/core/app/apps/agent_chat/generate_response_converter.py
  2. 1
    1
      api/core/app/apps/base_app_generate_response_converter.py
  3. 3
    2
      api/core/app/apps/chat/generate_response_converter.py
  4. 3
    2
      api/core/app/apps/completion/generate_response_converter.py
  5. 3
    3
      api/core/app/apps/workflow/app_generator.py
  6. 3
    2
      api/core/app/apps/workflow/generate_response_converter.py
  7. 5
    5
      api/core/entities/provider_configuration.py
  8. 1
    1
      api/core/file/upload_file_parser.py
  9. 8
    8
      api/core/llm_generator/llm_generator.py
  10. 2
    0
      api/core/model_runtime/model_providers/__base/large_language_model.py
  11. 6
    6
      api/core/model_runtime/model_providers/model_provider_factory.py
  12. 3
    3
      api/core/plugin/manager/base.py
  13. 1
    1
      api/core/provider_manager.py
  14. 3
    0
      api/core/rag/retrieval/dataset_retrieval.py
  15. 5
    6
      api/core/rag/splitter/fixed_text_splitter.py
  16. 6
    4
      api/core/rag/splitter/text_splitter.py
  17. 2
    2
      api/core/tools/__base/tool.py
  18. 2
    2
      api/core/tools/builtin_tool/provider.py
  19. 2
    1
      api/core/tools/builtin_tool/providers/audio/audio.py
  20. 3
    3
      api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py
  21. 1
    1
      api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py
  22. 1
    1
      api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py
  23. 1
    1
      api/core/tools/builtin_tool/providers/webscraper/webscraper.py
  24. 1
    1
      api/core/tools/custom_tool/provider.py
  25. 2
    2
      api/core/tools/plugin_tool/provider.py
  26. 9
    1
      api/core/tools/plugin_tool/tool.py
  27. 23
    20
      api/core/tools/tool_manager.py
  28. 1
    1
      api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
  29. 15
    3
      api/core/tools/utils/dataset_retriever_tool.py
  30. 1
    1
      api/core/tools/utils/message_transformer.py
  31. 1
    1
      api/core/tools/utils/workflow_configuration_sync.py
  32. 2
    3
      api/core/tools/workflow_as_tool/provider.py
  33. 4
    3
      api/core/tools/workflow_as_tool/tool.py
  34. 3
    3
      api/core/workflow/nodes/agent/agent_node.py
  35. 2
    2
      api/core/workflow/nodes/llm/node.py
  36. 2
    2
      api/core/workflow/nodes/tool/tool_node.py
  37. 0
    2
      api/core/workflow/workflow_entry.py
  38. 1
    1
      api/libs/helper.py
  39. 1
    1
      api/libs/login.py
  40. 2
    2
      api/models/account.py
  41. 1
    1
      api/models/model.py
  42. 2
    10
      api/models/tools.py
  43. 1
    1
      api/services/agent_service.py
  44. 1
    1
      api/services/entities/model_provider_entities.py
  45. 1
    2
      api/services/plugin/plugin_migration.py
  46. 1
    1
      api/services/tools/api_tools_manage_service.py
  47. 2
    2
      api/services/tools/tools_transform_service.py
  48. 9
    8
      api/services/tools/workflow_tools_manage_service.py
  49. 1
    1
      api/tasks/batch_create_segment_to_index_task.py

+ 3
- 2
api/core/app/apps/agent_chat/generate_response_converter.py 파일 보기

@@ -3,6 +3,7 @@ from typing import cast

from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@@ -51,7 +52,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):

@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -82,7 +83,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):

@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

+ 1
- 1
api/core/app/apps/base_app_generate_response_converter.py 파일 보기

@@ -56,7 +56,7 @@ class AppGenerateResponseConverter(ABC):
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
) -> Generator[dict | str, None, None]:
raise NotImplementedError

@classmethod

+ 3
- 2
api/core/app/apps/chat/generate_response_converter.py 파일 보기

@@ -3,6 +3,7 @@ from typing import cast

from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
@@ -51,7 +52,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):

@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -82,7 +83,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):

@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

+ 3
- 2
api/core/app/apps/completion/generate_response_converter.py 파일 보기

@@ -3,6 +3,7 @@ from typing import cast

from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
CompletionAppBlockingResponse,
CompletionAppStreamResponse,
ErrorStreamResponse,
@@ -50,7 +51,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):

@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -80,7 +81,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):

@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

+ 3
- 3
api/core/app/apps/workflow/app_generator.py 파일 보기

@@ -149,7 +149,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[dict, Generator[str | dict, None, None]]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.

@@ -200,9 +200,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: dict,
args: Mapping[str, Any],
streaming: bool = True,
) -> dict[str, Any] | Generator[str | dict, Any, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.


+ 3
- 2
api/core/app/apps/workflow/generate_response_converter.py 파일 보기

@@ -3,6 +3,7 @@ from typing import cast

from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
@@ -35,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):

@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -64,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):

@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

+ 5
- 5
api/core/entities/provider_configuration.py 파일 보기

@@ -157,7 +157,7 @@ class ProviderConfiguration(BaseModel):
"""
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0

def get_custom_credentials(self, obfuscated: bool = False):
def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
"""
Get custom credentials.

@@ -741,11 +741,11 @@ class ProviderConfiguration(BaseModel):
model_provider_factory = ModelProviderFactory(self.tenant_id)
provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)

model_types = []
model_types: list[ModelType] = []
if model_type:
model_types.append(model_type)
else:
model_types = provider_schema.supported_model_types
model_types = list(provider_schema.supported_model_types)

# Group model settings by model type and model
model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
@@ -1065,11 +1065,11 @@ class ProviderConfigurations(BaseModel):
def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values())

def get(self, key, default=None):
def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"

return self.configurations.get(key, default)
return self.configurations.get(key, default) # type: ignore


class ProviderModelBundle(BaseModel):

+ 1
- 1
api/core/file/upload_file_parser.py 파일 보기

@@ -20,7 +20,7 @@ class UploadFileParser:
if upload_file.extension not in IMAGE_EXTENSIONS:
return None

if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64

+ 8
- 8
api/core/llm_generator/llm_generator.py 파일 보기

@@ -48,7 +48,7 @@ class LLMGenerator:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompts, model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False
),
)
answer = cast(str, response.message.content)
@@ -101,7 +101,7 @@ class LLMGenerator:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters={"max_tokens": 256, "temperature": 0},
stream=False,
),
@@ -110,7 +110,7 @@ class LLMGenerator:
questions = output_parser.parse(cast(str, response.message.content))
except InvokeError:
questions = []
except Exception as e:
except Exception:
logging.exception("Failed to generate suggested questions after answer")
questions = []

@@ -150,7 +150,7 @@ class LLMGenerator:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)

@@ -200,7 +200,7 @@ class LLMGenerator:
prompt_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)
except InvokeError as e:
@@ -236,7 +236,7 @@ class LLMGenerator:
parameter_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=parameter_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
),
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
@@ -248,7 +248,7 @@ class LLMGenerator:
statement_content = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=statement_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
),
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
@@ -301,7 +301,7 @@ class LLMGenerator:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages, model_parameters=model_parameters, stream=False
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)


+ 2
- 0
api/core/model_runtime/model_providers/__base/large_language_model.py 파일 보기

@@ -84,6 +84,8 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)

result: Union[LLMResult, Generator[LLMResultChunk, None, None]]

try:
plugin_model_manager = PluginModelManager()
result = plugin_model_manager.invoke_llm(

+ 6
- 6
api/core/model_runtime/model_providers/model_provider_factory.py 파일 보기

@@ -285,17 +285,17 @@ class ModelProviderFactory:
}

if model_type == ModelType.LLM:
return LargeLanguageModel(**init_params)
return LargeLanguageModel(**init_params) # type: ignore
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(**init_params)
return TextEmbeddingModel(**init_params) # type: ignore
elif model_type == ModelType.RERANK:
return RerankModel(**init_params)
return RerankModel(**init_params) # type: ignore
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(**init_params)
return Speech2TextModel(**init_params) # type: ignore
elif model_type == ModelType.MODERATION:
return ModerationModel(**init_params)
return ModerationModel(**init_params) # type: ignore
elif model_type == ModelType.TTS:
return TTSModel(**init_params)
return TTSModel(**init_params) # type: ignore

def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""

+ 3
- 3
api/core/plugin/manager/base.py 파일 보기

@@ -119,7 +119,7 @@ class BasePluginManager:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type(**response.json())
return type(**response.json()) # type: ignore

def _request_with_plugin_daemon_response(
self,
@@ -140,7 +140,7 @@ class BasePluginManager:
if transformer:
json_response = transformer(json_response)

rep = PluginDaemonBasicResponse[type](**json_response)
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
if rep.code != 0:
try:
error = PluginDaemonError(**json.loads(rep.message))
@@ -171,7 +171,7 @@ class BasePluginManager:
line_data = None
try:
line_data = json.loads(line)
rep = PluginDaemonBasicResponse[type](**line_data)
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
except Exception:
# TODO modify this when line_data has code and message
if line_data and "error" in line_data:

+ 1
- 1
api/core/provider_manager.py 파일 보기

@@ -742,7 +742,7 @@ class ProviderManager:
try:
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials: dict[str, Any] = {}
provider_credentials = {}

# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(

+ 3
- 0
api/core/rag/retrieval/dataset_retrieval.py 파일 보기

@@ -601,6 +601,9 @@ class DatasetRetrieval:
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool

if retrieve_config.reranking_model is None:
raise ValueError("Reranking model is required for multiple retrieval")

tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,

+ 5
- 6
api/core/rag/splitter/fixed_text_splitter.py 파일 보기

@@ -30,14 +30,14 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
if not text:
return 0
def _token_encoder(texts: list[str]) -> list[int]:
if not texts:
return []

if embedding_model_instance:
return embedding_model_instance.get_text_embedding_num_tokens(texts=[text])
return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
else:
return GPT2Tokenizer.get_num_tokens(text)
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]

if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
@@ -96,7 +96,6 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
_good_splits_lengths = [] # cache the lengths of the splits
s_lens = self._length_function(splits)
for s, s_len in zip(splits, s_lens):
s_len = self._length_function(s)
if s_len < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)

+ 6
- 4
api/core/rag/splitter/text_splitter.py 파일 보기

@@ -106,7 +106,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
separator_len = self._length_function([separator])[0]

docs = []
current_doc: list[str] = []
@@ -129,7 +129,9 @@ class TextSplitter(BaseDocumentTransformer, ABC):
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
):
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
total -= self._length_function([current_doc[0]])[0] + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
@@ -155,7 +157,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
raise ValueError(
"Could not import transformers python package. Please install it with `pip install transformers`."
)
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)

@classmethod
def from_tiktoken_encoder(
@@ -199,7 +201,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
}
kwargs = {**kwargs, **extra_kwargs}

return cls(length_function=_tiktoken_encoder, **kwargs)
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)

def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""

+ 2
- 2
api/core/tools/__base/tool.py 파일 보기

@@ -71,13 +71,13 @@ class Tool(ABC):

if isinstance(result, ToolInvokeMessage):

def single_generator():
def single_generator() -> Generator[ToolInvokeMessage, None, None]:
yield result

return single_generator()
elif isinstance(result, list):

def generator():
def generator() -> Generator[ToolInvokeMessage, None, None]:
yield from result

return generator()

+ 2
- 2
api/core/tools/builtin_tool/provider.py 파일 보기

@@ -109,11 +109,11 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self._get_builtin_tools()

def get_tool(self, tool_name: str) -> BuiltinTool | None:
def get_tool(self, tool_name: str) -> BuiltinTool | None: # type: ignore
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore

@property
def need_credentials(self) -> bool:

+ 2
- 1
api/core/tools/builtin_tool/providers/audio/audio.py 파일 보기

@@ -1,6 +1,7 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController


class AudioToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
pass

+ 3
- 3
api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py 파일 보기

@@ -27,7 +27,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
timezone = None
time_format = "%Y-%m-%d %H:%M:%S"

timestamp = self.localtime_to_timestamp(localtime, time_format, timezone)
timestamp = self.localtime_to_timestamp(localtime, time_format, timezone) # type: ignore
if not timestamp:
yield self.create_text_message(f"Invalid localtime: {localtime}")
return
@@ -42,8 +42,8 @@ class LocaltimeToTimestampTool(BuiltinTool):
if isinstance(local_tz, str):
local_tz = pytz.timezone(local_tz)
local_time = datetime.strptime(localtime, time_format)
localtime = local_tz.localize(local_time)
timestamp = int(localtime.timestamp())
localtime = local_tz.localize(local_time) # type: ignore
timestamp = int(localtime.timestamp()) # type: ignore
return timestamp
except Exception as e:
raise ToolInvokeError(str(e))

+ 1
- 1
api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py 파일 보기

@@ -21,7 +21,7 @@ class TimestampToLocaltimeTool(BuiltinTool):
"""
Convert timestamp to localtime
"""
timestamp = tool_parameters.get("timestamp")
timestamp: int = tool_parameters.get("timestamp", 0)
timezone = tool_parameters.get("timezone", "Asia/Shanghai")
if not timezone:
timezone = None

+ 1
- 1
api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py 파일 보기

@@ -24,7 +24,7 @@ class TimezoneConversionTool(BuiltinTool):
current_time = tool_parameters.get("current_time")
current_timezone = tool_parameters.get("current_timezone", "Asia/Shanghai")
target_timezone = tool_parameters.get("target_timezone", "Asia/Tokyo")
target_time = self.timezone_convert(current_time, current_timezone, target_timezone)
target_time = self.timezone_convert(current_time, current_timezone, target_timezone) # type: ignore
if not target_time:
yield self.create_text_message(
f"Invalid datatime and timezone: {current_time},{current_timezone},{target_timezone}"

+ 1
- 1
api/core/tools/builtin_tool/providers/webscraper/webscraper.py 파일 보기

@@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController


class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
pass

+ 1
- 1
api/core/tools/custom_tool/provider.py 파일 보기

@@ -31,7 +31,7 @@ class ApiToolProviderController(ToolProviderController):
self.tools = []

@classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
credentials_schema = [
ProviderConfig(
name="auth_type",

+ 2
- 2
api/core/tools/plugin_tool/provider.py 파일 보기

@@ -44,7 +44,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
):
raise ToolProviderCredentialValidationError("Invalid credentials")

def get_tool(self, tool_name: str) -> PluginTool:
def get_tool(self, tool_name: str) -> PluginTool: # type: ignore
"""
return tool with given name
"""
@@ -61,7 +61,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
plugin_unique_identifier=self.plugin_unique_identifier,
)

def get_tools(self) -> list[PluginTool]:
def get_tools(self) -> list[PluginTool]: # type: ignore
"""
get all tools
"""

+ 9
- 1
api/core/tools/plugin_tool/tool.py 파일 보기

@@ -59,7 +59,12 @@ class PluginTool(Tool):
plugin_unique_identifier=self.plugin_unique_identifier,
)

def get_runtime_parameters(self) -> list[ToolParameter]:
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
"""
get the runtime parameters
"""
@@ -76,6 +81,9 @@ class PluginTool(Tool):
provider=self.entity.identity.provider,
tool=self.entity.identity.name,
credentials=self.runtime.credentials,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)

return self.runtime_parameters

+ 23
- 20
api/core/tools/tool_manager.py 파일 보기

@@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Union, cast

from yarl import URL

@@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)

class ToolManager:
_builtin_provider_lock = Lock()
_hardcoded_providers = {}
_hardcoded_providers: dict[str, BuiltinToolProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}

@@ -203,7 +203,7 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
else:
builtin_provider: BuiltinToolProvider | None = (
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.first()
@@ -270,9 +270,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
controller_tools: Optional[list[Tool]] = controller.get_tools(
user_id="", tenant_id=workflow_provider.tenant_id
)
controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id)
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

@@ -747,18 +745,21 @@ class ToolManager:
# add tool labels
labels = ToolLabelManager.get_tool_labels(controller)

return jsonable_encoder(
{
"schema_type": provider_obj.schema_type,
"schema": provider_obj.schema,
"tools": provider_obj.tools,
"icon": icon,
"description": provider_obj.description,
"credentials": masked_credentials,
"privacy_policy": provider_obj.privacy_policy,
"custom_disclaimer": provider_obj.custom_disclaimer,
"labels": labels,
}
return cast(
dict,
jsonable_encoder(
{
"schema_type": provider_obj.schema_type,
"schema": provider_obj.schema,
"tools": provider_obj.tools,
"icon": icon,
"description": provider_obj.description,
"credentials": masked_credentials,
"privacy_policy": provider_obj.privacy_policy,
"custom_disclaimer": provider_obj.custom_disclaimer,
"labels": labels,
}
),
)

@classmethod
@@ -795,7 +796,8 @@ class ToolManager:
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")

return json.loads(workflow_provider.icon)
icon: dict = json.loads(workflow_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

@@ -811,7 +813,8 @@ class ToolManager:
if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")

return json.loads(api_provider.icon)
icon: dict = json.loads(api_provider.icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}


+ 1
- 1
api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py 파일 보기

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from services.external_knowledge_service import ExternalDatasetService

+ 15
- 3
api/core/tools/utils/dataset_retriever_tool.py 파일 보기

@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, Optional

from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -83,7 +83,12 @@ class DatasetRetrieverTool(Tool):

return tools

def get_runtime_parameters(self) -> list[ToolParameter]:
def get_runtime_parameters(
self,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
return [
ToolParameter(
name="query",
@@ -101,7 +106,14 @@ class DatasetRetrieverTool(Tool):
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.DATASET_RETRIEVAL

def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke dataset retriever tool
"""

+ 1
- 1
api/core/tools/utils/message_transformer.py 파일 보기

@@ -91,7 +91,7 @@ class ToolFileMessageTransformer:
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file")
file = meta.get("file", None)
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None

+ 1
- 1
api/core/tools/utils/workflow_configuration_sync.py 파일 보기

@@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
) -> bool:
):
"""
check is synced


+ 2
- 3
api/core/tools/workflow_as_tool/provider.py 파일 보기

@@ -6,7 +6,6 @@ from pydantic import Field
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.plugin.entities.parameters import PluginParameterOption
from core.tools.__base.tool import Tool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
@@ -101,7 +100,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)

def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None)
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore

user = db_provider.user

@@ -212,7 +211,7 @@ class WorkflowToolProviderController(ToolProviderController):

return self.tools

def get_tool(self, tool_name: str) -> Optional[Tool]:
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: # type: ignore
"""
get tool by name


+ 4
- 3
api/core/tools/workflow_as_tool/tool.py 파일 보기

@@ -106,9 +106,9 @@ class WorkflowTool(Tool):
if outputs is None:
outputs = {}
else:
outputs, files = self._extract_files(outputs)
outputs, files = self._extract_files(outputs) # type: ignore
for file in files:
yield self.create_file_message(file)
yield self.create_file_message(file) # type: ignore

yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs)
@@ -217,7 +217,7 @@ class WorkflowTool(Tool):
:param result: the result
:return: the result, files
"""
files = []
files: list[File] = []
result = {}
for key, value in outputs.items():
if isinstance(value, list):
@@ -238,4 +238,5 @@ class WorkflowTool(Tool):
files.append(file)

result[key] = value

return result, files

+ 3
- 3
api/core/workflow/nodes/agent/agent_node.py 파일 보기

@@ -27,7 +27,7 @@ class AgentNode(ToolNode):
Agent Node
"""

_node_data_cls = AgentNodeData
_node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT

def _run(self) -> Generator:
@@ -125,7 +125,7 @@ class AgentNode(ToolNode):
"""
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}

result = {}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
@@ -214,7 +214,7 @@ class AgentNode(ToolNode):
:return:
"""
node_data = cast(AgentNodeData, node_data)
result = {}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
input = node_data.agent_parameters[parameter_name]
if input.type == "mixed":

+ 2
- 2
api/core/workflow/nodes/llm/node.py 파일 보기

@@ -233,9 +233,9 @@ class LLMNode(BaseNode[LLMNodeData]):
db.session.close()

invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
stop=stop,
stop=list(stop or []),
stream=True,
user=self.user_id,
)

+ 2
- 2
api/core/workflow/nodes/tool/tool_node.py 파일 보기

@@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast

from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -197,7 +197,7 @@ class ToolNode(BaseNode[ToolNodeData]):
json: list[dict] = []

agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = {}
agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {}

variables: dict[str, Any] = {}


+ 0
- 2
api/core/workflow/workflow_entry.py 파일 보기

@@ -284,8 +284,6 @@ class WorkflowEntry:
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=tenant_id,
node_type=node_type,
node_data=node_instance.node_data,
)

# run node

+ 1
- 1
api/libs/helper.py 파일 보기

@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
from zoneinfo import available_timezones

from flask import Response, stream_with_context
from flask_restful import fields
from flask_restful import fields # type: ignore

from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

+ 1
- 1
api/libs/login.py 파일 보기

@@ -102,6 +102,6 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore

return g._login_user
return g._login_user # type: ignore

return None

+ 2
- 2
api/models/account.py 파일 보기

@@ -1,7 +1,7 @@
import enum
import json

from flask_login import UserMixin
from flask_login import UserMixin # type: ignore
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column

@@ -56,7 +56,7 @@ class Account(UserMixin, Base):
if ta:
tenant.current_role = ta.role
else:
tenant = None
tenant = None # type: ignore

self._current_tenant = tenant


+ 1
- 1
api/models/model.py 파일 보기

@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast

import sqlalchemy as sa
from flask import request
from flask_login import UserMixin
from flask_login import UserMixin # type: ignore
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column


+ 2
- 10
api/models/tools.py 파일 보기

@@ -1,6 +1,6 @@
import json
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, cast

import sqlalchemy as sa
from deprecated import deprecated
@@ -48,7 +48,7 @@ class BuiltinToolProvider(Base):

@property
def credentials(self) -> dict:
return json.loads(self.encrypted_credentials)
return cast(dict, json.loads(self.encrypted_credentials))


class ApiToolProvider(Base):
@@ -302,13 +302,9 @@ class DeprecatedPublishedAppTool(Base):
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)

# id of the tool provider
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the app
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
# who published this tool
user_id = db.Column(StringUUID, nullable=False)
# description of the tool, stored in i18n format, for human
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
llm_description = db.Column(db.Text, nullable=False)
@@ -328,10 +324,6 @@ class DeprecatedPublishedAppTool(Base):
def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description))

@property
def app(self) -> App:
return db.session.query(App).filter(App.id == self.app_id).first()

id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)

+ 1
- 1
api/services/agent_service.py 파일 보기

@@ -23,7 +23,7 @@ class AgentService:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

conversation: Conversation = (
conversation: Conversation | None = (
db.session.query(Conversation)
.filter(
Conversation.id == conversation_id,

+ 1
- 1
api/services/entities/model_provider_entities.py 파일 보기

@@ -156,7 +156,7 @@ class DefaultModelResponse(BaseModel):
model_config = ConfigDict(protected_namespaces=())


class ModelWithProviderEntityResponse(ModelWithProviderEntity):
class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
"""
Model with provider entity.
"""

+ 1
- 2
api/services/plugin/plugin_migration.py 파일 보기

@@ -173,9 +173,8 @@ class PluginMigration:
"""
Extract model tables.

NOTE: rename google to gemini
"""
models = []
models: list[str] = []
table_pairs = [
("providers", "provider_name"),
("provider_models", "provider_name"),

+ 1
- 1
api/services/tools/api_tools_manage_service.py 파일 보기

@@ -439,7 +439,7 @@ class ApiToolManageService:
tenant_id=tenant_id,
)
)
result = runtime_tool.validate_credentials(credentials, parameters)
result = tool.validate_credentials(credentials, parameters)
except Exception as e:
return {"error": str(e)}


+ 2
- 2
api/services/tools/tools_transform_service.py 파일 보기

@@ -1,6 +1,6 @@
import json
import logging
from typing import Optional, Union
from typing import Optional, Union, cast

from yarl import URL

@@ -44,7 +44,7 @@ class ToolTransformService:
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try:
if isinstance(icon, str):
return json.loads(icon)
return cast(dict, json.loads(icon))
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}

+ 9
- 8
api/services/tools/workflow_tools_manage_service.py 파일 보기

@@ -1,7 +1,7 @@
import json
from collections.abc import Mapping, Sequence
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any

from sqlalchemy import or_

@@ -11,6 +11,7 @@ from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntit
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.model import App
from models.tools import WorkflowToolProvider
@@ -187,7 +188,7 @@ class WorkflowToolManageService:
"""
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()

tools: Sequence[WorkflowToolProviderController] = []
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
try:
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
@@ -264,7 +265,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)

@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
"""
Get a workflow tool.
:db_tool: the database tool
@@ -285,8 +286,8 @@ class WorkflowToolManageService:
raise ValueError("Workflow not found")

tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {db_tool.id} not found")

return {
@@ -325,8 +326,8 @@ class WorkflowToolManageService:
raise ValueError(f"Tool {workflow_tool_id} not found")

tool = ToolTransformService.workflow_provider_to_controller(db_tool)
to_user_tool: Optional[list[ToolApiEntity]] = tool.get_tools(user_id, tenant_id)
if to_user_tool is None or len(to_user_tool) == 0:
workflow_tools: list[WorkflowTool] = tool.get_tools(tenant_id)
if len(workflow_tools) == 0:
raise ValueError(f"Tool {workflow_tool_id} not found")

return [

+ 1
- 1
api/tasks/batch_create_segment_to_index_task.py 파일 보기

@@ -67,7 +67,7 @@ def batch_create_segment_to_index_task(
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
segment_hash = helper.generate_text_hash(content) # type: ignore
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == dataset_document.id)

Loading…
취소
저장