Quellcode durchsuchen

refactor: replace try-except blocks with contextlib.suppress for cleaner exception handling (#24284)

tags/1.8.0
Guangdong Liu vor 2 Monaten
Ursprung
Commit
1abf1240b2
Es ist kein Account mit der E-Mail-Adresse des Committers verbunden

+ 3
- 3
api/controllers/console/wraps.py Datei anzeigen

import contextlib
import json import json
import os import os
import time import time
def cloud_utm_record(view): def cloud_utm_record(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
try:
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)


if features.billing.enabled: if features.billing.enabled:
if utm_info: if utm_info:
utm_info_dict: dict = json.loads(utm_info) utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict) OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
except Exception as e:
pass

return view(*args, **kwargs) return view(*args, **kwargs)


return decorated return decorated

+ 2
- 3
api/core/helper/trace_id_helper.py Datei anzeigen

import contextlib
import re import re
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional from typing import Any, Optional
Reference: Reference:
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/ W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
""" """
try:
with contextlib.suppress(Exception):
parts = traceparent.split("-") parts = traceparent.split("-")
if len(parts) == 4 and len(parts[1]) == 32: if len(parts) == 4 and len(parts[1]) == 32:
return parts[1] return parts[1]
except Exception:
pass
return None return None

+ 3
- 6
api/core/provider_manager.py Datei anzeigen

import contextlib
import json import json
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError


for variable in provider_credential_secret_variables: for variable in provider_credential_secret_variables:
if variable in provider_credentials: if variable in provider_credentials:
try:
with contextlib.suppress(ValueError):
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable) or "", # type: ignore provider_credentials.get(variable) or "", # type: ignore
self.decoding_rsa_key, self.decoding_rsa_key,
self.decoding_cipher_rsa, self.decoding_cipher_rsa,
) )
except ValueError:
pass


# cache provider credentials # cache provider credentials
provider_credentials_cache.set(credentials=provider_credentials) provider_credentials_cache.set(credentials=provider_credentials)


for variable in model_credential_secret_variables: for variable in model_credential_secret_variables:
if variable in provider_model_credentials: if variable in provider_model_credentials:
try:
with contextlib.suppress(ValueError):
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable), provider_model_credentials.get(variable),
self.decoding_rsa_key, self.decoding_rsa_key,
self.decoding_cipher_rsa, self.decoding_cipher_rsa,
) )
except ValueError:
pass


# cache provider model credentials # cache provider model credentials
provider_model_credentials_cache.set(credentials=provider_model_credentials) provider_model_credentials_cache.set(credentials=provider_model_credentials)

+ 6
- 15
api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py Datei anzeigen

import contextlib
import json import json
import logging import logging
import queue import queue
return connection return connection
else: else:
# Connection expired or invalid, close it # Connection expired or invalid, close it
try:
with contextlib.suppress(Exception):
connection.close() connection.close()
except Exception:
pass


# No valid connection found, create new one # No valid connection found, create new one
return self._create_connection(config) return self._create_connection(config)


if config_key not in self._pool_locks: if config_key not in self._pool_locks:
# Pool was cleaned up, just close the connection # Pool was cleaned up, just close the connection
try:
with contextlib.suppress(Exception):
connection.close() connection.close()
except Exception:
pass
return return


with self._pool_locks[config_key]: with self._pool_locks[config_key]:
logger.debug("Returned ClickZetta connection to pool") logger.debug("Returned ClickZetta connection to pool")
else: else:
# Pool full or connection invalid, close it # Pool full or connection invalid, close it
try:
with contextlib.suppress(Exception):
connection.close() connection.close()
except Exception:
pass


def _cleanup_expired_connections(self) -> None: def _cleanup_expired_connections(self) -> None:
"""Clean up expired connections from all pools.""" """Clean up expired connections from all pools."""
if current_time - last_used < self._connection_timeout: if current_time - last_used < self._connection_timeout:
valid_connections.append((connection, last_used)) valid_connections.append((connection, last_used))
else: else:
try:
with contextlib.suppress(Exception):
connection.close() connection.close()
except Exception:
pass


self._pools[config_key] = valid_connections self._pools[config_key] = valid_connections


with self._pool_locks[config_key]: with self._pool_locks[config_key]:
pool = self._pools[config_key] pool = self._pools[config_key]
for connection, _ in pool: for connection, _ in pool:
try:
with contextlib.suppress(Exception):
connection.close() connection.close()
except Exception:
pass
pool.clear() pool.clear()





+ 2
- 3
api/core/rag/extractor/pdf_extractor.py Datei anzeigen

"""Abstract interface for document loader implementations.""" """Abstract interface for document loader implementations."""


