Sfoglia il codice sorgente

update sql in batch (#24801)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
tags/1.9.0
Asuka Minato 1 mese fa
parent
commit
cbc0e639e4
Nessun account collegato all'indirizzo email del committer
49 ha cambiato i file con 281 aggiunte e 277 eliminazioni
  1. 9
    11
      api/commands.py
  2. 5
    5
      api/controllers/console/apikey.py
  3. 3
    5
      api/controllers/console/datasets/data_source.py
  4. 16
    15
      api/controllers/console/datasets/datasets.py
  5. 2
    1
      api/controllers/console/datasets/datasets_document.py
  6. 9
    7
      api/controllers/console/explore/installed_app.py
  7. 3
    1
      api/controllers/console/workspace/account.py
  8. 13
    12
      api/core/memory/token_buffer_memory.py
  9. 5
    6
      api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
  10. 2
    1
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
  11. 6
    5
      api/core/tools/custom_tool/provider.py
  12. 1
    3
      api/core/tools/tool_label_manager.py
  13. 6
    6
      api/core/tools/tool_manager.py
  14. 3
    1
      api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
  15. 3
    1
      api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
  16. 6
    4
      api/models/account.py
  17. 7
    7
      api/models/dataset.py
  18. 3
    3
      api/models/model.py
  19. 8
    10
      api/schedule/clean_unused_datasets_task.py
  20. 4
    3
      api/schedule/mail_clean_document_notify_task.py
  21. 6
    6
      api/schedule/update_tidb_serverless_status_task.py
  22. 3
    5
      api/services/annotation_service.py
  23. 7
    5
      api/services/auth/api_key_auth_service.py
  24. 2
    1
      api/services/clear_free_plan_tenant_expired_logs.py
  25. 38
    59
      api/services/dataset_service.py
  26. 4
    6
      api/services/model_load_balancing_service.py
  27. 8
    10
      api/services/recommend_app/database/database_retrieval.py
  28. 15
    20
      api/services/tag_service.py
  29. 2
    3
      api/services/tools/api_tools_manage_service.py
  30. 4
    2
      api/services/tools/workflow_tools_manage_service.py
  31. 2
    1
      api/tasks/annotation/enable_annotation_reply_task.py
  32. 5
    2
      api/tasks/batch_clean_document_task.py
  33. 3
    2
      api/tasks/clean_dataset_task.py
  34. 2
    1
      api/tasks/clean_document_task.py
  35. 4
    1
      api/tasks/clean_notion_document_task.py
  36. 7
    10
      api/tasks/deal_dataset_vector_index_task.py
  37. 4
    5
      api/tasks/disable_segments_from_index_task.py
  38. 4
    1
      api/tasks/document_indexing_sync_task.py
  39. 2
    1
      api/tasks/document_indexing_update_task.py
  40. 4
    1
      api/tasks/duplicate_document_indexing_task.py
  41. 4
    5
      api/tasks/enable_segments_to_index_task.py
  42. 2
    1
      api/tasks/remove_document_from_index_task.py
  43. 4
    1
      api/tasks/retry_document_indexing_task.py
  44. 2
    1
      api/tasks/sync_website_document_indexing_task.py
  45. 4
    3
      api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py
  46. 7
    2
      api/tests/test_containers_integration_tests/services/test_tag_service.py
  47. 4
    5
      api/tests/test_containers_integration_tests/services/test_web_conversation_service.py
  48. 12
    8
      api/tests/unit_tests/services/auth/test_api_key_auth_service.py
  49. 2
    2
      api/tests/unit_tests/services/auth/test_auth_integration.py

+ 9
- 11
api/commands.py Vedi File

@@ -212,7 +212,9 @@ def migrate_annotation_vector_database():
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
annotations = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
@@ -367,29 +369,25 @@ def migrate_knowledge_vector_database():
)
raise e

dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()

documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
).all()

