Преглед на файлове

update sql in batch (#24801)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
tags/1.9.0
Asuka Minato преди 1 месец
родител
ревизия
cbc0e639e4
No account linked to committer's email address
променени са 49 файла, в които са добавени 281 реда и са изтрити 277 реда
  1. 9
    11
      api/commands.py
  2. 5
    5
      api/controllers/console/apikey.py
  3. 3
    5
      api/controllers/console/datasets/data_source.py
  4. 16
    15
      api/controllers/console/datasets/datasets.py
  5. 2
    1
      api/controllers/console/datasets/datasets_document.py
  6. 9
    7
      api/controllers/console/explore/installed_app.py
  7. 3
    1
      api/controllers/console/workspace/account.py
  8. 13
    12
      api/core/memory/token_buffer_memory.py
  9. 5
    6
      api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
  10. 2
    1
      api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
  11. 6
    5
      api/core/tools/custom_tool/provider.py
  12. 1
    3
      api/core/tools/tool_label_manager.py
  13. 6
    6
      api/core/tools/tool_manager.py
  14. 3
    1
      api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
  15. 3
    1
      api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
  16. 6
    4
      api/models/account.py
  17. 7
    7
      api/models/dataset.py
  18. 3
    3
      api/models/model.py
  19. 8
    10
      api/schedule/clean_unused_datasets_task.py
  20. 4
    3
      api/schedule/mail_clean_document_notify_task.py
  21. 6
    6
      api/schedule/update_tidb_serverless_status_task.py
  22. 3
    5
      api/services/annotation_service.py
  23. 7
    5
      api/services/auth/api_key_auth_service.py
  24. 2
    1
      api/services/clear_free_plan_tenant_expired_logs.py
  25. 38
    59
      api/services/dataset_service.py
  26. 4
    6
      api/services/model_load_balancing_service.py
  27. 8
    10
      api/services/recommend_app/database/database_retrieval.py
  28. 15
    20
      api/services/tag_service.py
  29. 2
    3
      api/services/tools/api_tools_manage_service.py
  30. 4
    2
      api/services/tools/workflow_tools_manage_service.py
  31. 2
    1
      api/tasks/annotation/enable_annotation_reply_task.py
  32. 5
    2
      api/tasks/batch_clean_document_task.py
  33. 3
    2
      api/tasks/clean_dataset_task.py
  34. 2
    1
      api/tasks/clean_document_task.py
  35. 4
    1
      api/tasks/clean_notion_document_task.py
  36. 7
    10
      api/tasks/deal_dataset_vector_index_task.py
  37. 4
    5
      api/tasks/disable_segments_from_index_task.py
  38. 4
    1
      api/tasks/document_indexing_sync_task.py
  39. 2
    1
      api/tasks/document_indexing_update_task.py
  40. 4
    1
      api/tasks/duplicate_document_indexing_task.py
  41. 4
    5
      api/tasks/enable_segments_to_index_task.py
  42. 2
    1
      api/tasks/remove_document_from_index_task.py
  43. 4
    1
      api/tasks/retry_document_indexing_task.py
  44. 2
    1
      api/tasks/sync_website_document_indexing_task.py
  45. 4
    3
      api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py
  46. 7
    2
      api/tests/test_containers_integration_tests/services/test_tag_service.py
  47. 4
    5
      api/tests/test_containers_integration_tests/services/test_web_conversation_service.py
  48. 12
    8
      api/tests/unit_tests/services/auth/test_api_key_auth_service.py
  49. 2
    2
      api/tests/unit_tests/services/auth/test_auth_integration.py

+ 9
- 11
api/commands.py Целия файл

if not dataset_collection_binding: if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}") click.echo(f"App annotation collection binding not found: {app.id}")
continue continue
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
annotations = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
) )
raise e raise e


dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset.id, DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
) )
.all()
)
).all()


documents = [] documents = []
segments_count = 0 segments_count = 0
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
) )
.all()
)
).all()