import contextlib
from collections.abc import Iterator from collections.abc import Iterator
from typing import Optional, cast from typing import Optional, cast


def extract(self) -> list[Document]: def extract(self) -> list[Document]:
plaintext_file_exists = False plaintext_file_exists = False
if self._file_cache_key: if self._file_cache_key:
try:
with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8") text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
plaintext_file_exists = True plaintext_file_exists = True
return [Document(page_content=text)] return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = list(self.load()) documents = list(self.load())
text_list = [] text_list = []
for document in documents: for document in documents:

+ 2
- 3
api/core/rag/extractor/unstructured/unstructured_eml_extractor.py Datei anzeigen

import base64 import base64
import contextlib
import logging import logging
from typing import Optional from typing import Optional


elements = partition_email(filename=self._file_path) elements = partition_email(filename=self._file_path)


# noinspection PyBroadException # noinspection PyBroadException
try:
with contextlib.suppress(Exception):
for element in elements: for element in elements:
element_text = element.text.strip() element_text = element.text.strip()


element_decode = base64.b64decode(element_text) element_decode = base64.b64decode(element_text)
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser") soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
element.text = soup.get_text() element.text = soup.get_text()
except Exception:
pass


from unstructured.chunking.title import chunk_by_title from unstructured.chunking.title import chunk_by_title



+ 2
- 3
api/core/tools/entities/tool_entities.py Datei anzeigen

import base64 import base64
import contextlib
import enum import enum
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import Enum
@classmethod @classmethod
def decode_blob_message(cls, v): def decode_blob_message(cls, v):
if isinstance(v, dict) and "blob" in v: if isinstance(v, dict) and "blob" in v:
try:
with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"]) v["blob"] = base64.b64decode(v["blob"])
except Exception:
pass
return v return v


@field_serializer("message") @field_serializer("message")

+ 3
- 6
api/core/tools/tool_engine.py Datei anzeigen

import contextlib
import json import json
from collections.abc import Generator, Iterable from collections.abc import Generator, Iterable
from copy import deepcopy from copy import deepcopy
if parameters and len(parameters) == 1: if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters} tool_parameters = {parameters[0].name: tool_parameters}
else: else:
try:
with contextlib.suppress(Exception):
tool_parameters = json.loads(tool_parameters) tool_parameters = json.loads(tool_parameters)
except Exception:
pass
if not isinstance(tool_parameters, dict): if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")


if response.meta.get("mime_type"): if response.meta.get("mime_type"):
mimetype = response.meta.get("mime_type") mimetype = response.meta.get("mime_type")
else: else:
try:
with contextlib.suppress(Exception):
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
extension = url.suffix extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}") guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result: if guess_type_result:
mimetype = guess_type_result mimetype = guess_type_result
except Exception:
pass


if not mimetype: if not mimetype:
mimetype = "image/jpeg" mimetype = "image/jpeg"

+ 3
- 4
api/core/tools/utils/configuration.py Datei anzeigen

import contextlib
from copy import deepcopy from copy import deepcopy
from typing import Any from typing import Any


and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
): ):
if parameter.name in parameters: if parameter.name in parameters:
try:
has_secret_input = True
has_secret_input = True
with contextlib.suppress(Exception):
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except Exception:
pass


if has_secret_input: if has_secret_input:
cache.set(parameters) cache.set(parameters)

+ 2
- 3
api/core/tools/utils/encryption.py Datei anzeigen

import contextlib
from copy import deepcopy from copy import deepcopy
from typing import Any, Optional, Protocol from typing import Any, Optional, Protocol


for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data: if field_name in data:
try:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt # if the value is None or empty string, skip decrypt
if not data[field_name]: if not data[field_name]:
continue continue


data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass


self.provider_config_cache.set(data) self.provider_config_cache.set(data)
return data return data

+ 4
- 6
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py Datei anzeigen

import contextlib
import json import json
import logging import logging
import uuid import uuid
if result[idx] == "{" or result[idx] == "[": if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:]) json_str = extract_json(result[idx:])
if json_str: if json_str:
try:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str)) return cast(dict, json.loads(json_str))
except Exception:
pass
logger.info("extra error: %s", result) logger.info("extra error: %s", result)
return None return None


if result[idx] == "{" or result[idx] == "[": if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:]) json_str = extract_json(result[idx:])
if json_str: if json_str:
try:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str)) return cast(dict, json.loads(json_str))
except Exception:
pass

logger.info("extra error: %s", result) logger.info("extra error: %s", result)
return None return None



+ 9
- 9
api/events/event_handlers/create_document_index.py Datei anzeigen

import contextlib
import logging import logging
import time import time


db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()