for segment in segments:
document = Document(

+ 5
- 5
api/controllers/console/apikey.py Vedi File

@@ -60,11 +60,11 @@ class BaseApiKeyListResource(Resource):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return {"items": keys}

@marshal_with(api_key_fields)

+ 3
- 5
api/controllers/console/datasets/data_source.py Vedi File

@@ -29,14 +29,12 @@ class DataSourceApi(Resource):
@marshal_with(integrate_list_fields)
def get(self):
# get workspace data source integrates
data_source_integrates = (
db.session.query(DataSourceOauthBinding)
.where(
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
.all()
)
).all()

base_url = request.url_root.rstrip("/")
data_source_oauth_base_path = "/console/api/oauth/data-source"

+ 16
- 15
api/controllers/console/datasets/datasets.py Vedi File

@@ -2,6 +2,7 @@ import flask_restx
from flask import request
from flask_login import current_user
from flask_restx import Resource, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound

import services
@@ -411,11 +412,11 @@ class DatasetIndexingEstimateApi(Resource):
extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all()
)
file_details = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
).all()

if file_details is None:
raise NotFound("File not found.")
@@ -518,11 +519,11 @@ class DatasetIndexingStatusApi(Resource):
@account_initialization_required
def get(self, dataset_id):
dataset_id = str(dataset_id)
documents = (
db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all()
)
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
).all()
documents_status = []
for document in documents:
completed_segments = (
@@ -569,11 +570,11 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
keys = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all()
)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
).all()
return {"items": keys}

@setup_required

+ 2
- 1
api/controllers/console/datasets/datasets_document.py Vedi File

@@ -1,5 +1,6 @@
import logging
from argparse import ArgumentTypeError
from collections.abc import Sequence
from typing import Literal, cast

from flask import request
@@ -79,7 +80,7 @@ class DocumentResource(Resource):

return document

def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")

+ 9
- 7
api/controllers/console/explore/installed_app.py Vedi File

@@ -3,7 +3,7 @@ from typing import Any

from flask import request
from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound

from controllers.console import api
@@ -33,13 +33,15 @@ class InstalledAppsListApi(Resource):
current_tenant_id = current_user.current_tenant_id

if app_id:
installed_apps = (
db.session.query(InstalledApp)
.where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
installed_apps = db.session.scalars(
select(InstalledApp).where(
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
)
).all()
else:
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
installed_apps = db.session.scalars(
select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id)
).all()

if current_user.current_tenant is None:
raise ValueError("current_user.current_tenant must not be None")

+ 3
- 1
api/controllers/console/workspace/account.py Vedi File

@@ -248,7 +248,9 @@ class AccountIntegrateApi(Resource):
raise ValueError("Invalid user account")
account = current_user

account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
).all()

base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login"

+ 13
- 12
api/core/memory/token_buffer_memory.py Vedi File

@@ -32,11 +32,16 @@ class TokenBufferMemory:
self.model_instance = model_instance