for segment in segments: for segment in segments:
document = Document( document = Document(

+ 5
- 5
api/controllers/console/apikey.py Целия файл

assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return {"items": keys} return {"items": keys}


@marshal_with(api_key_fields) @marshal_with(api_key_fields)

+ 3
- 5
api/controllers/console/datasets/data_source.py Целия файл

@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = (
db.session.query(DataSourceOauthBinding)
.where(
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
.all()
)
).all()


base_url = request.url_root.rstrip("/") base_url = request.url_root.rstrip("/")
data_source_oauth_base_path = "/console/api/oauth/data-source" data_source_oauth_base_path = "/console/api/oauth/data-source"

+ 16
- 15
api/controllers/console/datasets/datasets.py Целия файл

from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, marshal, marshal_with, reqparse from flask_restx import Resource, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound


import services import services
extract_settings = [] extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file": if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all()
)
file_details = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
).all()


if file_details is None: if file_details is None:
raise NotFound("File not found.") raise NotFound("File not found.")
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = (
db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all()
)
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
).all()
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
@account_initialization_required @account_initialization_required
@marshal_with(api_key_list) @marshal_with(api_key_list)
def get(self): def get(self):
keys = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all()
)
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
).all()
return {"items": keys} return {"items": keys}


@setup_required @setup_required

+ 2
- 1
api/controllers/console/datasets/datasets_document.py Целия файл

import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from collections.abc import Sequence
from typing import Literal, cast from typing import Literal, cast


from flask import request from flask import request


return document return document


def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]:
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")

+ 9
- 7
api/controllers/console/explore/installed_app.py Целия файл



from flask import request from flask import request
from flask_restx import Resource, inputs, marshal_with, reqparse from flask_restx import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound


from controllers.console import api from controllers.console import api
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id


if app_id: if app_id:
installed_apps = (
db.session.query(InstalledApp)
.where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
installed_apps = db.session.scalars(
select(InstalledApp).where(
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
)
).all()
else: else:
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
installed_apps = db.session.scalars(
select(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id)
).all()


if current_user.current_tenant is None: if current_user.current_tenant is None:
raise ValueError("current_user.current_tenant must not be None") raise ValueError("current_user.current_tenant must not be None")

+ 3
- 1
api/controllers/console/workspace/account.py Целия файл

raise ValueError("Invalid user account") raise ValueError("Invalid user account")
account = current_user account = current_user


account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
).all()


base_url = request.url_root.rstrip("/") base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login" oauth_base_path = "/console/api/oauth/login"

+ 13
- 12
api/core/memory/token_buffer_memory.py Целия файл

self.model_instance = model_instance self.model_instance = model_instance


def _build_prompt_message_with_files( def _build_prompt_message_with_files(
self, message_files: list[MessageFile], text_content: str, message: Message, app_record, is_user_message: bool
self,
message_files: Sequence[MessageFile],
text_content: str,
message: Message,
app_record,
is_user_message: bool,
) -> PromptMessage: ) -> PromptMessage:
""" """
Build prompt message with files. Build prompt message with files.
:param message_files: list of MessageFile objects
:param message_files: Sequence of MessageFile objects
:param text_content: text content of the message :param text_content: text content of the message
:param message: Message object :param message: Message object
:param app_record: app record :param app_record: app record
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
for message in messages: for message in messages:
# Process user message with files # Process user message with files
user_files = (
db.session.query(MessageFile)
.where(
user_files = db.session.scalars(
select(MessageFile).where(
MessageFile.message_id == message.id, MessageFile.message_id == message.id,
(MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)), (MessageFile.belongs_to == "user") | (MessageFile.belongs_to.is_(None)),
) )
.all()
)
).all()


