Преглед изворни кода

refactor: replace try-except blocks with contextlib.suppress for cleaner exception handling (#24284)

tags/1.8.0
Guangdong Liu пре 2 месеци
родитељ
комит
1abf1240b2
No account linked to committer's email address

+ 3
- 3
api/controllers/console/wraps.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import json
import os
import time
@@ -178,7 +179,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
def cloud_utm_record(view):
@wraps(view)
def decorated(*args, **kwargs):
try:
with contextlib.suppress(Exception):
features = FeatureService.get_features(current_user.current_tenant_id)

if features.billing.enabled:
@@ -187,8 +188,7 @@ def cloud_utm_record(view):
if utm_info:
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
except Exception as e:
pass

return view(*args, **kwargs)

return decorated

+ 2
- 3
api/core/helper/trace_id_helper.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import re
from collections.abc import Mapping
from typing import Any, Optional
@@ -97,10 +98,8 @@ def parse_traceparent_header(traceparent: str) -> Optional[str]:
Reference:
W3C Trace Context Specification: https://www.w3.org/TR/trace-context/
"""
try:
with contextlib.suppress(Exception):
parts = traceparent.split("-")
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
except Exception:
pass
return None

+ 3
- 6
api/core/provider_manager.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
@@ -624,14 +625,12 @@ class ProviderManager:

for variable in provider_credential_secret_variables:
if variable in provider_credentials:
try:
with contextlib.suppress(ValueError):
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable) or "", # type: ignore
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
except ValueError:
pass

# cache provider credentials
provider_credentials_cache.set(credentials=provider_credentials)
@@ -672,14 +671,12 @@ class ProviderManager:

for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
try:
with contextlib.suppress(ValueError):
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
except ValueError:
pass

# cache provider model credentials
provider_model_credentials_cache.set(credentials=provider_model_credentials)

+ 6
- 15
api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import json
import logging
import queue
@@ -214,10 +215,8 @@ class ClickzettaConnectionPool:
return connection
else:
# Connection expired or invalid, close it
try:
with contextlib.suppress(Exception):
connection.close()
except Exception:
pass

# No valid connection found, create new one
return self._create_connection(config)
@@ -228,10 +227,8 @@ class ClickzettaConnectionPool:

if config_key not in self._pool_locks:
# Pool was cleaned up, just close the connection
try:
with contextlib.suppress(Exception):
connection.close()
except Exception:
pass
return

with self._pool_locks[config_key]:
@@ -243,10 +240,8 @@ class ClickzettaConnectionPool:
logger.debug("Returned ClickZetta connection to pool")
else:
# Pool full or connection invalid, close it
try:
with contextlib.suppress(Exception):
connection.close()
except Exception:
pass

def _cleanup_expired_connections(self) -> None:
"""Clean up expired connections from all pools."""
@@ -265,10 +260,8 @@ class ClickzettaConnectionPool:
if current_time - last_used < self._connection_timeout:
valid_connections.append((connection, last_used))
else:
try:
with contextlib.suppress(Exception):
connection.close()
except Exception:
pass

self._pools[config_key] = valid_connections

@@ -299,10 +292,8 @@ class ClickzettaConnectionPool:
with self._pool_locks[config_key]:
pool = self._pools[config_key]
for connection, _ in pool:
try:
with contextlib.suppress(Exception):
connection.close()
except Exception:
pass
pool.clear()



+ 2
- 3
api/core/rag/extractor/pdf_extractor.py Прегледај датотеку

@@ -1,5 +1,6 @@
"""Abstract interface for document loader implementations."""

import contextlib
from collections.abc import Iterator
from typing import Optional, cast

@@ -25,12 +26,10 @@ class PdfExtractor(BaseExtractor):
def extract(self) -> list[Document]:
plaintext_file_exists = False
if self._file_cache_key:
try:
with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
except FileNotFoundError:
pass
documents = list(self.load())
text_list = []
for document in documents:

+ 2
- 3
api/core/rag/extractor/unstructured/unstructured_eml_extractor.py Прегледај датотеку

@@ -1,4 +1,5 @@
import base64
import contextlib
import logging
from typing import Optional

@@ -33,7 +34,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
elements = partition_email(filename=self._file_path)

# noinspection PyBroadException
try:
with contextlib.suppress(Exception):
for element in elements:
element_text = element.text.strip()

@@ -43,8 +44,6 @@ class UnstructuredEmailExtractor(BaseExtractor):
element_decode = base64.b64decode(element_text)
soup = BeautifulSoup(element_decode.decode("utf-8"), "html.parser")
element.text = soup.get_text()
except Exception:
pass

from unstructured.chunking.title import chunk_by_title


+ 2
- 3
api/core/tools/entities/tool_entities.py Прегледај датотеку

@@ -1,4 +1,5 @@
import base64
import contextlib
import enum
from collections.abc import Mapping
from enum import Enum
@@ -227,10 +228,8 @@ class ToolInvokeMessage(BaseModel):
@classmethod
def decode_blob_message(cls, v):
if isinstance(v, dict) and "blob" in v:
try:
with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"])
except Exception:
pass
return v

@field_serializer("message")

+ 3
- 6
api/core/tools/tool_engine.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import json
from collections.abc import Generator, Iterable
from copy import deepcopy
@@ -69,10 +70,8 @@ class ToolEngine:
if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters}
else:
try:
with contextlib.suppress(Exception):
tool_parameters = json.loads(tool_parameters)
except Exception:
pass
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")

@@ -270,14 +269,12 @@ class ToolEngine:
if response.meta.get("mime_type"):
mimetype = response.meta.get("mime_type")
else:
try:
with contextlib.suppress(Exception):
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result:
mimetype = guess_type_result
except Exception:
pass

if not mimetype:
mimetype = "image/jpeg"

+ 3
- 4
api/core/tools/utils/configuration.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
from copy import deepcopy
from typing import Any

@@ -137,11 +138,9 @@ class ToolParameterConfigurationManager:
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
try:
has_secret_input = True
has_secret_input = True
with contextlib.suppress(Exception):
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except Exception:
pass

if has_secret_input:
cache.set(parameters)

+ 2
- 3
api/core/tools/utils/encryption.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
from copy import deepcopy
from typing import Any, Optional, Protocol

@@ -111,14 +112,12 @@ class ProviderConfigEncrypter:
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
with contextlib.suppress(Exception):
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue

data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass

self.provider_config_cache.set(data)
return data

+ 4
- 6
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import json
import logging
import uuid
@@ -666,10 +667,8 @@ class ParameterExtractorNode(BaseNode):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
try:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
except Exception:
pass
logger.info("extra error: %s", result)
return None

@@ -686,10 +685,9 @@ class ParameterExtractorNode(BaseNode):
if result[idx] == "{" or result[idx] == "[":
json_str = extract_json(result[idx:])
if json_str:
try:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
except Exception:
pass

logger.info("extra error: %s", result)
return None


+ 9
- 9
api/events/event_handlers/create_document_index.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import logging
import time

@@ -38,12 +39,11 @@ def handle(sender, **kwargs):
db.session.add(document)
db.session.commit()

try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))
except Exception:
pass
with contextlib.suppress(Exception):
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow"))

+ 2
- 4
api/extensions/ext_otel.py Прегледај датотеку

@@ -1,4 +1,5 @@
import atexit
import contextlib
import logging
import os
import platform
@@ -106,7 +107,7 @@ def init_app(app: DifyApp):
"""Custom logging handler that creates spans for logging.exception() calls"""

def emit(self, record: logging.LogRecord):
try:
with contextlib.suppress(Exception):
if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span(
@@ -126,9 +127,6 @@ def init_app(app: DifyApp):
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)

except Exception:
pass

from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter

+ 2
- 3
api/services/conversation_service.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
from collections.abc import Callable, Sequence
from typing import Any, Optional, Union

@@ -142,13 +143,11 @@ class ConversationService:
raise MessageNotExistsError()

# generate conversation name
try:
with contextlib.suppress(Exception):
name = LLMGenerator.generate_conversation_name(
app_model.tenant_id, message.query, conversation.id, app_model.id
)
conversation.name = name
except Exception:
pass

db.session.commit()


+ 2
- 3
api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import os

import pytest
@@ -44,10 +45,8 @@ class TestClickzettaVector(AbstractVectorTest):
yield vector

# Cleanup: delete the test collection
try:
with contextlib.suppress(Exception):
vector.delete()
except Exception:
pass

def test_clickzetta_vector_basic_operations(self, vector_store):
"""Test basic CRUD operations on Clickzetta vector store."""

+ 5
- 14
api/tests/unit_tests/core/mcp/client/test_sse.py Прегледај датотеку

@@ -1,3 +1,4 @@
import contextlib
import json
import queue
import threading
@@ -124,13 +125,10 @@ def test_sse_client_connection_validation():
mock_event_source.iter_sse.return_value = [endpoint_event]

# Test connection
try:
with contextlib.suppress(Exception):
with sse_client(test_url) as (read_queue, write_queue):
assert read_queue is not None
assert write_queue is not None
except Exception as e:
# Connection might fail due to mocking, but we're testing the validation logic
pass


def test_sse_client_error_handling():
@@ -178,7 +176,7 @@ def test_sse_client_timeout_configuration():
mock_event_source.iter_sse.return_value = []
mock_sse_connect.return_value.__enter__.return_value = mock_event_source

try:
with contextlib.suppress(Exception):
with sse_client(
test_url, headers=custom_headers, timeout=custom_timeout, sse_read_timeout=custom_sse_timeout
) as (read_queue, write_queue):
@@ -190,9 +188,6 @@ def test_sse_client_timeout_configuration():
assert call_args is not None
timeout_arg = call_args[1]["timeout"]
assert timeout_arg.read == custom_sse_timeout
except Exception:
# Connection might fail due to mocking, but we tested the configuration
pass


def test_sse_transport_endpoint_validation():
@@ -251,12 +246,10 @@ def test_sse_client_queue_cleanup():
# Mock connection that raises an exception
mock_sse_connect.side_effect = Exception("Connection failed")

try:
with contextlib.suppress(Exception):
with sse_client(test_url) as (rq, wq):
read_queue = rq
write_queue = wq
except Exception:
pass # Expected to fail

# Queues should be cleaned up even on exception
# Note: In real implementation, cleanup should put None to signal shutdown
@@ -283,11 +276,9 @@ def test_sse_client_headers_propagation():
mock_event_source.iter_sse.return_value = []
mock_sse_connect.return_value.__enter__.return_value = mock_event_source

try:
with contextlib.suppress(Exception):
with sse_client(test_url, headers=custom_headers):
pass
except Exception:
pass # Expected due to mocking

# Verify headers were passed to client factory
mock_client_factory.assert_called_with(headers=custom_headers)

Loading…
Откажи
Сачувај