def _build_prompt_message_with_files(
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
self,
message_files: Sequence[MessageFile],
text_content: str,
message: Message,
app_record,
is_user_message: bool,
) -> PromptMessage:
"""
Build prompt message with files.
:param message_files: list of MessageFile objects
:param message_files: Sequence of MessageFile objects
:param text_content: text content of the message
:param message: Message object
:param app_record: app record
@@ -128,14 +133,12 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = []
for message in messages:
# Process user message with files
user_files = (
db.session.query(MessageFile)
.where(
user_files = db.session.scalars(
select(MessageFile).where(
MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
)
.all()
)
).all()

if user_files:
user_prompt_message = self._build_prompt_message_with_files(
@@ -150,11 +153,9 @@ class TokenBufferMemory:
prompt_messages.append(UserPromptMessage(content=message.query))

# Process assistant message with files
assistant_files = (
db.session.query(MessageFile)
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
.all()
)
assistant_files = db.session.scalars(
select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
).all()

if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files(

+ 5
- 6
api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py Vedi File

@@ -15,6 +15,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from sqlalchemy import select

from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
@@ -699,8 +700,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):

def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
workflow_nodes = db.session.scalars(
select(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
@@ -713,10 +714,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
).all()
return workflow_nodes

def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:

+ 2
- 1
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py Vedi File

@@ -1,5 +1,6 @@
import time
import uuid
from collections.abc import Sequence

import requests
from requests.auth import HTTPDigestAuth
@@ -139,7 +140,7 @@ class TidbService:

@staticmethod
def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: list[TidbAuthBinding],
tidb_serverless_list: Sequence[TidbAuthBinding],
project_id: str,
api_url: str,
iam_url: str,

+ 6
- 5
api/core/tools/custom_tool/provider.py Vedi File

@@ -1,4 +1,5 @@
from pydantic import Field
from sqlalchemy import select

from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_provider import ToolProviderController
@@ -176,11 +177,11 @@ class ApiToolProviderController(ToolProviderController):
tools: list[ApiTool] = []

# get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)
db_providers = db.session.scalars(
select(ApiToolProvider).where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name
)
).all()

if db_providers and len(db_providers) != 0:
for db_provider in db_providers:

+ 1
- 3
api/core/tools/tool_label_manager.py Vedi File

@@ -87,9 +87,7 @@ class ToolLabelManager:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]

labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()

tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}


+ 6
- 6
api/core/tools/tool_manager.py Vedi File

@@ -667,9 +667,9 @@ class ToolManager:

# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
)
db_api_providers = db.session.scalars(
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
).all()

api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
@@ -690,9 +690,9 @@ class ToolManager:

if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_providers = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()

workflow_provider_controllers: list[WorkflowToolProviderController] = []
for workflow_provider in workflow_providers:

+ 3
- 1
api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py Vedi File

@@ -1,3 +1,5 @@
from sqlalchemy import select

from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
@@ -13,7 +15,7 @@ def handle(sender, **kwargs):

dataset_ids = get_dataset_ids_from_model_config(app_model_config)

app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()

removed_dataset_ids: set[str] = set()
if not app_dataset_joins:

+ 3
- 1
api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py Vedi File

@@ -1,5 +1,7 @@
from typing import cast

from sqlalchemy import select

from core.workflow.nodes import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from events.app_event import app_published_workflow_was_updated
@@ -15,7 +17,7 @@ def handle(sender, **kwargs):
published_workflow = cast(Workflow, published_workflow)

dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()

removed_dataset_ids: set[str] = set()
if not app_dataset_joins:

+ 6
- 4
api/models/account.py Vedi File

@@ -218,10 +218,12 @@ class Tenant(Base):
updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())

def get_accounts(self) -> list[Account]:
return (
db.session.query(Account)
.where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.all()
return list(
db.session.scalars(
select(Account).where(
Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id
)
).all()
)

@property

+ 7
- 7
api/models/dataset.py Vedi File

@@ -208,7 +208,9 @@ class Dataset(Base):

@property
def doc_metadata(self):
dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all()
dataset_metadatas = db.session.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id)
).all()

doc_metadata = [
{
@@ -1055,13 +1057,11 @@ class ExternalKnowledgeApis(Base):

@property
def dataset_bindings(self) -> list[dict[str, Any]]:
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all()
)
external_knowledge_bindings = db.session.scalars(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
).all()
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
dataset_bindings: list[dict[str, Any]] = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})

+ 3
- 3
api/models/model.py Vedi File

@@ -812,7 +812,7 @@ class Conversation(Base):

@property
def status_count(self):
messages = db.session.query(Message).where(Message.conversation_id == self.id).all()
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
status_counts = {
WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0,
@@ -1090,7 +1090,7 @@ class Message(Base):

@property
def feedbacks(self):
feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all()
feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all()
return feedbacks

@property
@@ -1145,7 +1145,7 @@ class Message(Base):
def message_files(self) -> list[dict[str, Any]]:
from factories import file_factory

message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first()
if not current_app:
raise ValueError(f"App {self.app_id} not found")

+ 8
- 10
api/schedule/clean_unused_datasets_task.py Vedi File

@@ -96,11 +96,11 @@ def clean_unused_datasets_task():
break

for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
dataset_query = db.session.scalars(
select(DatasetQuery).where(
DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id
)
).all()

if not dataset_query or len(dataset_query) == 0:
try:
@@ -121,15 +121,13 @@ def clean_unused_datasets_task():
if should_clean:
# Add auto disable log if required
if add_logs:
documents = (
db.session.query(Document)
.where(
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
).all()
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,

+ 4
- 3
api/schedule/mail_clean_document_notify_task.py Vedi File

@@ -3,6 +3,7 @@ import time
from collections import defaultdict

import click
from sqlalchemy import select

import app
from configs import dify_config
@@ -31,9 +32,9 @@ def mail_clean_document_notify_task():

# send document clean notify mail
try:
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all()
)
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
).all()
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:

+ 6
- 6
api/schedule/update_tidb_serverless_status_task.py Vedi File

@@ -1,6 +1,8 @@
import time
from collections.abc import Sequence

import click
from sqlalchemy import select

import app
from configs import dify_config
@@ -15,11 +17,9 @@ def update_tidb_serverless_status_task():
start_at = time.perf_counter()
try:
# check the number of idle tidb serverless
tidb_serverless_list = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
tidb_serverless_list = db.session.scalars(
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
).all()
if len(tidb_serverless_list) == 0:
return
# update tidb serverless status
@@ -32,7 +32,7 @@ def update_tidb_serverless_status_task():
click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green"))


def update_clusters(tidb_serverless_list: list[TidbAuthBinding]):
def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]):
try:
# batch 20
for i in range(0, len(tidb_serverless_list), 20):

+ 3
- 5
api/services/annotation_service.py Vedi File

@@ -263,11 +263,9 @@ class AppAnnotationService:

db.session.delete(annotation)

annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
annotation_hit_histories = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history)

+ 7
- 5
api/services/auth/api_key_auth_service.py Vedi File

@@ -1,5 +1,7 @@
import json

from sqlalchemy import select

from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
@@ -9,11 +11,11 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings

@staticmethod

+ 2
- 1
api/services/clear_free_plan_tenant_expired_logs.py Vedi File

@@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor

import click
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker

from configs import dify_config
@@ -115,7 +116,7 @@ class ClearFreePlanTenantExpiredLogs:
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:

+ 38
- 59
api/services/dataset_service.py Vedi File

@@ -6,6 +6,7 @@ import secrets
import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, Optional

import sqlalchemy as sa
@@ -741,14 +742,12 @@ class DatasetService:
}
# get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.where(
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
.all()
)
).all()
if dataset_auto_disable_logs:
return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs],
@@ -885,69 +884,58 @@ class DocumentService:
return document

@staticmethod
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
).all()
return documents

@staticmethod
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
.all()
)
).all()

return documents

@staticmethod
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
Document.archived == False,
)
.all()
)
).all()

return documents

@staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all()
)
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
).all()
return documents

@staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]:
assert isinstance(current_user, Account)

documents = (
db.session.query(Document)
.where(
documents = db.session.scalars(
select(Document).where(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id,
)
.all()
)
).all()

return documents

@@ -984,7 +972,7 @@ class DocumentService:
# Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0:
return
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
for document in documents
@@ -2424,16 +2412,14 @@ class SegmentService:
if not segment_ids or len(segment_ids) == 0:
return
if action == "enable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False,
)
.all()
)
).all()
if not segments:
return
real_deal_segment_ids = []
@@ -2451,16 +2437,14 @@ class SegmentService:

enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True,
)
.all()
)
).all()
if not segments:
return
real_deal_segment_ids = []
@@ -2532,16 +2516,13 @@ class SegmentService:
dataset: Dataset,
) -> list[ChildChunk]:
assert isinstance(current_user, Account)

child_chunks = (
db.session.query(ChildChunk)
.where(
child_chunks = db.session.scalars(
select(ChildChunk).where(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
)
.all()
)
).all()
child_chunks_map = {chunk.id: chunk for chunk in child_chunks}

new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
@@ -2751,13 +2732,11 @@ class DatasetCollectionBindingService:
class DatasetPermissionService:
@classmethod
def get_dataset_partial_member_list(cls, dataset_id):
user_list_query = (
db.session.query(
user_list_query = db.session.scalars(
select(
DatasetPermission.account_id,
)
.where(DatasetPermission.dataset_id == dataset_id)
.all()
)
).where(DatasetPermission.dataset_id == dataset_id)
).all()

user_list = []
for user in user_list_query:

+ 4
- 6
api/services/model_load_balancing_service.py Vedi File

@@ -3,7 +3,7 @@ import logging
from json import JSONDecodeError
from typing import Optional, Union

from sqlalchemy import or_
from sqlalchemy import or_, select

from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration
@@ -322,16 +322,14 @@ class ModelLoadBalancingService:
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")

current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
current_load_balancing_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
.all()
)
).all()

# id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}

+ 8
- 10
api/services/recommend_app/database/database_retrieval.py Vedi File

@@ -1,5 +1,7 @@
from typing import Optional

from sqlalchemy import select

from constants.languages import languages
from extensions.ext_database import db
from models.model import App, RecommendedApp
@@ -31,18 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
:param language: language
:return:
"""
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
).all()