if user_files: if user_files:
user_prompt_message = self._build_prompt_message_with_files( user_prompt_message = self._build_prompt_message_with_files(
prompt_messages.append(UserPromptMessage(content=message.query)) prompt_messages.append(UserPromptMessage(content=message.query))


# Process assistant message with files # Process assistant message with files
assistant_files = (
db.session.query(MessageFile)
.where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
.all()
)
assistant_files = db.session.scalars(
select(MessageFile).where(MessageFile.message_id == message.id, MessageFile.belongs_to == "assistant")
).all()


if assistant_files: if assistant_files:
assistant_prompt_message = self._build_prompt_message_with_files( assistant_prompt_message = self._build_prompt_message_with_files(

+ 5
- 6
api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py Целия файл

from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from sqlalchemy import select


from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig


def _get_workflow_nodes(self, workflow_run_id: str): def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes""" """Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
workflow_nodes = db.session.scalars(
select(
WorkflowNodeExecutionModel.id, WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id, WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id, WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.elapsed_time, WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata, WorkflowNodeExecutionModel.execution_metadata,
)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
).all()
return workflow_nodes return workflow_nodes


def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:

+ 2
- 1
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py Целия файл

import time import time
import uuid import uuid
from collections.abc import Sequence


import requests import requests
from requests.auth import HTTPDigestAuth from requests.auth import HTTPDigestAuth


@staticmethod @staticmethod
def batch_update_tidb_serverless_cluster_status( def batch_update_tidb_serverless_cluster_status(
tidb_serverless_list: list[TidbAuthBinding],
tidb_serverless_list: Sequence[TidbAuthBinding],
project_id: str, project_id: str,
api_url: str, api_url: str,
iam_url: str, iam_url: str,

+ 6
- 5
api/core/tools/custom_tool/provider.py Целия файл

from pydantic import Field from pydantic import Field
from sqlalchemy import select


from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
tools: list[ApiTool] = [] tools: list[ApiTool] = []


# get tenant api providers # get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)
db_providers = db.session.scalars(
select(ApiToolProvider).where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name
)
).all()


if db_providers and len(db_providers) != 0: if db_providers and len(db_providers) != 0:
for db_provider in db_providers: for db_provider in db_providers:

+ 1
- 3
api/core/tools/tool_label_manager.py Целия файл

assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]


labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()


tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}



+ 6
- 6
api/core/tools/tool_manager.py Целия файл



# get db api providers # get db api providers
if "api" in filters: if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
)
db_api_providers = db.session.scalars(
select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)
).all()


api_provider_controllers: list[dict[str, Any]] = [ api_provider_controllers: list[dict[str, Any]] = [
{"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}


if "workflow" in filters: if "workflow" in filters:
# get workflow providers # get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_providers = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()


workflow_provider_controllers: list[WorkflowToolProviderController] = [] workflow_provider_controllers: list[WorkflowToolProviderController] = []
for workflow_provider in workflow_providers: for workflow_provider in workflow_providers:

+ 3
- 1
api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py Целия файл

from sqlalchemy import select

from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import AppDatasetJoin from models.dataset import AppDatasetJoin


dataset_ids = get_dataset_ids_from_model_config(app_model_config) dataset_ids = get_dataset_ids_from_model_config(app_model_config)


app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()


removed_dataset_ids: set[str] = set() removed_dataset_ids: set[str] = set()
if not app_dataset_joins: if not app_dataset_joins:

+ 3
- 1
api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py Целия файл

from typing import cast from typing import cast


from sqlalchemy import select

from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from events.app_event import app_published_workflow_was_updated from events.app_event import app_published_workflow_was_updated
published_workflow = cast(Workflow, published_workflow) published_workflow = cast(Workflow, published_workflow)


dataset_ids = get_dataset_ids_from_workflow(published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
app_dataset_joins = db.session.scalars(select(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id)).all()


removed_dataset_ids: set[str] = set() removed_dataset_ids: set[str] = set()
if not app_dataset_joins: if not app_dataset_joins:

+ 6
- 4
api/models/account.py Целия файл

updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())


def get_accounts(self) -> list[Account]: def get_accounts(self) -> list[Account]:
return (
db.session.query(Account)
.where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.all()
return list(
db.session.scalars(
select(Account).where(
Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id
)
).all()
) )


@property @property

+ 7
- 7
api/models/dataset.py Целия файл



@property @property
def doc_metadata(self): def doc_metadata(self):
dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all()
dataset_metadatas = db.session.scalars(
select(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id)
).all()


doc_metadata = [ doc_metadata = [
{ {


@property @property
def dataset_bindings(self) -> list[dict[str, Any]]: def dataset_bindings(self) -> list[dict[str, Any]]:
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
.where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all()
)
external_knowledge_bindings = db.session.scalars(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
).all()
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
dataset_bindings: list[dict[str, Any]] = [] dataset_bindings: list[dict[str, Any]] = []
for dataset in datasets: for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name}) dataset_bindings.append({"id": dataset.id, "name": dataset.name})

+ 3
- 3
api/models/model.py Целия файл



@property @property
def status_count(self): def status_count(self):
messages = db.session.query(Message).where(Message.conversation_id == self.id).all()
messages = db.session.scalars(select(Message).where(Message.conversation_id == self.id)).all()
status_counts = { status_counts = {
WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0, WorkflowExecutionStatus.SUCCEEDED: 0,


@property @property
def feedbacks(self): def feedbacks(self):
feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all()
feedbacks = db.session.scalars(select(MessageFeedback).where(MessageFeedback.message_id == self.id)).all()
return feedbacks return feedbacks


@property @property
def message_files(self) -> list[dict[str, Any]]: def message_files(self) -> list[dict[str, Any]]:
from factories import file_factory from factories import file_factory


message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.query(App).where(App.id == self.app_id).first() current_app = db.session.query(App).where(App.id == self.app_id).first()
if not current_app: if not current_app:
raise ValueError(f"App {self.app_id} not found") raise ValueError(f"App {self.app_id} not found")

+ 8
- 10
api/schedule/clean_unused_datasets_task.py Целия файл

break break


for dataset in datasets: for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
.where(DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
dataset_query = db.session.scalars(
select(DatasetQuery).where(
DatasetQuery.created_at > clean_day, DatasetQuery.dataset_id == dataset.id
)
).all()


if not dataset_query or len(dataset_query) == 0: if not dataset_query or len(dataset_query) == 0:
try: try:
if should_clean: if should_clean:
# Add auto disable log if required # Add auto disable log if required
if add_logs: if add_logs:
documents = (
db.session.query(Document)
.where(
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset.id, Document.dataset_id == dataset.id,
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
) )
.all()
)
).all()
for document in documents: for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog( dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,

+ 4
- 3
api/schedule/mail_clean_document_notify_task.py Целия файл

from collections import defaultdict from collections import defaultdict


import click import click
from sqlalchemy import select


import app import app
from configs import dify_config from configs import dify_config


# send document clean notify mail # send document clean notify mail
try: try:
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all()
)
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False)
).all()
# group by tenant_id # group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs: for dataset_auto_disable_log in dataset_auto_disable_logs:

+ 6
- 6
api/schedule/update_tidb_serverless_status_task.py Целия файл

import time import time
from collections.abc import Sequence


import click import click
from sqlalchemy import select


import app import app
from configs import dify_config from configs import dify_config
start_at = time.perf_counter() start_at = time.perf_counter()
try: try:
# check the number of idle tidb serverless # check the number of idle tidb serverless
tidb_serverless_list = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
tidb_serverless_list = db.session.scalars(
select(TidbAuthBinding).where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
).all()
if len(tidb_serverless_list) == 0: if len(tidb_serverless_list) == 0:
return return
# update tidb serverless status # update tidb serverless status
click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green")) click.echo(click.style(f"Update tidb serverless status task success latency: {end_at - start_at}", fg="green"))




def update_clusters(tidb_serverless_list: list[TidbAuthBinding]):
def update_clusters(tidb_serverless_list: Sequence[TidbAuthBinding]):
try: try:
# batch 20 # batch 20
for i in range(0, len(tidb_serverless_list), 20): for i in range(0, len(tidb_serverless_list), 20):

+ 3
- 5
api/services/annotation_service.py Целия файл



db.session.delete(annotation) db.session.delete(annotation)


annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
annotation_hit_histories = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation_id)
).all()
if annotation_hit_histories: if annotation_hit_histories:
for annotation_hit_history in annotation_hit_histories: for annotation_hit_history in annotation_hit_histories:
db.session.delete(annotation_hit_history) db.session.delete(annotation_hit_history)

+ 7
- 5
api/services/auth/api_key_auth_service.py Целия файл

import json import json


from sqlalchemy import select

from core.helper import encrypter from core.helper import encrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding from models.source import DataSourceApiKeyAuthBinding
class ApiKeyAuthService: class ApiKeyAuthService:
@staticmethod @staticmethod
def get_provider_auth_list(tenant_id: str): def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return data_source_api_key_bindings return data_source_api_key_bindings


@staticmethod @staticmethod

+ 2
- 1
api/services/clear_free_plan_tenant_expired_logs.py Целия файл



import click import click
from flask import Flask, current_app from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker


from configs import dify_config from configs import dify_config
@classmethod @classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int): def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context(): with flask_app.app_context():
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
apps = db.session.scalars(select(App).where(App.tenant_id == tenant_id)).all()
app_ids = [app.id for app in apps] app_ids = [app.id for app in apps]
while True: while True:
with Session(db.engine).no_autoflush as session: with Session(db.engine).no_autoflush as session:

+ 38
- 59
api/services/dataset_service.py Целия файл

import time import time
import uuid import uuid
from collections import Counter from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, Optional from typing import Any, Literal, Optional


import sqlalchemy as sa import sqlalchemy as sa
} }
# get recent 30 days auto disable logs # get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30) start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.where(
dataset_auto_disable_logs = db.session.scalars(
select(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date, DatasetAutoDisableLog.created_at >= start_date,
) )
.all()
)
).all()
if dataset_auto_disable_logs: if dataset_auto_disable_logs:
return { return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs], "document_ids": [log.document_id for log in dataset_auto_disable_logs],
return document return document


