| @@ -42,11 +42,7 @@ jobs: | |||
| - name: Run Unit tests | |||
| run: | | |||
| uv run --project api bash dev/pytest/pytest_unit_tests.sh | |||
| - name: Run ty check | |||
| run: | | |||
| cd api | |||
| uv add --dev ty | |||
| uv run ty check || true | |||
| - name: Run pyrefly check | |||
| run: | | |||
| cd api | |||
| @@ -44,6 +44,9 @@ jobs: | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| run: uv sync --project api --dev | |||
| - name: Run ty check | |||
| run: dev/ty-check | |||
| - name: Dotenv check | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example | |||
| @@ -27,7 +27,7 @@ class NacosHttpClient: | |||
| response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) | |||
| response.raise_for_status() | |||
| return response.text | |||
| except requests.exceptions.RequestException as e: | |||
| except requests.RequestException as e: | |||
| return f"Request to Nacos failed: {e}" | |||
| def _inject_auth_info(self, headers, params, module="config"): | |||
| @@ -84,7 +84,7 @@ class BaseApiKeyListResource(Resource): | |||
| flask_restx.abort( | |||
| 400, | |||
| message=f"Cannot create more than {self.max_keys} API keys for this resource type.", | |||
| code="max_keys_exceeded", | |||
| custom="max_keys_exceeded", | |||
| ) | |||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||
| @@ -81,7 +81,7 @@ class OAuthDataSourceBinding(Resource): | |||
| return {"error": "Invalid code"}, 400 | |||
| try: | |||
| oauth_provider.get_access_token(code) | |||
| except requests.exceptions.HTTPError as e: | |||
| except requests.HTTPError as e: | |||
| logger.exception( | |||
| "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text | |||
| ) | |||
| @@ -104,7 +104,7 @@ class OAuthDataSourceSync(Resource): | |||
| return {"error": "Invalid provider"}, 400 | |||
| try: | |||
| oauth_provider.sync_data_source(binding_id) | |||
| except requests.exceptions.HTTPError as e: | |||
| except requests.HTTPError as e: | |||
| logger.exception( | |||
| "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text | |||
| ) | |||
| @@ -80,7 +80,7 @@ class OAuthCallback(Resource): | |||
| try: | |||
| token = oauth_provider.get_access_token(code) | |||
| user_info = oauth_provider.get_user_info(token) | |||
| except requests.exceptions.RequestException as e: | |||
| except requests.RequestException as e: | |||
| error_text = e.response.text if e.response else str(e) | |||
| logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) | |||
| return {"error": "OAuth process failed"}, 400 | |||
| @@ -55,7 +55,7 @@ class AudioApi(Resource): | |||
| file = request.files["file"] | |||
| try: | |||
| response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) | |||
| response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.id) | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| @@ -118,7 +118,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| @@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): | |||
| data = cls._error_to_stream_response(sub_stream_response.err) | |||
| response_chunk.update(data) | |||
| elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) | |||
| response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] | |||
| else: | |||
| response_chunk.update(sub_stream_response.to_dict()) | |||
| yield response_chunk | |||
| @@ -96,7 +96,11 @@ class RateLimit: | |||
| if isinstance(generator, Mapping): | |||
| return generator | |||
| else: | |||
| return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id) | |||
| return RateLimitGenerator( | |||
| rate_limit=self, | |||
| generator=generator, # ty: ignore [invalid-argument-type] | |||
| request_id=request_id, | |||
| ) | |||
| class RateLimitGenerator: | |||
| @@ -50,7 +50,7 @@ class BasedGenerateTaskPipeline: | |||
| if isinstance(e, InvokeAuthorizationError): | |||
| err = InvokeAuthorizationError("Incorrect API key provided") | |||
| elif isinstance(e, InvokeError | ValueError): | |||
| err = e | |||
| err = e # ty: ignore [invalid-assignment] | |||
| else: | |||
| description = getattr(e, "description", None) | |||
| err = Exception(description if description is not None else str(e)) | |||
| @@ -43,9 +43,9 @@ class APIBasedExtensionRequestor: | |||
| timeout=self.timeout, | |||
| proxies=proxies, | |||
| ) | |||
| except requests.exceptions.Timeout: | |||
| except requests.Timeout: | |||
| raise ValueError("request timeout") | |||
| except requests.exceptions.ConnectionError: | |||
| except requests.ConnectionError: | |||
| raise ValueError("request connection error") | |||
| if response.status_code != 200: | |||
| @@ -47,7 +47,7 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type] | |||
| def load_single_subclass_from_source( | |||
| *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False | |||
| *, module_name: str, script_path: str, parent_type: type, use_lazy_loader: bool = False | |||
| ) -> type: | |||
| """ | |||
| Load a single subclass from the source | |||
| @@ -56,11 +56,8 @@ class LLMGenerator: | |||
| prompts = [UserPromptMessage(content=prompt)] | |||
| with measure_time() as timer: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False | |||
| ), | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False | |||
| ) | |||
| answer = cast(str, response.message.content) | |||
| cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) | |||
| @@ -113,13 +110,10 @@ class LLMGenerator: | |||
| prompt_messages = [UserPromptMessage(content=prompt)] | |||
| try: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), | |||
| model_parameters={"max_tokens": 256, "temperature": 0}, | |||
| stream=False, | |||
| ), | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), | |||
| model_parameters={"max_tokens": 256, "temperature": 0}, | |||
| stream=False, | |||
| ) | |||
| text_content = response.message.get_text_content() | |||
| @@ -162,11 +156,8 @@ class LLMGenerator: | |||
| ) | |||
| try: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ) | |||
| rule_config["prompt"] = cast(str, response.message.content) | |||
| @@ -212,11 +203,8 @@ class LLMGenerator: | |||
| try: | |||
| try: | |||
| # the first step to generate the task prompt | |||
| prompt_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| prompt_content: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ) | |||
| except InvokeError as e: | |||
| error = str(e) | |||
| @@ -248,11 +236,8 @@ class LLMGenerator: | |||
| statement_messages = [UserPromptMessage(content=statement_generate_prompt)] | |||
| try: | |||
| parameter_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| parameter_content: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False | |||
| ) | |||
| rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)) | |||
| except InvokeError as e: | |||
| @@ -260,11 +245,8 @@ class LLMGenerator: | |||
| error_step = "generate variables" | |||
| try: | |||
| statement_content = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| statement_content: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False | |||
| ) | |||
| rule_config["opening_statement"] = cast(str, statement_content.message.content) | |||
| except InvokeError as e: | |||
| @@ -307,11 +289,8 @@ class LLMGenerator: | |||
| prompt_messages = [UserPromptMessage(content=prompt)] | |||
| model_parameters = model_config.get("completion_params", {}) | |||
| try: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ) | |||
| generated_code = cast(str, response.message.content) | |||
| @@ -338,13 +317,10 @@ class LLMGenerator: | |||
| prompt_messages = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)] | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters={"temperature": 0.01, "max_tokens": 2000}, | |||
| stream=False, | |||
| ), | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters={"temperature": 0.01, "max_tokens": 2000}, | |||
| stream=False, | |||
| ) | |||
| answer = cast(str, response.message.content) | |||
| @@ -367,11 +343,8 @@ class LLMGenerator: | |||
| model_parameters = model_config.get("model_parameters", {}) | |||
| try: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ) | |||
| raw_content = response.message.content | |||
| @@ -555,11 +528,8 @@ class LLMGenerator: | |||
| model_parameters = {"temperature": 0.4} | |||
| try: | |||
| response = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ), | |||
| response: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False | |||
| ) | |||
| generated_raw = cast(str, response.message.content) | |||
| @@ -72,7 +72,7 @@ class TraceClient: | |||
| else: | |||
| logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) | |||
| return False | |||
| except requests.exceptions.RequestException as e: | |||
| except requests.RequestException as e: | |||
| logger.debug("AliyunTrace API check failed: %s", str(e)) | |||
| raise ValueError(f"AliyunTrace API check failed: {str(e)}") | |||
| @@ -64,7 +64,7 @@ class BasePluginClient: | |||
| response = requests.request( | |||
| method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files | |||
| ) | |||
| except requests.exceptions.ConnectionError: | |||
| except requests.ConnectionError: | |||
| logger.exception("Request to Plugin Daemon Service failed") | |||
| raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed") | |||
| @@ -192,8 +192,8 @@ class AnalyticdbVectorOpenAPI: | |||
| collection=self._collection_name, | |||
| metrics=self.config.metrics, | |||
| include_values=True, | |||
| vector=None, | |||
| content=None, | |||
| vector=None, # ty: ignore [invalid-argument-type] | |||
| content=None, # ty: ignore [invalid-argument-type] | |||
| top_k=1, | |||
| filter=f"ref_doc_id='{id}'", | |||
| ) | |||
| @@ -211,7 +211,7 @@ class AnalyticdbVectorOpenAPI: | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| collection_data=None, | |||
| collection_data=None, # ty: ignore [invalid-argument-type] | |||
| collection_data_filter=f"ref_doc_id IN {ids_str}", | |||
| ) | |||
| self._client.delete_collection_data(request) | |||
| @@ -225,7 +225,7 @@ class AnalyticdbVectorOpenAPI: | |||
| namespace=self.config.namespace, | |||
| namespace_password=self.config.namespace_password, | |||
| collection=self._collection_name, | |||
| collection_data=None, | |||
| collection_data=None, # ty: ignore [invalid-argument-type] | |||
| collection_data_filter=f"metadata_ ->> '{key}' = '{value}'", | |||
| ) | |||
| self._client.delete_collection_data(request) | |||
| @@ -249,7 +249,7 @@ class AnalyticdbVectorOpenAPI: | |||
| include_values=kwargs.pop("include_values", True), | |||
| metrics=self.config.metrics, | |||
| vector=query_vector, | |||
| content=None, | |||
| content=None, # ty: ignore [invalid-argument-type] | |||
| top_k=kwargs.get("top_k", 4), | |||
| filter=where_clause, | |||
| ) | |||
| @@ -285,7 +285,7 @@ class AnalyticdbVectorOpenAPI: | |||
| collection=self._collection_name, | |||
| include_values=kwargs.pop("include_values", True), | |||
| metrics=self.config.metrics, | |||
| vector=None, | |||
| vector=None, # ty: ignore [invalid-argument-type] | |||
| content=query, | |||
| top_k=kwargs.get("top_k", 4), | |||
| filter=where_clause, | |||
| @@ -12,7 +12,7 @@ import clickzetta # type: ignore | |||
| from pydantic import BaseModel, model_validator | |||
| if TYPE_CHECKING: | |||
| from clickzetta import Connection | |||
| from clickzetta.connector.v0.connection import Connection # type: ignore | |||
| from configs import dify_config | |||
| from core.rag.datasource.vdb.field import Field | |||
| @@ -306,7 +306,7 @@ class CouchbaseVector(BaseVector): | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 4) | |||
| try: | |||
| CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) | |||
| CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query)) # ty: ignore [too-many-positional-arguments] | |||
| search_iter = self._scope.search( | |||
| self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"]) | |||
| ) | |||
| @@ -138,7 +138,7 @@ class ElasticSearchVector(BaseVector): | |||
| if not client.ping(): | |||
| raise ConnectionError("Failed to connect to Elasticsearch") | |||
| except requests.exceptions.ConnectionError as e: | |||
| except requests.ConnectionError as e: | |||
| raise ConnectionError(f"Vector database connection error: {str(e)}") | |||
| except Exception as e: | |||
| raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") | |||
| @@ -376,7 +376,12 @@ class MilvusVector(BaseVector): | |||
| if config.token: | |||
| client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database) | |||
| else: | |||
| client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) | |||
| client = MilvusClient( | |||
| uri=config.uri, | |||
| user=config.user or "", | |||
| password=config.password or "", | |||
| db_name=config.database, | |||
| ) | |||
| return client | |||
| @@ -32,9 +32,9 @@ class VikingDBConfig(BaseModel): | |||
| scheme: str | |||
| connection_timeout: int | |||
| socket_timeout: int | |||
| index_type: str = IndexType.HNSW | |||
| distance: str = DistanceType.L2 | |||
| quant: str = QuantType.Float | |||
| index_type: str = str(IndexType.HNSW) | |||
| distance: str = str(DistanceType.L2) | |||
| quant: str = str(QuantType.Float) | |||
| class VikingDBVector(BaseVector): | |||
| @@ -37,22 +37,22 @@ class WeaviateVector(BaseVector): | |||
| self._attributes = attributes | |||
| def _init_client(self, config: WeaviateConfig) -> weaviate.Client: | |||
| auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key) | |||
| auth_config = weaviate.AuthApiKey(api_key=config.api_key or "") | |||
| weaviate.connect.connection.has_grpc = False | |||
| weaviate.connect.connection.has_grpc = False # ty: ignore [unresolved-attribute] | |||
| # Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0, | |||
| # by changing the connection timeout to pypi.org from 1 second to 0.001 seconds. | |||
| # TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher, | |||
| # which does not contain the deprecation check. | |||
| if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"): | |||
| weaviate.connect.connection.PYPI_TIMEOUT = 0.001 | |||
| if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"): # ty: ignore [unresolved-attribute] | |||
| weaviate.connect.connection.PYPI_TIMEOUT = 0.001 # ty: ignore [unresolved-attribute] | |||
| try: | |||
| client = weaviate.Client( | |||
| url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None | |||
| ) | |||
| except requests.exceptions.ConnectionError: | |||
| except requests.ConnectionError: | |||
| raise ConnectionError("Vector database connection error") | |||
| client.batch.configure( | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Union, cast | |||
| from typing import Union | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.model_manager import ModelInstance | |||
| @@ -28,14 +28,11 @@ class FunctionCallMultiDatasetRouter: | |||
| SystemPromptMessage(content="You are a helpful AI assistant."), | |||
| UserPromptMessage(content=query), | |||
| ] | |||
| result = cast( | |||
| LLMResult, | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| tools=dataset_tools, | |||
| stream=False, | |||
| model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, | |||
| ), | |||
| result: LLMResult = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| tools=dataset_tools, | |||
| stream=False, | |||
| model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500}, | |||
| ) | |||
| if result.message.tool_calls: | |||
| # get retrieval model config | |||
| @@ -1,5 +1,5 @@ | |||
| from collections.abc import Generator, Sequence | |||
| from typing import Union, cast | |||
| from typing import Union | |||
| from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity | |||
| from core.model_manager import ModelInstance | |||
| @@ -150,15 +150,12 @@ class ReactMultiDatasetRouter: | |||
| :param stop: stop | |||
| :return: | |||
| """ | |||
| invoke_result = cast( | |||
| Generator[LLMResult, None, None], | |||
| model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=completion_param, | |||
| stop=stop, | |||
| stream=True, | |||
| user=user_id, | |||
| ), | |||
| invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=completion_param, | |||
| stop=stop, | |||
| stream=True, | |||
| user=user_id, | |||
| ) | |||
| # handle invoke result | |||
| @@ -74,7 +74,7 @@ class BuiltinToolProviderController(ToolProviderController): | |||
| tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False) | |||
| # get tool class, import the module | |||
| assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source( | |||
| assistant_tool_class: type = load_single_subclass_from_source( | |||
| module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}", | |||
| script_path=path.join( | |||
| path.dirname(path.realpath(__file__)), | |||
| @@ -26,7 +26,7 @@ class ToolLabelManager: | |||
| labels = cls.filter_tool_labels(labels) | |||
| if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): | |||
| provider_id = controller.provider_id | |||
| provider_id = controller.provider_id # ty: ignore [unresolved-attribute] | |||
| else: | |||
| raise ValueError("Unsupported tool type") | |||
| @@ -51,7 +51,7 @@ class ToolLabelManager: | |||
| Get tool labels | |||
| """ | |||
| if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): | |||
| provider_id = controller.provider_id | |||
| provider_id = controller.provider_id # ty: ignore [unresolved-attribute] | |||
| elif isinstance(controller, BuiltinToolProviderController): | |||
| return controller.tool_labels | |||
| else: | |||
| @@ -85,7 +85,7 @@ class ToolLabelManager: | |||
| provider_ids = [] | |||
| for controller in tool_providers: | |||
| assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) | |||
| provider_ids.append(controller.provider_id) | |||
| provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] | |||
| labels: list[ToolLabelBinding] = ( | |||
| db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() | |||
| @@ -1,7 +1,6 @@ | |||
| from abc import abstractmethod | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional | |||
| from msal_extensions.persistence import ABC # type: ignore | |||
| from pydantic import BaseModel, ConfigDict | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| @@ -52,12 +52,12 @@ class AnswerStreamProcessor(StreamProcessor): | |||
| yield event | |||
| elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent): | |||
| yield event | |||
| if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: | |||
| if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # ty: ignore [unresolved-attribute] | |||
| # update self.route_position after all stream event finished | |||
| for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: | |||
| for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]: # ty: ignore [unresolved-attribute] | |||
| self.route_position[answer_node_id] += 1 | |||
| del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] | |||
| del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id] # ty: ignore [unresolved-attribute] | |||
| self._remove_unreachable_nodes(event) | |||
| @@ -83,7 +83,7 @@ class IfElseNode(BaseNode): | |||
| else: | |||
| # TODO: Update database then remove this | |||
| # Fallback to old structure if cases are not defined | |||
| input_conditions, group_result, final_result = _should_not_use_old_function( | |||
| input_conditions, group_result, final_result = _should_not_use_old_function( # ty: ignore [deprecated] | |||
| condition_processor=condition_processor, | |||
| variable_pool=self.graph_runtime_state.variable_pool, | |||
| conditions=self._node_data.conditions or [], | |||
| @@ -441,8 +441,8 @@ class IterationNode(BaseNode): | |||
| iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}" | |||
| next_index = int(current_index) + 1 | |||
| for event in rst: | |||
| if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: | |||
| event.in_iteration_id = self.node_id | |||
| if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: # ty: ignore [unresolved-attribute] | |||
| event.in_iteration_id = self.node_id # ty: ignore [unresolved-attribute] | |||
| if ( | |||
| isinstance(event, BaseNodeEvent) | |||
| @@ -299,8 +299,8 @@ class LoopNode(BaseNode): | |||
| check_break_result = False | |||
| for event in rst: | |||
| if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: | |||
| event.in_loop_id = self.node_id | |||
| if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id: # ty: ignore [unresolved-attribute] | |||
| event.in_loop_id = self.node_id # ty: ignore [unresolved-attribute] | |||
| if ( | |||
| isinstance(event, BaseNodeEvent) | |||
| @@ -103,7 +103,7 @@ def init_app(app: DifyApp): | |||
| def shutdown_tracer(): | |||
| provider = trace.get_tracer_provider() | |||
| if hasattr(provider, "force_flush"): | |||
| provider.force_flush() | |||
| provider.force_flush() # ty: ignore [call-non-callable] | |||
| class ExceptionLoggingHandler(logging.Handler): | |||
| """Custom logging handler that creates spans for logging.exception() calls""" | |||
| @@ -260,7 +260,8 @@ def redis_fallback(default_return: Optional[Any] = None): | |||
| try: | |||
| return func(*args, **kwargs) | |||
| except RedisError as e: | |||
| logger.warning("Redis operation failed in %s: %s", func.__name__, str(e), exc_info=True) | |||
| func_name = getattr(func, "__name__", "Unknown") | |||
| logger.warning("Redis operation failed in %s: %s", func_name, str(e), exc_info=True) | |||
| return default_return | |||
| return wrapper | |||
| @@ -101,7 +101,7 @@ def register_external_error_handlers(api: Api) -> None: | |||
| exc_info: Any = sys.exc_info() | |||
| if exc_info[1] is None: | |||
| exc_info = None | |||
| current_app.log_exception(exc_info) | |||
| current_app.log_exception(exc_info) # ty: ignore [invalid-argument-type] | |||
| return data, status_code | |||
| @@ -136,7 +136,7 @@ class PKCS1OAepCipher: | |||
| # Step 3a (OS2IP) | |||
| em_int = bytes_to_long(em) | |||
| # Step 3b (RSAEP) | |||
| m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) | |||
| m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute] | |||
| # Step 3c (I2OSP) | |||
| c = long_to_bytes(m_int, k) | |||
| return c | |||
| @@ -169,7 +169,7 @@ class PKCS1OAepCipher: | |||
| ct_int = bytes_to_long(ciphertext) | |||
| # Step 2b (RSADP) | |||
| # m_int = self._key._decrypt(ct_int) | |||
| m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) | |||
| m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute] | |||
| # Complete step 2c (I2OSP) | |||
| em = long_to_bytes(m_int, k) | |||
| # Step 3a | |||
| @@ -14,11 +14,11 @@ class PassportService: | |||
| def verify(self, token): | |||
| try: | |||
| return jwt.decode(token, self.sk, algorithms=["HS256"]) | |||
| except jwt.exceptions.ExpiredSignatureError: | |||
| except jwt.ExpiredSignatureError: | |||
| raise Unauthorized("Token has expired.") | |||
| except jwt.exceptions.InvalidSignatureError: | |||
| except jwt.InvalidSignatureError: | |||
| raise Unauthorized("Invalid token signature.") | |||
| except jwt.exceptions.DecodeError: | |||
| except jwt.DecodeError: | |||
| raise Unauthorized("Invalid token.") | |||
| except jwt.exceptions.PyJWTError: # Catch-all for other JWT errors | |||
| except jwt.PyJWTError: # Catch-all for other JWT errors | |||
| raise Unauthorized("Invalid token.") | |||
| @@ -26,9 +26,9 @@ class SendGridClient: | |||
| to_email = To(_to) | |||
| subject = mail["subject"] | |||
| content = Content("text/html", mail["html"]) | |||
| mail = Mail(from_email, to_email, subject, content) | |||
| mail_json = mail.get() # type: ignore | |||
| response = sg.client.mail.send.post(request_body=mail_json) | |||
| sg_mail = Mail(from_email, to_email, subject, content) | |||
| mail_json = sg_mail.get() | |||
| response = sg.client.mail.send.post(request_body=mail_json) # ty: ignore [call-non-callable] | |||
| logger.debug(response.status_code) | |||
| logger.debug(response.body) | |||
| logger.debug(response.headers) | |||
| @@ -110,6 +110,7 @@ dev = [ | |||
| "dotenv-linter~=0.5.0", | |||
| "faker~=32.1.0", | |||
| "lxml-stubs~=0.5.1", | |||
| "ty~=0.0.1a19", | |||
| "mypy~=1.17.1", | |||
| "ruff~=0.12.3", | |||
| "pytest~=8.3.2", | |||
| @@ -133,7 +133,11 @@ class DatasetService: | |||
| # Check if tag_ids is not empty to avoid WHERE false condition | |||
| if tag_ids and len(tag_ids) > 0: | |||
| target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) | |||
| target_ids = TagService.get_target_ids_by_tag_ids( | |||
| "knowledge", | |||
| tenant_id, # ty: ignore [invalid-argument-type] | |||
| tag_ids, | |||
| ) | |||
| if target_ids and len(target_ids) > 0: | |||
| query = query.where(Dataset.id.in_(target_ids)) | |||
| else: | |||
| @@ -2361,7 +2365,9 @@ class SegmentService: | |||
| index_node_ids = [seg.index_node_id for seg in segments] | |||
| total_words = sum(seg.word_count for seg in segments) | |||
| document.word_count -= total_words | |||
| document.word_count = ( | |||
| document.word_count - total_words if document.word_count and document.word_count > total_words else 0 | |||
| ) | |||
| db.session.add(document) | |||
| delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) | |||
| @@ -229,7 +229,7 @@ class ExternalDatasetService: | |||
| @staticmethod | |||
| def get_external_knowledge_api_settings(settings: dict) -> ExternalKnowledgeApiSetting: | |||
| return ExternalKnowledgeApiSetting.parse_obj(settings) | |||
| return ExternalKnowledgeApiSetting.model_validate(settings) | |||
| @staticmethod | |||
| def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: | |||
| @@ -170,7 +170,9 @@ class ModelLoadBalancingService: | |||
| if variable in credentials: | |||
| try: | |||
| credentials[variable] = encrypter.decrypt_token_with_decoding( | |||
| credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa | |||
| credentials.get(variable), # ty: ignore [invalid-argument-type] | |||
| decoding_rsa_key, | |||
| decoding_cipher_rsa, | |||
| ) | |||
| except ValueError: | |||
| pass | |||
| @@ -229,7 +229,7 @@ class MCPToolManageService: | |||
| provider_controller = MCPToolProviderController._from_db(mcp_provider) | |||
| tool_configuration = ProviderConfigEncrypter( | |||
| tenant_id=mcp_provider.tenant_id, | |||
| config=list(provider_controller.get_credentials_schema()), | |||
| config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] | |||
| provider_config_cache=NoOpProviderCredentialCache(), | |||
| ) | |||
| credentials = tool_configuration.encrypt(credentials) | |||
| @@ -0,0 +1,16 @@ | |||
| [src] | |||
| exclude = [ | |||
| # TODO: enable when violations fixed | |||
| "core/app/apps/workflow_app_runner.py", | |||
| "controllers/console/app", | |||
| "controllers/console/explore", | |||
| "controllers/console/datasets", | |||
| "controllers/console/workspace", | |||
| # non-producition or generated code | |||
| "migrations", | |||
| "tests", | |||
| ] | |||
| [rules] | |||
| missing-argument = "ignore" # TODO: restore when **args for constructor is supported properly | |||
| possibly-unbound-attribute = "ignore" | |||
| @@ -1353,6 +1353,7 @@ dev = [ | |||
| { name = "ruff" }, | |||
| { name = "scipy-stubs" }, | |||
| { name = "testcontainers" }, | |||
| { name = "ty" }, | |||
| { name = "types-aiofiles" }, | |||
| { name = "types-beautifulsoup4" }, | |||
| { name = "types-cachetools" }, | |||
| @@ -1542,6 +1543,7 @@ dev = [ | |||
| { name = "ruff", specifier = "~=0.12.3" }, | |||
| { name = "scipy-stubs", specifier = ">=1.15.3.0" }, | |||
| { name = "testcontainers", specifier = "~=4.10.0" }, | |||
| { name = "ty", specifier = "~=0.0.1a19" }, | |||
| { name = "types-aiofiles", specifier = "~=24.1.0" }, | |||
| { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, | |||
| { name = "types-cachetools", specifier = "~=5.5.0" }, | |||
| @@ -5782,6 +5784,31 @@ wheels = [ | |||
| { url = "https://files.pythonhosted.org/packages/41/b1/d7520cc5cb69c825599042eb3a7c986fa9baa8a8d2dea9acd78e152c81e2/transformers-4.53.3-py3-none-any.whl", hash = "sha256:5aba81c92095806b6baf12df35d756cf23b66c356975fb2a7fa9e536138d7c75", size = 10826382 }, | |||
| ] | |||
| [[package]] | |||
| name = "ty" | |||
| version = "0.0.1a19" | |||
| source = { registry = "https://pypi.org/simple" } | |||
| sdist = { url = "https://files.pythonhosted.org/packages/c0/04/281c1a3c9c53dae5826b9d01a3412de653e3caf1ca50ce1265da66e06d73/ty-0.0.1a19.tar.gz", hash = "sha256:894f6a13a43989c8ef891ae079b3b60a0c0eae00244abbfbbe498a3840a235ac", size = 4098412, upload-time = "2025-08-19T13:29:58.559Z" } | |||
| wheels = [ | |||
| { url = "https://files.pythonhosted.org/packages/3e/65/a61cfcc7248b0257a3110bf98d3d910a4729c1063abdbfdcd1cad9012323/ty-0.0.1a19-py3-none-linux_armv6l.whl", hash = "sha256:e0e7762f040f4bab1b37c57cb1b43cc3bc5afb703fa5d916dfcafa2ef885190e", size = 8143744, upload-time = "2025-08-19T13:29:13.88Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/02/d9/232afef97d9afa2274d23a4c49a3ad690282ca9696e1b6bbb6e4e9a1b072/ty-0.0.1a19-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cd0a67ac875f49f34d9a0b42dcabf4724194558a5dd36867209d5695c67768f7", size = 8305799, upload-time = "2025-08-19T13:29:17.322Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/20/14/099d268da7a9cccc6ba38dfc124f6742a1d669bc91f2c61a3465672b4f71/ty-0.0.1a19-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ff8b1c0b85137333c39eccd96c42603af8ba7234d6e2ed0877f66a4a26750dd4", size = 7901431, upload-time = "2025-08-19T13:29:21.635Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/c2/cd/3f1ca6e1d7f77cc4d08910a3fc4826313c031c0aae72286ae859e737670c/ty-0.0.1a19-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4fef34a29f4b97d78aa30e60adbbb12137cf52b8b2b0f1a408dd0feb0466908a", size = 8051501, upload-time = "2025-08-19T13:29:23.741Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/47/72/ddbec39f48ce3f5f6a3fa1f905c8fff2873e59d2030f738814032bd783e3/ty-0.0.1a19-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b0f219cb43c0c50fc1091f8ebd5548d3ef31ee57866517b9521d5174978af9fd", size = 7981234, upload-time = "2025-08-19T13:29:25.839Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/f2/0f/58e76b8d4634df066c790d362e8e73b25852279cd6f817f099b42a555a66/ty-0.0.1a19-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22abb6c1f14c65c1a2fafd38e25dd3c87994b3ab88cb0b323235b51dbad082d9", size = 8916394, upload-time = "2025-08-19T13:29:27.932Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/70/30/01bfd93ccde11540b503e2539e55f6a1fc6e12433a229191e248946eb753/ty-0.0.1a19-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5b49225c349a3866e38dd297cb023a92d084aec0e895ed30ca124704bff600e6", size = 9412024, upload-time = "2025-08-19T13:29:30.942Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/a8/a2/2216d752f5f22c5c0995f9b13f18337301220f2a7d952c972b33e6a63583/ty-0.0.1a19-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:88f41728b3b07402e0861e3c34412ca963268e55f6ab1690208f25d37cb9d63c", size = 9032657, upload-time = "2025-08-19T13:29:33.933Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/24/c7/e6650b0569be1b69a03869503d07420c9fb3e90c9109b09726c44366ce63/ty-0.0.1a19-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33814a1197ec3e930fcfba6fb80969fe7353957087b42b88059f27a173f7510b", size = 8812775, upload-time = "2025-08-19T13:29:36.505Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/35/c6/b8a20e06b97fe8203059d56d8f91cec4f9633e7ba65f413d80f16aa0be04/ty-0.0.1a19-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d71b7f2b674a287258f628acafeecd87691b169522945ff6192cd8a69af15857", size = 8631417, upload-time = "2025-08-19T13:29:38.837Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/be/99/821ca1581dcf3d58ffb7bbe1cde7e1644dbdf53db34603a16a459a0b302c/ty-0.0.1a19-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3a7f8ef9ac4c38e8651c18c7380649c5a3fa9adb1a6012c721c11f4bbdc0ce24", size = 7928900, upload-time = "2025-08-19T13:29:41.08Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/08/cb/59f74a0522e57565fef99e2287b2bc803ee47ff7dac250af26960636939f/ty-0.0.1a19-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:60f40e72f0fbf4e54aa83d9a6cb1959f551f83de73af96abbb94711c1546bd60", size = 8003310, upload-time = "2025-08-19T13:29:43.165Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/4c/b3/1209b9acb5af00a2755114042e48fb0f71decc20d9d77a987bf5b3d1a102/ty-0.0.1a19-py3-none-musllinux_1_2_i686.whl", hash = "sha256:64971e4d3e3f83dc79deb606cc438255146cab1ab74f783f7507f49f9346d89d", size = 8496463, upload-time = "2025-08-19T13:29:46.136Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/a2/d6/a4b6ba552d347a08196d83a4d60cb23460404a053dd3596e23a922bce544/ty-0.0.1a19-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9aadbff487e2e1486e83543b4f4c2165557f17432369f419be9ba48dc47625ca", size = 8700633, upload-time = "2025-08-19T13:29:49.351Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/96/c5/258f318d68b95685c8d98fb654a38882c9d01ce5d9426bed06124f690f04/ty-0.0.1a19-py3-none-win32.whl", hash = "sha256:00b75b446357ee22bcdeb837cb019dc3bc1dc5e5013ff0f46a22dfe6ce498fe2", size = 7811441, upload-time = "2025-08-19T13:29:52.077Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/fb/bb/039227eee3c0c0cddc25f45031eea0f7f10440713f12d333f2f29cf8e934/ty-0.0.1a19-py3-none-win_amd64.whl", hash = "sha256:aaef76b2f44f6379c47adfe58286f0c56041cb2e374fd8462ae8368788634469", size = 8441186, upload-time = "2025-08-19T13:29:54.53Z" }, | |||
| { url = "https://files.pythonhosted.org/packages/74/5f/bceb29009670ae6f759340f9cb434121bc5ed84ad0f07bdc6179eaaa3204/ty-0.0.1a19-py3-none-win_arm64.whl", hash = "sha256:893755bb35f30653deb28865707e3b16907375c830546def2741f6ff9a764710", size = 8000810, upload-time = "2025-08-19T13:29:56.796Z" }, | |||
| ] | |||
| [[package]] | |||
| name = "typer" | |||
| version = "0.16.0" | |||
| @@ -14,5 +14,8 @@ uv run --directory api --dev ruff format ./ | |||
| # run dotenv-linter linter | |||
| uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example | |||
| # run ty check | |||
| dev/ty-check | |||
| # run mypy check | |||
| dev/mypy-check | |||
| @@ -0,0 +1,10 @@ | |||
| #!/bin/bash | |||
| set -x | |||
| SCRIPT_DIR="$(dirname "$(realpath "$0")")" | |||
| cd "$SCRIPT_DIR/.." | |||
| # run ty checks | |||
| uv run --directory api --dev \ | |||
| ty check | |||
| @@ -41,6 +41,15 @@ if $api_modified; then | |||
| echo "Please run 'dev/reformat' to fix the fixable linting errors." | |||
| exit 1 | |||
| fi | |||
| # run ty checks | |||
| uv run --directory api --dev ty check || status=$? | |||
| status=${status:-0} | |||
| if [ $status -ne 0 ]; then | |||
| echo "ty type checker on api module error, exit code: $status" | |||
| echo "Please run 'dev/ty-check' to check the type errors." | |||
| exit 1 | |||
| fi | |||
| fi | |||
| if $web_modified; then | |||