if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
).all()

categories = set()
recommended_apps_result = []

+ 15
- 20
api/services/tag_service.py Vedi File

@@ -2,7 +2,7 @@ import uuid
from typing import Optional

from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound

from extensions.ext_database import db
@@ -29,35 +29,30 @@ class TagService:
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tags = (
db.session.query(Tag)
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
tags = db.session.scalars(
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tag_bindings = (
db.session.query(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
tag_bindings = db.session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
)
).all()
return tag_bindings

@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
if not tag_type or not tag_name:
return []
tags = (
db.session.query(Tag)
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
tags = list(
db.session.scalars(
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
)
if not tags:
return []
@@ -117,7 +112,7 @@ class TagService:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)

+ 2
- 3
api/services/tools/api_tools_manage_service.py Vedi File

@@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any, cast

from httpx import get
from sqlalchemy import select

from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -443,9 +444,7 @@ class ApiToolManageService:
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()

result: list[ToolProviderApiEntity] = []


+ 4
- 2
api/services/tools/workflow_tools_manage_service.py Vedi File

@@ -3,7 +3,7 @@ from collections.abc import Mapping
from datetime import datetime
from typing import Any

from sqlalchemy import or_
from sqlalchemy import or_, select

from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@@ -186,7 +186,9 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()

tools: list[WorkflowToolProviderController] = []
for provider in db_tools:

+ 2
- 1
api/tasks/annotation/enable_annotation_reply_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
@@ -39,7 +40,7 @@ def enable_annotation_reply_task(
db.session.close()
return

annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all()
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"


+ 5
- 2
api/tasks/batch_clean_document_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -34,7 +35,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if not dataset:
raise Exception("Document has no dataset")

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -59,7 +62,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form

db.session.commit()
if file_ids:
files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)

+ 3
- 2
api/tasks/clean_dataset_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -55,8 +56,8 @@ def clean_dataset_task(
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()

# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled

+ 2
- 1
api/tasks/clean_document_task.py Vedi File

@@ -4,6 +4,7 @@ from typing import Optional

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -35,7 +36,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if not dataset:
raise Exception("Document has no dataset")

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]

+ 4
- 1
api/tasks/clean_notion_document_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document)

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]

index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

+ 7
- 10
api/tasks/deal_dataset_vector_index_task.py Vedi File

@@ -4,6 +4,7 @@ from typing import Literal

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
elif action == "add":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()

if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
@@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
# add new index
if dataset_documents:
# update document status

+ 4
- 5
api/tasks/disable_segments_from_index_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()

segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
).all()