try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
with contextlib.suppress(Exception):
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))

+ 2
- 4
api/extensions/ext_otel.py Datei anzeigen

import atexit import atexit
import contextlib
import logging import logging
import os import os
import platform import platform
"""Custom logging handler that creates spans for logging.exception() calls""" """Custom logging handler that creates spans for logging.exception() calls"""


def emit(self, record: logging.LogRecord): def emit(self, record: logging.LogRecord):
try:
with contextlib.suppress(Exception):
if record.exc_info: if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging") tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span( with tracer.start_as_current_span(
if record.exc_info[0]: if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__) span.set_attribute("exception.type", record.exc_info[0].__name__)


except Exception:
pass

from opentelemetry import trace from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter

+ 2
- 3
api/services/conversation_service.py Datei anzeigen

import contextlib
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import Any, Optional, Union from typing import Any, Optional, Union


raise MessageNotExistsError() raise MessageNotExistsError()


# generate conversation name # generate conversation name
try:
with contextlib.suppress(Exception):
name = LLMGenerator.generate_conversation_name( name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, message.query, conversation.id, app_model.id app_model.tenant_id, message.query, conversation.id, app_model.id
) )
conversation.name = name conversation.name = name
except Exception:
pass


db.session.commit() db.session.commit()



+ 2
- 3
api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py Datei anzeigen

import contextlib
import os import os


import pytest import pytest
yield vector yield vector


# Cleanup: delete the test collection # Cleanup: delete the test collection
try:
with contextlib.suppress(Exception):
vector.delete() vector.delete()
except Exception:
pass


def test_clickzetta_vector_basic_operations(self, vector_store): def test_clickzetta_vector_basic_operations(self, vector_store):
"""Test basic CRUD operations on Clickzetta vector store.""" """Test basic CRUD operations on Clickzetta vector store."""

+ 5
- 14
api/tests/unit_tests/core/mcp/client/test_sse.py Datei anzeigen

import contextlib
import json import json
import queue import queue
import threading import threading
mock_event_source.iter_sse.return_value = [endpoint_event] mock_event_source.iter_sse.return_value = [endpoint_event]


# Test connection # Test connection
try:
with contextlib.suppress(Exception):
with sse_client(test_url) as (read_queue, write_queue): with sse_client(test_url) as (read_queue, write_queue):
assert read_queue is not None assert read_queue is not None
assert write_queue is not None assert write_queue is not None
except Exception as e:
# Connection might fail due to mocking, but we're testing the validation logic
pass




def test_sse_client_error_handling(): def test_sse_client_error_handling():
mock_event_source.iter_sse.return_value = [] mock_event_source.iter_sse.return_value = []
mock_sse_connect.return_value.__enter__.return_value = mock_event_source mock_sse_connect.return_value.__enter__.return_value = mock_event_source


try:
with contextlib.suppress(Exception):
with sse_client( with sse_client(
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
) as (read_queue, write_queue): ) as (read_queue, write_queue):
assert call_args is not None assert call_args is not None
timeout_arg = call_args[1]["timeout"] timeout_arg = call_args[1]["timeout"]
assert timeout_arg.read == custom_sse_timeout assert timeout_arg.read == custom_sse_timeout
except Exception:
# Connection might fail due to mocking, but we tested the configuration
pass




def test_sse_transport_endpoint_validation(): def test_sse_transport_endpoint_validation():
# Mock connection that raises an exception # Mock connection that raises an exception
mock_sse_connect.side_effect = Exception("Connection failed") mock_sse_connect.side_effect = Exception("Connection failed")


try:
with contextlib.suppress(Exception):
with sse_client(test_url) as (rq, wq): with sse_client(test_url) as (rq, wq):
read_queue = rq read_queue = rq
write_queue = wq write_queue = wq
except Exception:
pass # Expected to fail


# Queues should be cleaned up even on exception # Queues should be cleaned up even on exception
# Note: In real implementation, cleanup should put None to signal shutdown # Note: In real implementation, cleanup should put None to signal shutdown
mock_event_source.iter_sse.return_value = [] mock_event_source.iter_sse.return_value = []
mock_sse_connect.return_value.__enter__.return_value = mock_event_source mock_sse_connect.return_value.__enter__.return_value = mock_event_source


try:
with contextlib.suppress(Exception):
with sse_client(test_url, headers=custom_headers): with sse_client(test_url, headers=custom_headers):
pass pass
except Exception:
pass # Expected due to mocking


# Verify headers were passed to client factory # Verify headers were passed to client factory
mock_client_factory.assert_called_with(headers=custom_headers) mock_client_factory.assert_called_with(headers=custom_headers)

Laden…
Abbrechen
Speichern