@staticmethod @staticmethod
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.id.in_(document_ids), Document.id.in_(document_ids),
Document.enabled == True, Document.enabled == True,
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.archived == False, Document.archived == False,
) )
.all()
)
).all()
return documents return documents


@staticmethod @staticmethod
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.dataset_id == dataset_id,
Document.enabled == True, Document.enabled == True,
) )
.all()
)
).all()


return documents return documents


@staticmethod @staticmethod
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(
def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.dataset_id == dataset_id,
Document.enabled == True, Document.enabled == True,
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.archived == False, Document.archived == False,
) )
.all()
)
).all()


return documents return documents


@staticmethod @staticmethod
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all()
)
def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]:
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
).all()
return documents return documents


@staticmethod @staticmethod
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]:
assert isinstance(current_user, Account) assert isinstance(current_user, Account)

documents = (
db.session.query(Document)
.where(
documents = db.session.scalars(
select(Document).where(
Document.batch == batch, Document.batch == batch,
Document.dataset_id == dataset_id, Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id, Document.tenant_id == current_user.current_tenant_id,
) )
.all()
)
).all()


return documents return documents


# Check if document_ids is not empty to avoid WHERE false condition # Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0: if not document_ids or len(document_ids) == 0:
return return
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
file_ids = [ file_ids = [
document.data_source_info_dict["upload_file_id"] document.data_source_info_dict["upload_file_id"]
for document in documents for document in documents
if not segment_ids or len(segment_ids) == 0: if not segment_ids or len(segment_ids) == 0:
return return
if action == "enable": if action == "enable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id, DocumentSegment.document_id == document.id,
DocumentSegment.enabled == False, DocumentSegment.enabled == False,
) )
.all()
)
).all()
if not segments: if not segments:
return return
real_deal_segment_ids = [] real_deal_segment_ids = []


enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id)
elif action == "disable": elif action == "disable":
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id, DocumentSegment.document_id == document.id,
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
) )
.all()
)
).all()
if not segments: if not segments:
return return
real_deal_segment_ids = [] real_deal_segment_ids = []
dataset: Dataset, dataset: Dataset,
) -> list[ChildChunk]: ) -> list[ChildChunk]:
assert isinstance(current_user, Account) assert isinstance(current_user, Account)

child_chunks = (
db.session.query(ChildChunk)
.where(
child_chunks = db.session.scalars(
select(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id, ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id, ChildChunk.segment_id == segment.id,
) )
.all()
)
).all()
child_chunks_map = {chunk.id: chunk for chunk in child_chunks} child_chunks_map = {chunk.id: chunk for chunk in child_chunks}


new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], []
class DatasetPermissionService: class DatasetPermissionService:
@classmethod @classmethod
def get_dataset_partial_member_list(cls, dataset_id): def get_dataset_partial_member_list(cls, dataset_id):
user_list_query = (
db.session.query(
user_list_query = db.session.scalars(
select(
DatasetPermission.account_id, DatasetPermission.account_id,
)
.where(DatasetPermission.dataset_id == dataset_id)
.all()
)
).where(DatasetPermission.dataset_id == dataset_id)
).all()