if not segments:
db.session.close()

+ 4
- 1
api/tasks/document_indexing_sync_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
@@ -85,7 +86,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]

# delete from vector index

+ 2
- 1
api/tasks/document_indexing_update_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]


+ 4
- 1
api/tasks/duplicate_document_indexing_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
@@ -79,7 +80,9 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]


+ 4
- 5
api/tasks/enable_segments_to_index_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()

segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
).all()
if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
db.session.close()

+ 2
- 1
api/tasks/remove_document_from_index_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str):

index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:

+ 4
- 1
api/tasks/retry_document_indexing_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -69,7 +70,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index

+ 2
- 1
api/tasks/sync_website_document_indexing_task.py Vedi File

@@ -3,6 +3,7 @@ import time

import click
from celery import shared_task
from sqlalchemy import select

from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()

segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index

+ 4
- 3
api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py Vedi File

@@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch

import pytest
from faker import Faker
from sqlalchemy import select

from models.account import TenantAccountJoin, TenantAccountRole
from models.model import Account, Tenant
@@ -468,7 +469,7 @@ class TestModelLoadBalancingService:
assert load_balancing_config.id is not None

# Verify inherit config was created in database
inherit_configs = (
db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all()
)
inherit_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
).all()
assert len(inherit_configs) == 1

