Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>tags/1.8.1
| """ | """ | ||||
| Save agent thought | Save agent thought | ||||
| """ | """ | ||||
| agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first() | |||||
| stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id) | |||||
| agent_thought = db.session.scalar(stmt) | |||||
| if not agent_thought: | if not agent_thought: | ||||
| raise ValueError("agent thought not found") | raise ValueError("agent thought not found") | ||||
| return result | return result | ||||
| def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: | def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: | ||||
| files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() | |||||
| stmt = select(MessageFile).where(MessageFile.message_id == message.id) | |||||
| files = db.session.scalars(stmt).all() | |||||
| if not files: | if not files: | ||||
| return UserPromptMessage(content=message.query) | return UserPromptMessage(content=message.query) | ||||
| if message.app_model_config: | if message.app_model_config: |
| with Session(db.engine, expire_on_commit=False) as session: | with Session(db.engine, expire_on_commit=False) as session: | ||||
| app_record = session.scalar(select(App).where(App.id == app_config.app_id)) | app_record = session.scalar(select(App).where(App.id == app_config.app_id)) | ||||
| if not app_record: | if not app_record: | ||||
| raise ValueError("App not found") | raise ValueError("App not found") | ||||
| import logging | import logging | ||||
| from typing import cast | from typing import cast | ||||
| from sqlalchemy import select | |||||
| from core.agent.cot_chat_agent_runner import CotChatAgentRunner | from core.agent.cot_chat_agent_runner import CotChatAgentRunner | ||||
| from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner | from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner | ||||
| from core.agent.entities import AgentEntity | from core.agent.entities import AgentEntity | ||||
| """ | """ | ||||
| app_config = application_generate_entity.app_config | app_config = application_generate_entity.app_config | ||||
| app_config = cast(AgentChatAppConfig, app_config) | app_config = cast(AgentChatAppConfig, app_config) | ||||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||||
| app_stmt = select(App).where(App.id == app_config.app_id) | |||||
| app_record = db.session.scalar(app_stmt) | |||||
| if not app_record: | if not app_record: | ||||
| raise ValueError("App not found") | raise ValueError("App not found") | ||||
| if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): | if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): | ||||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | ||||
| conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() | |||||
| conversation_stmt = select(Conversation).where(Conversation.id == conversation.id) | |||||
| conversation_result = db.session.scalar(conversation_stmt) | |||||
| if conversation_result is None: | if conversation_result is None: | ||||
| raise ValueError("Conversation not found") | raise ValueError("Conversation not found") | ||||
| message_result = db.session.query(Message).where(Message.id == message.id).first() | |||||
| msg_stmt = select(Message).where(Message.id == message.id) | |||||
| message_result = db.session.scalar(msg_stmt) | |||||
| if message_result is None: | if message_result is None: | ||||
| raise ValueError("Message not found") | raise ValueError("Message not found") | ||||
| db.session.close() | db.session.close() |
| import logging | import logging | ||||
| from typing import cast | from typing import cast | ||||
| from sqlalchemy import select | |||||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | ||||
| from core.app.apps.base_app_runner import AppRunner | from core.app.apps.base_app_runner import AppRunner | ||||
| from core.app.apps.chat.app_config_manager import ChatAppConfig | from core.app.apps.chat.app_config_manager import ChatAppConfig | ||||
| """ | """ | ||||
| app_config = application_generate_entity.app_config | app_config = application_generate_entity.app_config | ||||
| app_config = cast(ChatAppConfig, app_config) | app_config = cast(ChatAppConfig, app_config) | ||||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||||
| stmt = select(App).where(App.id == app_config.app_id) | |||||
| app_record = db.session.scalar(stmt) | |||||
| if not app_record: | if not app_record: | ||||
| raise ValueError("App not found") | raise ValueError("App not found") | ||||
| from flask import Flask, copy_current_request_context, current_app | from flask import Flask, copy_current_request_context, current_app | ||||
| from pydantic import ValidationError | from pydantic import ValidationError | ||||
| from sqlalchemy import select | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter | from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter | ||||
| :param invoke_from: invoke from source | :param invoke_from: invoke from source | ||||
| :param stream: is stream | :param stream: is stream | ||||
| """ | """ | ||||
| message = ( | |||||
| db.session.query(Message) | |||||
| .where( | |||||
| Message.id == message_id, | |||||
| Message.app_id == app_model.id, | |||||
| Message.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||||
| Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||||
| Message.from_account_id == (user.id if isinstance(user, Account) else None), | |||||
| ) | |||||
| .first() | |||||
| stmt = select(Message).where( | |||||
| Message.id == message_id, | |||||
| Message.app_id == app_model.id, | |||||
| Message.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||||
| Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||||
| Message.from_account_id == (user.id if isinstance(user, Account) else None), | |||||
| ) | ) | ||||
| message = db.session.scalar(stmt) | |||||
| if not message: | if not message: | ||||
| raise MessageNotExistsError() | raise MessageNotExistsError() |
| import logging | import logging | ||||
| from typing import cast | from typing import cast | ||||
| from sqlalchemy import select | |||||
| from core.app.apps.base_app_queue_manager import AppQueueManager | from core.app.apps.base_app_queue_manager import AppQueueManager | ||||
| from core.app.apps.base_app_runner import AppRunner | from core.app.apps.base_app_runner import AppRunner | ||||
| from core.app.apps.completion.app_config_manager import CompletionAppConfig | from core.app.apps.completion.app_config_manager import CompletionAppConfig | ||||
| """ | """ | ||||
| app_config = application_generate_entity.app_config | app_config = application_generate_entity.app_config | ||||
| app_config = cast(CompletionAppConfig, app_config) | app_config = cast(CompletionAppConfig, app_config) | ||||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||||
| stmt = select(App).where(App.id == app_config.app_id) | |||||
| app_record = db.session.scalar(stmt) | |||||
| if not app_record: | if not app_record: | ||||
| raise ValueError("App not found") | raise ValueError("App not found") | ||||
| def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: | def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: | ||||
| if conversation: | if conversation: | ||||
| app_model_config = ( | |||||
| db.session.query(AppModelConfig) | |||||
| .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) | |||||
| .first() | |||||
| stmt = select(AppModelConfig).where( | |||||
| AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id | |||||
| ) | ) | ||||
| app_model_config = db.session.scalar(stmt) | |||||
| if not app_model_config: | if not app_model_config: | ||||
| raise AppModelConfigBrokenError() | raise AppModelConfigBrokenError() |
| import logging | import logging | ||||
| from typing import Optional | from typing import Optional | ||||
| from sqlalchemy import select | |||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.rag.datasource.vdb.vector_factory import Vector | from core.rag.datasource.vdb.vector_factory import Vector | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| :param invoke_from: invoke from | :param invoke_from: invoke from | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| annotation_setting = ( | |||||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() | |||||
| ) | |||||
| stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id) | |||||
| annotation_setting = db.session.scalar(stmt) | |||||
| if not annotation_setting: | if not annotation_setting: | ||||
| return None | return None |
| def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): | def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): | ||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| # get conversation and message | # get conversation and message | ||||
| conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() | |||||
| stmt = select(Conversation).where(Conversation.id == conversation_id) | |||||
| conversation = db.session.scalar(stmt) | |||||
| if not conversation: | if not conversation: | ||||
| return | return |
| import logging | import logging | ||||
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| from sqlalchemy import select | |||||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | ||||
| from core.app.entities.app_invoke_entities import InvokeFrom | from core.app.entities.app_invoke_entities import InvokeFrom | ||||
| from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | from core.app.entities.queue_entities import QueueRetrieverResourcesEvent | ||||
| for document in documents: | for document in documents: | ||||
| if document.metadata is not None: | if document.metadata is not None: | ||||
| document_id = document.metadata["document_id"] | document_id = document.metadata["document_id"] | ||||
| dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() | |||||
| dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id) | |||||
| dataset_document = db.session.scalar(dataset_document_stmt) | |||||
| if not dataset_document: | if not dataset_document: | ||||
| _logger.warning( | _logger.warning( | ||||
| "Expected DatasetDocument record to exist, but none was found, document_id=%s", | "Expected DatasetDocument record to exist, but none was found, document_id=%s", | ||||
| ) | ) | ||||
| continue | continue | ||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | ||||
| child_chunk = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where( | |||||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||||
| ChildChunk.document_id == dataset_document.id, | |||||
| ) | |||||
| .first() | |||||
| child_chunk_stmt = select(ChildChunk).where( | |||||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||||
| ChildChunk.document_id == dataset_document.id, | |||||
| ) | ) | ||||
| child_chunk = db.session.scalar(child_chunk_stmt) | |||||
| if child_chunk: | if child_chunk: | ||||
| segment = ( | segment = ( | ||||
| db.session.query(DocumentSegment) | db.session.query(DocumentSegment) |
| from typing import Optional | from typing import Optional | ||||
| from sqlalchemy import select | |||||
| from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor | from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor | ||||
| from core.external_data_tool.base import ExternalDataTool | from core.external_data_tool.base import ExternalDataTool | ||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| api_based_extension_id = config.get("api_based_extension_id") | api_based_extension_id = config.get("api_based_extension_id") | ||||
| if not api_based_extension_id: | if not api_based_extension_id: | ||||
| raise ValueError("api_based_extension_id is required") | raise ValueError("api_based_extension_id is required") | ||||
| # get api_based_extension | # get api_based_extension | ||||
| api_based_extension = ( | |||||
| db.session.query(APIBasedExtension) | |||||
| .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) | |||||
| .first() | |||||
| stmt = select(APIBasedExtension).where( | |||||
| APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id | |||||
| ) | ) | ||||
| api_based_extension = db.session.scalar(stmt) | |||||
| if not api_based_extension: | if not api_based_extension: | ||||
| raise ValueError("api_based_extension_id is invalid") | raise ValueError("api_based_extension_id is invalid") | ||||
| raise ValueError(f"config is required, config: {self.config}") | raise ValueError(f"config is required, config: {self.config}") | ||||
| api_based_extension_id = self.config.get("api_based_extension_id") | api_based_extension_id = self.config.get("api_based_extension_id") | ||||
| assert api_based_extension_id is not None, "api_based_extension_id is required" | assert api_based_extension_id is not None, "api_based_extension_id is required" | ||||
| # get api_based_extension | # get api_based_extension | ||||
| api_based_extension = ( | |||||
| db.session.query(APIBasedExtension) | |||||
| .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) | |||||
| .first() | |||||
| stmt = select(APIBasedExtension).where( | |||||
| APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id | |||||
| ) | ) | ||||
| api_based_extension = db.session.scalar(stmt) | |||||
| if not api_based_extension: | if not api_based_extension: | ||||
| raise ValueError( | raise ValueError( |
| from typing import Any, Optional, cast | from typing import Any, Optional, cast | ||||
| from flask import current_app | from flask import current_app | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm.exc import ObjectDeletedError | from sqlalchemy.orm.exc import ObjectDeletedError | ||||
| from configs import dify_config | from configs import dify_config | ||||
| if not dataset: | if not dataset: | ||||
| raise ValueError("no dataset found") | raise ValueError("no dataset found") | ||||
| # get the process rule | # get the process rule | ||||
| processing_rule = ( | |||||
| db.session.query(DatasetProcessRule) | |||||
| .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||||
| .first() | |||||
| stmt = select(DatasetProcessRule).where( | |||||
| DatasetProcessRule.id == dataset_document.dataset_process_rule_id | |||||
| ) | ) | ||||
| processing_rule = db.session.scalar(stmt) | |||||
| if not processing_rule: | if not processing_rule: | ||||
| raise ValueError("no process rule found") | raise ValueError("no process rule found") | ||||
| index_type = dataset_document.doc_form | index_type = dataset_document.doc_form | ||||
| db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() | db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() | ||||
| db.session.commit() | db.session.commit() | ||||
| # get the process rule | # get the process rule | ||||
| processing_rule = ( | |||||
| db.session.query(DatasetProcessRule) | |||||
| .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||||
| .first() | |||||
| ) | |||||
| stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||||
| processing_rule = db.session.scalar(stmt) | |||||
| if not processing_rule: | if not processing_rule: | ||||
| raise ValueError("no process rule found") | raise ValueError("no process rule found") | ||||
| child_documents.append(child_document) | child_documents.append(child_document) | ||||
| document.children = child_documents | document.children = child_documents | ||||
| documents.append(document) | documents.append(document) | ||||
| # build index | # build index | ||||
| index_type = dataset_document.doc_form | index_type = dataset_document.doc_form | ||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | index_processor = IndexProcessorFactory(index_type).init_index_processor() | ||||
| # delete image files and related db records | # delete image files and related db records | ||||
| image_upload_file_ids = get_image_upload_file_ids(document.page_content) | image_upload_file_ids = get_image_upload_file_ids(document.page_content) | ||||
| for upload_file_id in image_upload_file_ids: | for upload_file_id in image_upload_file_ids: | ||||
| image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() | |||||
| stmt = select(UploadFile).where(UploadFile.id == upload_file_id) | |||||
| image_file = db.session.scalar(stmt) | |||||
| if image_file is None: | if image_file is None: | ||||
| continue | continue | ||||
| try: | try: | ||||
| if dataset_document.data_source_type == "upload_file": | if dataset_document.data_source_type == "upload_file": | ||||
| if not data_source_info or "upload_file_id" not in data_source_info: | if not data_source_info or "upload_file_id" not in data_source_info: | ||||
| raise ValueError("no upload file found") | raise ValueError("no upload file found") | ||||
| file_detail = ( | |||||
| db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() | |||||
| ) | |||||
| stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) | |||||
| file_detail = db.session.scalars(stmt).one_or_none() | |||||
| if file_detail: | if file_detail: | ||||
| extract_setting = ExtractSetting( | extract_setting = ExtractSetting( |
| else: | else: | ||||
| message_limit = 500 | message_limit = 500 | ||||
| stmt = stmt.limit(message_limit) | |||||
| msg_limit_stmt = stmt.limit(message_limit) | |||||
| messages = db.session.scalars(stmt).all() | |||||
| messages = db.session.scalars(msg_limit_stmt).all() | |||||
| # instead of all messages from the conversation, we only need to extract messages | # instead of all messages from the conversation, we only need to extract messages | ||||
| # that belong to the thread of last message | # that belong to the thread of last message |
| from typing import Optional | from typing import Optional | ||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from sqlalchemy import select | |||||
| from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor | from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor | ||||
| from core.helper.encrypter import decrypt_token | from core.helper.encrypter import decrypt_token | ||||
| @staticmethod | @staticmethod | ||||
| def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: | def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: | ||||
| extension = ( | |||||
| db.session.query(APIBasedExtension) | |||||
| .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) | |||||
| .first() | |||||
| stmt = select(APIBasedExtension).where( | |||||
| APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id | |||||
| ) | ) | ||||
| extension = db.session.scalar(stmt) | |||||
| return extension | return extension |
| from urllib.parse import urljoin | from urllib.parse import urljoin | ||||
| from opentelemetry.trace import Link, Status, StatusCode | from opentelemetry.trace import Link, Status, StatusCode | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session, sessionmaker | from sqlalchemy.orm import Session, sessionmaker | ||||
| from core.ops.aliyun_trace.data_exporter.traceclient import ( | from core.ops.aliyun_trace.data_exporter.traceclient import ( | ||||
| app_id = trace_info.metadata.get("app_id") | app_id = trace_info.metadata.get("app_id") | ||||
| if not app_id: | if not app_id: | ||||
| raise ValueError("No app_id found in trace_info metadata") | raise ValueError("No app_id found in trace_info metadata") | ||||
| app = session.query(App).where(App.id == app_id).first() | |||||
| app_stmt = select(App).where(App.id == app_id) | |||||
| app = session.scalar(app_stmt) | |||||
| if not app: | if not app: | ||||
| raise ValueError(f"App with id {app_id} not found") | raise ValueError(f"App with id {app_id} not found") | ||||
| if not app.created_by: | if not app.created_by: | ||||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | ||||
| service_account = session.query(Account).where(Account.id == app.created_by).first() | |||||
| account_stmt = select(Account).where(Account.id == app.created_by) | |||||
| service_account = session.scalar(account_stmt) | |||||
| if not service_account: | if not service_account: | ||||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | ||||
| current_tenant = ( | current_tenant = ( |
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from core.ops.entities.config_entity import BaseTracingConfig | from core.ops.entities.config_entity import BaseTracingConfig | ||||
| """ | """ | ||||
| with Session(db.engine, expire_on_commit=False) as session: | with Session(db.engine, expire_on_commit=False) as session: | ||||
| # Get the app to find its creator | # Get the app to find its creator | ||||
| app = session.query(App).where(App.id == app_id).first() | |||||
| app_stmt = select(App).where(App.id == app_id) | |||||
| app = session.scalar(app_stmt) | |||||
| if not app: | if not app: | ||||
| raise ValueError(f"App with id {app_id} not found") | raise ValueError(f"App with id {app_id} not found") | ||||
| if not app.created_by: | if not app.created_by: | ||||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | ||||
| service_account = session.query(Account).where(Account.id == app.created_by).first() | |||||
| account_stmt = select(Account).where(Account.id == app.created_by) | |||||
| service_account = session.scalar(account_stmt) | |||||
| if not service_account: | if not service_account: | ||||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | ||||
| if not trace_config_data: | if not trace_config_data: | ||||
| return None | return None | ||||
| # decrypt_token | # decrypt_token | ||||
| app = db.session.query(App).where(App.id == app_id).first() | |||||
| stmt = select(App).where(App.id == app_id) | |||||
| app = db.session.scalar(stmt) | |||||
| if not app: | if not app: | ||||
| raise ValueError("App not found") | raise ValueError("App not found") | ||||
| @classmethod | @classmethod | ||||
| def get_app_config_through_message_id(cls, message_id: str): | def get_app_config_through_message_id(cls, message_id: str): | ||||
| app_model_config = None | app_model_config = None | ||||
| message_data = db.session.query(Message).where(Message.id == message_id).first() | |||||
| message_stmt = select(Message).where(Message.id == message_id) | |||||
| message_data = db.session.scalar(message_stmt) | |||||
| if not message_data: | if not message_data: | ||||
| return None | return None | ||||
| conversation_id = message_data.conversation_id | conversation_id = message_data.conversation_id | ||||
| conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() | |||||
| conversation_stmt = select(Conversation).where(Conversation.id == conversation_id) | |||||
| conversation_data = db.session.scalar(conversation_stmt) | |||||
| if not conversation_data: | if not conversation_data: | ||||
| return None | return None | ||||
| if conversation_data.app_model_config_id: | if conversation_data.app_model_config_id: | ||||
| app_model_config = ( | |||||
| db.session.query(AppModelConfig) | |||||
| .where(AppModelConfig.id == conversation_data.app_model_config_id) | |||||
| .first() | |||||
| ) | |||||
| config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id) | |||||
| app_model_config = db.session.scalar(config_stmt) | |||||
| elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: | elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: | ||||
| app_model_config = conversation_data.override_model_configs | app_model_config = conversation_data.override_model_configs | ||||
| from collections.abc import Generator, Mapping | from collections.abc import Generator, Mapping | ||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| from sqlalchemy import select | |||||
| from controllers.service_api.wraps import create_or_update_end_user_for_user_id | from controllers.service_api.wraps import create_or_update_end_user_for_user_id | ||||
| from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict | from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict | ||||
| from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator | from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator | ||||
| """ | """ | ||||
| get the user by user id | get the user by user id | ||||
| """ | """ | ||||
| user = db.session.query(EndUser).where(EndUser.id == user_id).first() | |||||
| stmt = select(EndUser).where(EndUser.id == user_id) | |||||
| user = db.session.scalar(stmt) | |||||
| if not user: | if not user: | ||||
| user = db.session.query(Account).where(Account.id == user_id).first() | |||||
| stmt = select(Account).where(Account.id == user_id) | |||||
| user = db.session.scalar(stmt) | |||||
| if not user: | if not user: | ||||
| raise ValueError("user not found") | raise ValueError("user not found") |
| :param model_type: model type | :param model_type: model type | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # Get the corresponding TenantDefaultModel record | |||||
| default_model = ( | |||||
| db.session.query(TenantDefaultModel) | |||||
| .where( | |||||
| TenantDefaultModel.tenant_id == tenant_id, | |||||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||||
| ) | |||||
| .first() | |||||
| stmt = select(TenantDefaultModel).where( | |||||
| TenantDefaultModel.tenant_id == tenant_id, | |||||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||||
| ) | ) | ||||
| default_model = db.session.scalar(stmt) | |||||
| # If it does not exist, get the first available provider model from get_configurations | # If it does not exist, get the first available provider model from get_configurations | ||||
| # and update the TenantDefaultModel record | # and update the TenantDefaultModel record | ||||
| model_names = [model.model for model in available_models] | model_names = [model.model for model in available_models] | ||||
| if model not in model_names: | if model not in model_names: | ||||
| raise ValueError(f"Model {model} does not exist.") | raise ValueError(f"Model {model} does not exist.") | ||||
| # Get the list of available models from get_configurations and check if it is LLM | |||||
| default_model = ( | |||||
| db.session.query(TenantDefaultModel) | |||||
| .where( | |||||
| TenantDefaultModel.tenant_id == tenant_id, | |||||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||||
| ) | |||||
| .first() | |||||
| stmt = select(TenantDefaultModel).where( | |||||
| TenantDefaultModel.tenant_id == tenant_id, | |||||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||||
| ) | ) | ||||
| default_model = db.session.scalar(stmt) | |||||
| # create or update TenantDefaultModel record | # create or update TenantDefaultModel record | ||||
| if default_model: | if default_model: | ||||
| provider_name_to_provider_records_dict[provider_name].append(new_provider_record) | provider_name_to_provider_records_dict[provider_name].append(new_provider_record) | ||||
| except IntegrityError: | except IntegrityError: | ||||
| db.session.rollback() | db.session.rollback() | ||||
| existed_provider_record = ( | |||||
| db.session.query(Provider) | |||||
| .where( | |||||
| Provider.tenant_id == tenant_id, | |||||
| Provider.provider_name == ModelProviderID(provider_name).provider_name, | |||||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||||
| Provider.quota_type == ProviderQuotaType.TRIAL.value, | |||||
| ) | |||||
| .first() | |||||
| stmt = select(Provider).where( | |||||
| Provider.tenant_id == tenant_id, | |||||
| Provider.provider_name == ModelProviderID(provider_name).provider_name, | |||||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||||
| Provider.quota_type == ProviderQuotaType.TRIAL.value, | |||||
| ) | ) | ||||
| existed_provider_record = db.session.scalar(stmt) | |||||
| if not existed_provider_record: | if not existed_provider_record: | ||||
| continue | continue | ||||
| import orjson | import orjson | ||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from sqlalchemy import select | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | ||||
| return sorted_chunk_indices[:k] | return sorted_chunk_indices[:k] | ||||
| def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): | def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): | ||||
| document_segment = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) | |||||
| .first() | |||||
| stmt = select(DocumentSegment).where( | |||||
| DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id | |||||
| ) | ) | ||||
| document_segment = db.session.scalar(stmt) | |||||
| if document_segment: | if document_segment: | ||||
| document_segment.keywords = keywords | document_segment.keywords = keywords | ||||
| db.session.add(document_segment) | db.session.add(document_segment) |
| from typing import Optional | from typing import Optional | ||||
| from flask import Flask, current_app | from flask import Flask, current_app | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session, load_only | from sqlalchemy.orm import Session, load_only | ||||
| from configs import dify_config | from configs import dify_config | ||||
| external_retrieval_model: Optional[dict] = None, | external_retrieval_model: Optional[dict] = None, | ||||
| metadata_filtering_conditions: Optional[dict] = None, | metadata_filtering_conditions: Optional[dict] = None, | ||||
| ): | ): | ||||
| dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() | |||||
| stmt = select(Dataset).where(Dataset.id == dataset_id) | |||||
| dataset = db.session.scalar(stmt) | |||||
| if not dataset: | if not dataset: | ||||
| return [] | return [] | ||||
| metadata_condition = ( | metadata_condition = ( | ||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | ||||
| # Handle parent-child documents | # Handle parent-child documents | ||||
| child_index_node_id = document.metadata.get("doc_id") | child_index_node_id = document.metadata.get("doc_id") | ||||
| child_chunk = ( | |||||
| db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first() | |||||
| ) | |||||
| child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) | |||||
| child_chunk = db.session.scalar(child_chunk_stmt) | |||||
| if not child_chunk: | if not child_chunk: | ||||
| continue | continue | ||||
| index_node_id = document.metadata.get("doc_id") | index_node_id = document.metadata.get("doc_id") | ||||
| if not index_node_id: | if not index_node_id: | ||||
| continue | continue | ||||
| segment = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where( | |||||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.index_node_id == index_node_id, | |||||
| ) | |||||
| .first() | |||||
| document_segment_stmt = select(DocumentSegment).where( | |||||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.index_node_id == index_node_id, | |||||
| ) | ) | ||||
| segment = db.session.scalar(document_segment_stmt) | |||||
| if not segment: | if not segment: | ||||
| continue | continue |
| TokenizerType, | TokenizerType, | ||||
| ) | ) | ||||
| from qdrant_client.local.qdrant_local import QdrantLocal | from qdrant_client.local.qdrant_local import QdrantLocal | ||||
| from sqlalchemy import select | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.rag.datasource.vdb.field import Field | from core.rag.datasource.vdb.field import Field | ||||
| class QdrantVectorFactory(AbstractVectorFactory): | class QdrantVectorFactory(AbstractVectorFactory): | ||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: | def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector: | ||||
| if dataset.collection_binding_id: | if dataset.collection_binding_id: | ||||
| dataset_collection_binding = ( | |||||
| db.session.query(DatasetCollectionBinding) | |||||
| .where(DatasetCollectionBinding.id == dataset.collection_binding_id) | |||||
| .one_or_none() | |||||
| ) | |||||
| stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id) | |||||
| dataset_collection_binding = db.session.scalars(stmt).one_or_none() | |||||
| if dataset_collection_binding: | if dataset_collection_binding: | ||||
| collection_name = dataset_collection_binding.collection_name | collection_name = dataset_collection_binding.collection_name | ||||
| else: | else: |
| ) | ) | ||||
| from qdrant_client.local.qdrant_local import QdrantLocal | from qdrant_client.local.qdrant_local import QdrantLocal | ||||
| from requests.auth import HTTPDigestAuth | from requests.auth import HTTPDigestAuth | ||||
| from sqlalchemy import select | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.rag.datasource.vdb.field import Field | from core.rag.datasource.vdb.field import Field | ||||
| class TidbOnQdrantVectorFactory(AbstractVectorFactory): | class TidbOnQdrantVectorFactory(AbstractVectorFactory): | ||||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: | def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: | ||||
| tidb_auth_binding = ( | |||||
| db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | |||||
| ) | |||||
| stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||||
| tidb_auth_binding = db.session.scalars(stmt).one_or_none() | |||||
| if not tidb_auth_binding: | if not tidb_auth_binding: | ||||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | ||||
| tidb_auth_binding = ( | |||||
| db.session.query(TidbAuthBinding) | |||||
| .where(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||||
| .one_or_none() | |||||
| ) | |||||
| stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||||
| tidb_auth_binding = db.session.scalars(stmt).one_or_none() | |||||
| if tidb_auth_binding: | if tidb_auth_binding: | ||||
| TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}" | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from sqlalchemy import select | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| vector_type = self._dataset.index_struct_dict["type"] | vector_type = self._dataset.index_struct_dict["type"] | ||||
| else: | else: | ||||
| if dify_config.VECTOR_STORE_WHITELIST_ENABLE: | if dify_config.VECTOR_STORE_WHITELIST_ENABLE: | ||||
| whitelist = ( | |||||
| db.session.query(Whitelist) | |||||
| .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") | |||||
| .one_or_none() | |||||
| stmt = select(Whitelist).where( | |||||
| Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db" | |||||
| ) | ) | ||||
| whitelist = db.session.scalars(stmt).one_or_none() | |||||
| if whitelist: | if whitelist: | ||||
| vector_type = VectorType.TIDB_ON_QDRANT | vector_type = VectorType.TIDB_ON_QDRANT | ||||
| from collections.abc import Sequence | from collections.abc import Sequence | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from sqlalchemy import func | |||||
| from sqlalchemy import func, select | |||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| @property | @property | ||||
| def docs(self) -> dict[str, Document]: | def docs(self) -> dict[str, Document]: | ||||
| document_segments = ( | |||||
| db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all() | |||||
| ) | |||||
| stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id) | |||||
| document_segments = db.session.scalars(stmt).all() | |||||
| output = {} | output = {} | ||||
| for document_segment in document_segments: | for document_segment in document_segments: | ||||
| return data | return data | ||||
| def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: | def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: | ||||
| document_segment = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) | |||||
| .first() | |||||
| stmt = select(DocumentSegment).where( | |||||
| DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id | |||||
| ) | ) | ||||
| document_segment = db.session.scalar(stmt) | |||||
| return document_segment | return document_segment |
| from typing import Any, Optional, cast | from typing import Any, Optional, cast | ||||
| import requests | import requests | ||||
| from sqlalchemy import select | |||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.rag.extractor.extractor_base import BaseExtractor | from core.rag.extractor.extractor_base import BaseExtractor | ||||
| @classmethod | @classmethod | ||||
| def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: | def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: | ||||
| data_source_binding = ( | |||||
| db.session.query(DataSourceOauthBinding) | |||||
| .where( | |||||
| db.and_( | |||||
| DataSourceOauthBinding.tenant_id == tenant_id, | |||||
| DataSourceOauthBinding.provider == "notion", | |||||
| DataSourceOauthBinding.disabled == False, | |||||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', | |||||
| ) | |||||
| ) | |||||
| .first() | |||||
| stmt = select(DataSourceOauthBinding).where( | |||||
| DataSourceOauthBinding.tenant_id == tenant_id, | |||||
| DataSourceOauthBinding.provider == "notion", | |||||
| DataSourceOauthBinding.disabled == False, | |||||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', | |||||
| ) | ) | ||||
| data_source_binding = db.session.scalar(stmt) | |||||
| if not data_source_binding: | if not data_source_binding: | ||||
| raise Exception( | raise Exception( |
| from typing import Any, Optional, Union, cast | from typing import Any, Optional, Union, cast | ||||
| from flask import Flask, current_app | from flask import Flask, current_app | ||||
| from sqlalchemy import Float, and_, or_, text | |||||
| from sqlalchemy import Float, and_, or_, select, text | |||||
| from sqlalchemy import cast as sqlalchemy_cast | from sqlalchemy import cast as sqlalchemy_cast | ||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| available_datasets = [] | available_datasets = [] | ||||
| for dataset_id in dataset_ids: | for dataset_id in dataset_ids: | ||||
| # get dataset from dataset id | # get dataset from dataset id | ||||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||||
| dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) | |||||
| dataset = db.session.scalar(dataset_stmt) | |||||
| # pass if dataset is not available | # pass if dataset is not available | ||||
| if not dataset: | if not dataset: | ||||
| for record in records: | for record in records: | ||||
| segment = record.segment | segment = record.segment | ||||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | ||||
| document = ( | |||||
| db.session.query(DatasetDocument) | |||||
| .where( | |||||
| DatasetDocument.id == segment.document_id, | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ) | |||||
| .first() | |||||
| dataset_document_stmt = select(DatasetDocument).where( | |||||
| DatasetDocument.id == segment.document_id, | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ) | ) | ||||
| document = db.session.scalar(dataset_document_stmt) | |||||
| if dataset and document: | if dataset and document: | ||||
| source = RetrievalSourceMetadata( | source = RetrievalSourceMetadata( | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, | ||||
| if dataset_id: | if dataset_id: | ||||
| # get retrieval model config | # get retrieval model config | ||||
| dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() | |||||
| dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) | |||||
| dataset = db.session.scalar(dataset_stmt) | |||||
| if dataset: | if dataset: | ||||
| results = [] | results = [] | ||||
| if dataset.provider == "external": | if dataset.provider == "external": | ||||
| dify_documents = [document for document in documents if document.provider == "dify"] | dify_documents = [document for document in documents if document.provider == "dify"] | ||||
| for document in dify_documents: | for document in dify_documents: | ||||
| if document.metadata is not None: | if document.metadata is not None: | ||||
| dataset_document = ( | |||||
| db.session.query(DatasetDocument) | |||||
| .where(DatasetDocument.id == document.metadata["document_id"]) | |||||
| .first() | |||||
| dataset_document_stmt = select(DatasetDocument).where( | |||||
| DatasetDocument.id == document.metadata["document_id"] | |||||
| ) | ) | ||||
| dataset_document = db.session.scalar(dataset_document_stmt) | |||||
| if dataset_document: | if dataset_document: | ||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | ||||
| child_chunk = ( | |||||
| db.session.query(ChildChunk) | |||||
| .where( | |||||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||||
| ChildChunk.document_id == dataset_document.id, | |||||
| ) | |||||
| .first() | |||||
| child_chunk_stmt = select(ChildChunk).where( | |||||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||||
| ChildChunk.document_id == dataset_document.id, | |||||
| ) | ) | ||||
| child_chunk = db.session.scalar(child_chunk_stmt) | |||||
| if child_chunk: | if child_chunk: | ||||
| segment = ( | segment = ( | ||||
| db.session.query(DocumentSegment) | db.session.query(DocumentSegment) | ||||
| ): | ): | ||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| with Session(db.engine) as session: | with Session(db.engine) as session: | ||||
| dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() | |||||
| dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) | |||||
| dataset = db.session.scalar(dataset_stmt) | |||||
| if not dataset: | if not dataset: | ||||
| return [] | return [] | ||||
| available_datasets = [] | available_datasets = [] | ||||
| for dataset_id in dataset_ids: | for dataset_id in dataset_ids: | ||||
| # get dataset from dataset id | # get dataset from dataset id | ||||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||||
| dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id) | |||||
| dataset = db.session.scalar(dataset_stmt) | |||||
| # pass if dataset is not available | # pass if dataset is not available | ||||
| if not dataset: | if not dataset: | ||||
| self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig | self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig | ||||
| ) -> Optional[list[dict[str, Any]]]: | ) -> Optional[list[dict[str, Any]]]: | ||||
| # get all metadata field | # get all metadata field | ||||
| metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||||
| metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) | |||||
| metadata_fields = db.session.scalars(metadata_stmt).all() | |||||
| all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | ||||
| # get metadata model config | # get metadata model config | ||||
| if metadata_model_config is None: | if metadata_model_config is None: |
| from sqlalchemy import select | |||||
| from core.tools.__base.tool_provider import ToolProviderController | from core.tools.__base.tool_provider import ToolProviderController | ||||
| from core.tools.builtin_tool.provider import BuiltinToolProviderController | from core.tools.builtin_tool.provider import BuiltinToolProviderController | ||||
| from core.tools.custom_tool.provider import ApiToolProviderController | from core.tools.custom_tool.provider import ApiToolProviderController | ||||
| return controller.tool_labels | return controller.tool_labels | ||||
| else: | else: | ||||
| raise ValueError("Unsupported tool type") | raise ValueError("Unsupported tool type") | ||||
| labels = ( | |||||
| db.session.query(ToolLabelBinding.label_name) | |||||
| .where( | |||||
| ToolLabelBinding.tool_id == provider_id, | |||||
| ToolLabelBinding.tool_type == controller.provider_type.value, | |||||
| ) | |||||
| .all() | |||||
| stmt = select(ToolLabelBinding.label_name).where( | |||||
| ToolLabelBinding.tool_id == provider_id, | |||||
| ToolLabelBinding.tool_type == controller.provider_type.value, | |||||
| ) | ) | ||||
| labels = db.session.scalars(stmt).all() | |||||
| return [label.label_name for label in labels] | |||||
| return list(labels) | |||||
| @classmethod | @classmethod | ||||
| def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: | def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: |
| import sqlalchemy as sa | import sqlalchemy as sa | ||||
| from pydantic import TypeAdapter | from pydantic import TypeAdapter | ||||
| from sqlalchemy import select | |||||
| from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
| from yarl import URL | from yarl import URL | ||||
| # get specific credentials | # get specific credentials | ||||
| if is_valid_uuid(credential_id): | if is_valid_uuid(credential_id): | ||||
| try: | try: | ||||
| builtin_provider = ( | |||||
| db.session.query(BuiltinToolProvider) | |||||
| .where( | |||||
| BuiltinToolProvider.tenant_id == tenant_id, | |||||
| BuiltinToolProvider.id == credential_id, | |||||
| ) | |||||
| .first() | |||||
| builtin_provider_stmt = select(BuiltinToolProvider).where( | |||||
| BuiltinToolProvider.tenant_id == tenant_id, | |||||
| BuiltinToolProvider.id == credential_id, | |||||
| ) | ) | ||||
| builtin_provider = db.session.scalar(builtin_provider_stmt) | |||||
| except Exception as e: | except Exception as e: | ||||
| builtin_provider = None | builtin_provider = None | ||||
| logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) | logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) | ||||
| ), | ), | ||||
| ) | ) | ||||
| elif provider_type == ToolProviderType.WORKFLOW: | elif provider_type == ToolProviderType.WORKFLOW: | ||||
| workflow_provider = ( | |||||
| db.session.query(WorkflowToolProvider) | |||||
| .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) | |||||
| .first() | |||||
| workflow_provider_stmt = select(WorkflowToolProvider).where( | |||||
| WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id | |||||
| ) | ) | ||||
| workflow_provider = db.session.scalar(workflow_provider_stmt) | |||||
| if workflow_provider is None: | if workflow_provider is None: | ||||
| raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") | raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") |
| from flask import Flask, current_app | from flask import Flask, current_app | ||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from sqlalchemy import select | |||||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| document_context_list = [] | document_context_list = [] | ||||
| index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] | index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] | ||||
| segments = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .where( | |||||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||||
| DocumentSegment.completed_at.isnot(None), | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||||
| ) | |||||
| .all() | |||||
| document_segment_stmt = select(DocumentSegment).where( | |||||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||||
| DocumentSegment.completed_at.isnot(None), | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||||
| ) | ) | ||||
| segments = db.session.scalars(document_segment_stmt).all() | |||||
| if segments: | if segments: | ||||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | ||||
| resource_number = 1 | resource_number = 1 | ||||
| for segment in sorted_segments: | for segment in sorted_segments: | ||||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | ||||
| document = ( | |||||
| db.session.query(Document) | |||||
| .where( | |||||
| Document.id == segment.document_id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ) | |||||
| .first() | |||||
| document_stmt = select(Document).where( | |||||
| Document.id == segment.document_id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ) | ) | ||||
| document = db.session.scalar(document_stmt) | |||||
| if dataset and document: | if dataset and document: | ||||
| source = RetrievalSourceMetadata( | source = RetrievalSourceMetadata( | ||||
| position=resource_number, | position=resource_number, | ||||
| hit_callbacks: list[DatasetIndexToolCallbackHandler], | hit_callbacks: list[DatasetIndexToolCallbackHandler], | ||||
| ): | ): | ||||
| with flask_app.app_context(): | with flask_app.app_context(): | ||||
| dataset = ( | |||||
| db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() | |||||
| ) | |||||
| stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id) | |||||
| dataset = db.session.scalar(stmt) | |||||
| if not dataset: | if not dataset: | ||||
| return [] | return [] |
| from typing import Any, Optional, cast | from typing import Any, Optional, cast | ||||
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| from sqlalchemy import select | |||||
| from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | from core.rag.datasource.retrieval_service import RetrievalService | ||||
| ) | ) | ||||
| def _run(self, query: str) -> str: | def _run(self, query: str) -> str: | ||||
| dataset = ( | |||||
| db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() | |||||
| ) | |||||
| dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id) | |||||
| dataset = db.session.scalar(dataset_stmt) | |||||
| if not dataset: | if not dataset: | ||||
| return "" | return "" | ||||
| for record in records: | for record in records: | ||||
| segment = record.segment | segment = record.segment | ||||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | ||||
| document = ( | |||||
| db.session.query(DatasetDocument) # type: ignore | |||||
| .where( | |||||
| DatasetDocument.id == segment.document_id, | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ) | |||||
| .first() | |||||
| dataset_document_stmt = select(DatasetDocument).where( | |||||
| DatasetDocument.id == segment.document_id, | |||||
| DatasetDocument.enabled == True, | |||||
| DatasetDocument.archived == False, | |||||
| ) | ) | ||||
| document = db.session.scalar(dataset_document_stmt) # type: ignore | |||||
| if dataset and document: | if dataset and document: | ||||
| source = RetrievalSourceMetadata( | source = RetrievalSourceMetadata( | ||||
| dataset_id=dataset.id, | dataset_id=dataset.id, |
| from collections.abc import Generator | from collections.abc import Generator | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from sqlalchemy import select | |||||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod | ||||
| from core.tools.__base.tool import Tool | from core.tools.__base.tool import Tool | ||||
| from core.tools.__base.tool_runtime import ToolRuntime | from core.tools.__base.tool_runtime import ToolRuntime | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| else: | else: | ||||
| workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first() | |||||
| stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version) | |||||
| workflow = db.session.scalar(stmt) | |||||
| if not workflow: | if not workflow: | ||||
| raise ValueError("workflow not found or not published") | raise ValueError("workflow not found or not published") | ||||
| """ | """ | ||||
| get the app by app id | get the app by app id | ||||
| """ | """ | ||||
| app = db.session.query(App).where(App.id == app_id).first() | |||||
| stmt = select(App).where(App.id == app_id) | |||||
| app = db.session.scalar(stmt) | |||||
| if not app: | if not app: | ||||
| raise ValueError("app not found") | raise ValueError("app not found") | ||||
| from collections.abc import Mapping, Sequence | from collections.abc import Mapping, Sequence | ||||
| from typing import TYPE_CHECKING, Any, Optional, cast | from typing import TYPE_CHECKING, Any, Optional, cast | ||||
| from sqlalchemy import Float, and_, func, or_, text | |||||
| from sqlalchemy import Float, and_, func, or_, select, text | |||||
| from sqlalchemy import cast as sqlalchemy_cast | from sqlalchemy import cast as sqlalchemy_cast | ||||
| from sqlalchemy.orm import sessionmaker | from sqlalchemy.orm import sessionmaker | ||||
| for record in records: | for record in records: | ||||
| segment = record.segment | segment = record.segment | ||||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore | dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore | ||||
| document = ( | |||||
| db.session.query(Document) | |||||
| .where( | |||||
| Document.id == segment.document_id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ) | |||||
| .first() | |||||
| stmt = select(Document).where( | |||||
| Document.id == segment.document_id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ) | ) | ||||
| document = db.session.scalar(stmt) | |||||
| if dataset and document: | if dataset and document: | ||||
| source = { | source = { | ||||
| "metadata": { | "metadata": { | ||||
| self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData | self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData | ||||
| ) -> list[dict[str, Any]]: | ) -> list[dict[str, Any]]: | ||||
| # get all metadata field | # get all metadata field | ||||
| metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||||
| stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) | |||||
| metadata_fields = db.session.scalars(stmt).all() | |||||
| all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | ||||
| if node_data.metadata_model_config is None: | if node_data.metadata_model_config is None: | ||||
| raise ValueError("metadata_model_config is required") | raise ValueError("metadata_model_config is required") |