user_list = [] user_list = []
for user in user_list_query: for user in user_list_query:

+ 4
- 6
api/services/model_load_balancing_service.py Целия файл

from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional, Union from typing import Optional, Union


from sqlalchemy import or_
from sqlalchemy import or_, select


from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities.provider_configuration import ProviderConfiguration from core.entities.provider_configuration import ProviderConfiguration
if not isinstance(configs, list): if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs") raise ValueError("Invalid load balancing configs")


current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.where(
current_load_balancing_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.model_name == model,
) )
.all()
)
).all()


# id as key, config as value # id as key, config as value
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}

+ 8
- 10
api/services/recommend_app/database/database_retrieval.py Целия файл

from typing import Optional from typing import Optional


from sqlalchemy import select

from constants.languages import languages from constants.languages import languages
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, RecommendedApp from models.model import App, RecommendedApp
:param language: language :param language: language
:return: :return:
""" """
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
).all()


if len(recommended_apps) == 0: if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
recommended_apps = db.session.scalars(
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
).all()


categories = set() categories = set()
recommended_apps_result = [] recommended_apps_result = []

+ 15
- 20
api/services/tag_service.py Целия файл

from typing import Optional from typing import Optional


from flask_login import current_user from flask_login import current_user
from sqlalchemy import func
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound


from extensions.ext_database import db from extensions.ext_database import db
# Check if tag_ids is not empty to avoid WHERE false condition # Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0: if not tag_ids or len(tag_ids) == 0:
return [] return []
tags = (
db.session.query(Tag)
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
tags = db.session.scalars(
select(Tag).where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
if not tags: if not tags:
return [] return []
tag_ids = [tag.id for tag in tags] tag_ids = [tag.id for tag in tags]
# Check if tag_ids is not empty to avoid WHERE false condition # Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0: if not tag_ids or len(tag_ids) == 0:
return [] return []
tag_bindings = (
db.session.query(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
tag_bindings = db.session.scalars(
select(TagBinding.target_id).where(
TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id
)
).all()
return tag_bindings


@staticmethod @staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str): def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
if not tag_type or not tag_name: if not tag_type or not tag_name:
return [] return []
tags = (
db.session.query(Tag)
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
tags = list(
db.session.scalars(
select(Tag).where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
).all()
) )
if not tags: if not tags:
return [] return []
raise NotFound("Tag not found") raise NotFound("Tag not found")
db.session.delete(tag) db.session.delete(tag)
# delete tag binding # delete tag binding
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
tag_bindings = db.session.scalars(select(TagBinding).where(TagBinding.tag_id == tag_id)).all()
if tag_bindings: if tag_bindings:
for tag_binding in tag_bindings: for tag_binding in tag_bindings:
db.session.delete(tag_binding) db.session.delete(tag_binding)

+ 2
- 3
api/services/tools/api_tools_manage_service.py Целия файл

from typing import Any, cast from typing import Any, cast


from httpx import get from httpx import get
from sqlalchemy import select


from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
list api tools list api tools
""" """
# get all api providers # get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()


result: list[ToolProviderApiEntity] = [] result: list[ToolProviderApiEntity] = []



+ 4
- 2
api/services/tools/workflow_tools_manage_service.py Целия файл

from datetime import datetime from datetime import datetime
from typing import Any from typing import Any


from sqlalchemy import or_
from sqlalchemy import or_, select


from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
:param tenant_id: the tenant id :param tenant_id: the tenant id
:return: the list of tools :return: the list of tools
""" """
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()


tools: list[WorkflowToolProviderController] = [] tools: list[WorkflowToolProviderController] = []
for provider in db_tools: for provider in db_tools:

+ 2
- 1
api/tasks/annotation/enable_annotation_reply_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document from core.rag.models.document import Document
db.session.close() db.session.close()
return return


annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all()
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}" enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}" enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"



+ 5
- 2
api/tasks/batch_clean_document_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids from core.tools.utils.web_reader_tool import get_image_upload_file_ids
if not dataset: if not dataset:
raise Exception("Document has no dataset") raise Exception("Document has no dataset")


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist # check segment is exist
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]


db.session.commit() db.session.commit()
if file_ids: if file_ids:
files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files: for file in files:
try: try:
storage.delete(file.key) storage.delete(file.key)

+ 3
- 2
api/tasks/clean_dataset_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids from core.tools.utils.web_reader_tool import get_image_upload_file_ids
index_struct=index_struct, index_struct=index_struct,
collection_binding_id=collection_binding_id, collection_binding_id=collection_binding_id,
) )
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()


# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled # This ensures all invalid doc_form values are properly handled

+ 2
- 1
api/tasks/clean_document_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids from core.tools.utils.web_reader_tool import get_image_upload_file_ids
if not dataset: if not dataset:
raise Exception("Document has no dataset") raise Exception("Document has no dataset")


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# check segment is exist # check segment is exist
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]

+ 4
- 1
api/tasks/clean_notion_document_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
document = db.session.query(Document).where(Document.id == document_id).first() document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document) db.session.delete(document)


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]


index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

+ 7
- 10
api/tasks/deal_dataset_vector_index_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
if action == "remove": if action == "remove":
index_processor.clean(dataset, None, with_keywords=False) index_processor.clean(dataset, None, with_keywords=False)
elif action == "add": elif action == "add":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
) )
.all()
)
).all()


if dataset_documents: if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents] dataset_documents_ids = [doc.id for doc in dataset_documents]
) )
db.session.commit() db.session.commit()
elif action == "update": elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
) )
.all()
)
).all()
# add new index # add new index
if dataset_documents: if dataset_documents:
# update document status # update document status

+ 4
- 5
api/tasks/disable_segments_from_index_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
# sync index processor # sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()


segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id, DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id, DocumentSegment.document_id == document_id,
) )
.all()
)
).all()


if not segments: if not segments:
db.session.close() db.session.close()

+ 4
- 1
api/tasks/document_indexing_sync_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
index_type = document.doc_form index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]


# delete from vector index # delete from vector index

+ 2
- 1
api/tasks/document_indexing_update_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
index_type = document.doc_form index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]



+ 4
- 1
api/tasks/duplicate_document_indexing_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from configs import dify_config from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
index_type = document.doc_form index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]



+ 4
- 5
api/tasks/enable_segments_to_index_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
# sync index processor # sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()


segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id, DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id, DocumentSegment.document_id == document_id,
) )
.all()
)
).all()
if not segments: if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan")) logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
db.session.close() db.session.close()

+ 2
- 1
api/tasks/remove_document_from_index_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db


index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids: if index_node_ids:
try: try:

+ 4
- 1
api/tasks/retry_document_indexing_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
# clean old data # clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index

+ 2
- 1
api/tasks/sync_website_document_indexing_task.py Целия файл



import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import select


from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
# clean old data # clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()


segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index

+ 4
- 3
api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py Целия файл



import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy import select


from models.account import TenantAccountJoin, TenantAccountRole from models.account import TenantAccountJoin, TenantAccountRole
from models.model import Account, Tenant from models.model import Account, Tenant
assert load_balancing_config.id is not None assert load_balancing_config.id is not None


# Verify inherit config was created in database # Verify inherit config was created in database
inherit_configs = (
db.session.query(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__").all()
)
inherit_configs = db.session.scalars(
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
).all()
assert len(inherit_configs) == 1 assert len(inherit_configs) == 1

+ 7
- 2
api/tests/test_containers_integration_tests/services/test_tag_service.py Целия файл



import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy import select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound


from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from extensions.ext_database import db from extensions.ext_database import db


# Verify only one binding exists # Verify only one binding exists
bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
bindings = db.session.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all()
assert len(bindings) == 1 assert len(bindings) == 1


def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies):
# No error should be raised, and database state should remain unchanged # No error should be raised, and database state should remain unchanged
from extensions.ext_database import db from extensions.ext_database import db


bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id).all()
bindings = db.session.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all()
assert len(bindings) == 0 assert len(bindings) == 0


def test_check_target_exists_knowledge_success( def test_check_target_exists_knowledge_success(

+ 4
- 5
api/tests/test_containers_integration_tests/services/test_web_conversation_service.py Целия файл



import pytest import pytest
from faker import Faker from faker import Faker
from sqlalchemy import select


from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from models.account import Account from models.account import Account
# Verify only one pinned conversation record exists # Verify only one pinned conversation record exists
from extensions.ext_database import db from extensions.ext_database import db


pinned_conversations = (
db.session.query(PinnedConversation)
.where(
pinned_conversations = db.session.scalars(
select(PinnedConversation).where(
PinnedConversation.app_id == app.id, PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id, PinnedConversation.conversation_id == conversation.id,
PinnedConversation.created_by_role == "account", PinnedConversation.created_by_role == "account",
PinnedConversation.created_by == account.id, PinnedConversation.created_by == account.id,
) )
.all()
)
).all()


assert len(pinned_conversations) == 1 assert len(pinned_conversations) == 1



+ 12
- 8
api/tests/unit_tests/services/auth/test_api_key_auth_service.py Целия файл

mock_binding.provider = self.provider mock_binding.provider = self.provider
mock_binding.disabled = False mock_binding.disabled = False


mock_session.query.return_value.where.return_value.all.return_value = [mock_binding]
mock_session.scalars.return_value.all.return_value = [mock_binding]


result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)


assert len(result) == 1 assert len(result) == 1
assert result[0].tenant_id == self.tenant_id assert result[0].tenant_id == self.tenant_id
mock_session.query.assert_called_once_with(DataSourceApiKeyAuthBinding)
assert mock_session.scalars.call_count == 1
select_arg = mock_session.scalars.call_args[0][0]
assert "data_source_api_key_auth_binding" in str(select_arg).lower()


@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_empty(self, mock_session): def test_get_provider_auth_list_empty(self, mock_session):
"""Test get provider auth list - empty result""" """Test get provider auth list - empty result"""
mock_session.query.return_value.where.return_value.all.return_value = []
mock_session.scalars.return_value.all.return_value = []


result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id) result = ApiKeyAuthService.get_provider_auth_list(self.tenant_id)


@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
def test_get_provider_auth_list_filters_disabled(self, mock_session): def test_get_provider_auth_list_filters_disabled(self, mock_session):
"""Test get provider auth list - filters disabled items""" """Test get provider auth list - filters disabled items"""
mock_session.query.return_value.where.return_value.all.return_value = []
mock_session.scalars.return_value.all.return_value = []


ApiKeyAuthService.get_provider_auth_list(self.tenant_id) ApiKeyAuthService.get_provider_auth_list(self.tenant_id)

# Verify where conditions include disabled.is_(False)
where_call = mock_session.query.return_value.where.call_args[0]
assert len(where_call) == 2 # tenant_id and disabled filter conditions
select_stmt = mock_session.scalars.call_args[0][0]
where_clauses = list(getattr(select_stmt, "_where_criteria", []) or [])
# Ensure both tenant filter and disabled filter exist
where_strs = [str(c).lower() for c in where_clauses]
assert any("tenant_id" in s for s in where_strs)
assert any("disabled" in s for s in where_strs)


@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")

+ 2
- 2
api/tests/unit_tests/services/auth/test_auth_integration.py Целия файл

tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials)
tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials)


mock_session.query.return_value.where.return_value.all.return_value = [tenant1_binding]
mock_session.scalars.return_value.all.return_value = [tenant1_binding]
result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1)


mock_session.query.return_value.where.return_value.all.return_value = [tenant2_binding]
mock_session.scalars.return_value.all.return_value = [tenant2_binding]
result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2)


assert len(result1) == 1 assert len(result1) == 1

Loading…
Отказ
Запис