+ 7
- 2
api/tests/test_containers_integration_tests/services/test_tag_service.py Vedi File

@@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch

import pytest
from faker import Faker
from sqlalchemy import select
from werkzeug.exceptions import NotFound

from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
@@ -954,7 +955,9 @@ class TestTagService:
from extensions.ext_database import db

# Verify only one binding exists
bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
bindings = db.session.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all()
assert len(bindings) == 1

def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies):
@@ -1064,7 +1067,9 @@ class TestTagService:
# No error should be raised, and database state should remain unchanged
from extensions.ext_database import db

bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
bindings = db.session.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all()
assert len(bindings) == 0

def test_check_target_exists_knowledge_success(

+ 4
- 5
api/tests/test_containers_integration_tests/services/test_web_conversation_service.py Vedi File

@@ -2,6 +2,7 @@ from unittest.mock import patch

import pytest
from faker import Faker
from sqlalchemy import select

from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account
@@ -354,16 +355,14 @@ class TestWebConversationService:
# Verify only one pinned conversation record exists
from extensions.ext_database import db

pinned_conversations = (
db.session.query(PinnedConversation)
.where(
pinned_conversations = db.session.scalars(
select(PinnedConversation).where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id,
)
.all()
)
).all()

assert len(pinned_conversations) == 1


+ 12
- 8
api/tests/unit_tests/services/auth/test_api_key_auth_service.py Vedi File

@@ -28,18 +28,20 @@ class TestApiKeyAuthService:
mock_binding.provider = self.provider
mock_binding.disabled = False

mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
mock_session.scalars.return_value.all.return_value = [mock_binding]

result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)

assert len(result) == 1
assert result[0].tenant_id == self.tenant_id
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
assert mock_session.scalars.call_count == 1
select_arg = mock_session.scalars.call_args[0][0]
assert "data_source_api_key_auth_binding" in str(select_arg).lower()

@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_empty(self, mock_session):
"""Test get provider auth list - empty result"""
mock_session.query.return_value.where.return_value.all.return_value = []
mock_session.scalars.return_value.all.return_value = []

result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)

@@ -48,13 +50,15 @@ class TestApiKeyAuthService:
@patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_filters_disabled(self, mock_session):
"""Test get provider auth list - filters disabled items"""
mock_session.query.return_value.where.return_value.all.return_value = []
mock_session.scalars.return_value.all.return_value = []

ApiKeyAuthService.get_provider_auth_list(self.tenant_id)

# Verify where conditions include disabled.is_(False)
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2 # tenant_id and disabled filter conditions
select_stmt = mock_session.scalars.call_args[0][0]
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
# Ensure both tenant filter and disabled filter exist
where_strs = [str(c).lower() for c in where_clauses]
assert any("tenant_id" in s for s in where_strs)
assert any("disabled" in s for s in where_strs)

@patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")

+ 2
- 2
api/tests/unit_tests/services/auth/test_auth_integration.py Vedi File

@@ -63,10 +63,10 @@ class TestAuthIntegration:
tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)

mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)

mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)

assert len(result1) == 1

Loading…
Annulla
Salva