| @@ -0,0 +1,27 @@ | |||
| name: autofix.ci | |||
| on: | |||
| workflow_call: | |||
| pull_request: | |||
| push: | |||
| branches: [ "main" ] | |||
| permissions: | |||
| contents: read | |||
| jobs: | |||
| autofix: | |||
| runs-on: ubuntu-latest | |||
| steps: | |||
| - uses: actions/checkout@v4 | |||
| # Use uv to ensure we have the same ruff version in CI and locally. | |||
| - uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f | |||
| - run: | | |||
| cd api | |||
| uv sync --dev | |||
| # Fix lint errors | |||
| uv run ruff check --fix-only . | |||
| # Format code | |||
| uv run ruff format . | |||
| - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 | |||
| @@ -471,6 +471,16 @@ APP_MAX_ACTIVE_REQUESTS=0 | |||
| # Celery beat configuration | |||
| CELERY_BEAT_SCHEDULER_TIME=1 | |||
| # Celery schedule tasks configuration | |||
| ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false | |||
| ENABLE_CLEAN_UNUSED_DATASETS_TASK=false | |||
| ENABLE_CREATE_TIDB_SERVERLESS_TASK=false | |||
| ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false | |||
| ENABLE_CLEAN_MESSAGES=false | |||
| ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false | |||
| ENABLE_DATASETS_QUEUE_MONITOR=false | |||
| ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true | |||
| # Position configuration | |||
| POSITION_TOOL_PINS= | |||
| POSITION_TOOL_INCLUDES= | |||
| @@ -74,7 +74,12 @@ | |||
| 10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. | |||
| ```bash | |||
| uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion | |||
| uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin | |||
| ``` | |||
| Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal: | |||
| ```bash | |||
| uv run celery -A app.celery beat | |||
| ``` | |||
| ## Testing | |||
| @@ -51,7 +51,7 @@ def reset_password(email, new_password, password_confirm): | |||
| click.echo(click.style("Passwords do not match.", fg="red")) | |||
| return | |||
| account = db.session.query(Account).filter(Account.email == email).one_or_none() | |||
| account = db.session.query(Account).where(Account.email == email).one_or_none() | |||
| if not account: | |||
| click.echo(click.style("Account not found for email: {}".format(email), fg="red")) | |||
| @@ -90,7 +90,7 @@ def reset_email(email, new_email, email_confirm): | |||
| click.echo(click.style("New emails do not match.", fg="red")) | |||
| return | |||
| account = db.session.query(Account).filter(Account.email == email).one_or_none() | |||
| account = db.session.query(Account).where(Account.email == email).one_or_none() | |||
| if not account: | |||
| click.echo(click.style("Account not found for email: {}".format(email), fg="red")) | |||
| @@ -137,8 +137,8 @@ def reset_encrypt_key_pair(): | |||
| tenant.encrypt_public_key = generate_key_pair(tenant.id) | |||
| db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() | |||
| db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() | |||
| db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() | |||
| db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() | |||
| db.session.commit() | |||
| click.echo( | |||
| @@ -173,7 +173,7 @@ def migrate_annotation_vector_database(): | |||
| per_page = 50 | |||
| apps = ( | |||
| db.session.query(App) | |||
| .filter(App.status == "normal") | |||
| .where(App.status == "normal") | |||
| .order_by(App.created_at.desc()) | |||
| .limit(per_page) | |||
| .offset((page - 1) * per_page) | |||
| @@ -193,7 +193,7 @@ def migrate_annotation_vector_database(): | |||
| try: | |||
| click.echo("Creating app annotation index: {}".format(app.id)) | |||
| app_annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() | |||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() | |||
| ) | |||
| if not app_annotation_setting: | |||
| @@ -203,13 +203,13 @@ def migrate_annotation_vector_database(): | |||
| # get dataset_collection_binding info | |||
| dataset_collection_binding = ( | |||
| db.session.query(DatasetCollectionBinding) | |||
| .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) | |||
| .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) | |||
| .first() | |||
| ) | |||
| if not dataset_collection_binding: | |||
| click.echo("App annotation collection binding not found: {}".format(app.id)) | |||
| continue | |||
| annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() | |||
| annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() | |||
| dataset = Dataset( | |||
| id=app.id, | |||
| tenant_id=app.tenant_id, | |||
| @@ -306,7 +306,7 @@ def migrate_knowledge_vector_database(): | |||
| while True: | |||
| try: | |||
| stmt = ( | |||
| select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) | |||
| select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) | |||
| ) | |||
| datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) | |||
| @@ -333,7 +333,7 @@ def migrate_knowledge_vector_database(): | |||
| if dataset.collection_binding_id: | |||
| dataset_collection_binding = ( | |||
| db.session.query(DatasetCollectionBinding) | |||
| .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) | |||
| .where(DatasetCollectionBinding.id == dataset.collection_binding_id) | |||
| .one_or_none() | |||
| ) | |||
| if dataset_collection_binding: | |||
| @@ -368,7 +368,7 @@ def migrate_knowledge_vector_database(): | |||
| dataset_documents = ( | |||
| db.session.query(DatasetDocument) | |||
| .filter( | |||
| .where( | |||
| DatasetDocument.dataset_id == dataset.id, | |||
| DatasetDocument.indexing_status == "completed", | |||
| DatasetDocument.enabled == True, | |||
| @@ -382,7 +382,7 @@ def migrate_knowledge_vector_database(): | |||
| for dataset_document in dataset_documents: | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| @@ -469,7 +469,7 @@ def convert_to_agent_apps(): | |||
| app_id = str(i.id) | |||
| if app_id not in proceeded_app_ids: | |||
| proceeded_app_ids.append(app_id) | |||
| app = db.session.query(App).filter(App.id == app_id).first() | |||
| app = db.session.query(App).where(App.id == app_id).first() | |||
| if app is not None: | |||
| apps.append(app) | |||
| @@ -484,7 +484,7 @@ def convert_to_agent_apps(): | |||
| db.session.commit() | |||
| # update conversation mode to agent | |||
| db.session.query(Conversation).filter(Conversation.app_id == app.id).update( | |||
| db.session.query(Conversation).where(Conversation.app_id == app.id).update( | |||
| {Conversation.mode: AppMode.AGENT_CHAT.value} | |||
| ) | |||
| @@ -561,7 +561,7 @@ def old_metadata_migration(): | |||
| try: | |||
| stmt = ( | |||
| select(DatasetDocument) | |||
| .filter(DatasetDocument.doc_metadata.is_not(None)) | |||
| .where(DatasetDocument.doc_metadata.is_not(None)) | |||
| .order_by(DatasetDocument.created_at.desc()) | |||
| ) | |||
| documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) | |||
| @@ -579,7 +579,7 @@ def old_metadata_migration(): | |||
| else: | |||
| dataset_metadata = ( | |||
| db.session.query(DatasetMetadata) | |||
| .filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) | |||
| .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) | |||
| .first() | |||
| ) | |||
| if not dataset_metadata: | |||
| @@ -603,7 +603,7 @@ def old_metadata_migration(): | |||
| else: | |||
| dataset_metadata_binding = ( | |||
| db.session.query(DatasetMetadataBinding) # type: ignore | |||
| .filter( | |||
| .where( | |||
| DatasetMetadataBinding.dataset_id == document.dataset_id, | |||
| DatasetMetadataBinding.document_id == document.id, | |||
| DatasetMetadataBinding.metadata_id == dataset_metadata.id, | |||
| @@ -718,7 +718,7 @@ where sites.id is null limit 1000""" | |||
| continue | |||
| try: | |||
| app = db.session.query(App).filter(App.id == app_id).first() | |||
| app = db.session.query(App).where(App.id == app_id).first() | |||
| if not app: | |||
| print(f"App {app_id} not found") | |||
| continue | |||
| @@ -832,6 +832,41 @@ class CeleryBeatConfig(BaseSettings): | |||
| ) | |||
| class CeleryScheduleTasksConfig(BaseSettings): | |||
| ENABLE_CLEAN_EMBEDDING_CACHE_TASK: bool = Field( | |||
| description="Enable clean embedding cache task", | |||
| default=False, | |||
| ) | |||
| ENABLE_CLEAN_UNUSED_DATASETS_TASK: bool = Field( | |||
| description="Enable clean unused datasets task", | |||
| default=False, | |||
| ) | |||
| ENABLE_CREATE_TIDB_SERVERLESS_TASK: bool = Field( | |||
| description="Enable create tidb service job task", | |||
| default=False, | |||
| ) | |||
| ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: bool = Field( | |||
| description="Enable update tidb service job status task", | |||
| default=False, | |||
| ) | |||
| ENABLE_CLEAN_MESSAGES: bool = Field( | |||
| description="Enable clean messages task", | |||
| default=False, | |||
| ) | |||
| ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field( | |||
| description="Enable mail clean document notify task", | |||
| default=False, | |||
| ) | |||
| ENABLE_DATASETS_QUEUE_MONITOR: bool = Field( | |||
| description="Enable queue monitor task", | |||
| default=False, | |||
| ) | |||
| ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field( | |||
| description="Enable check upgradable plugin task", | |||
| default=True, | |||
| ) | |||
| class PositionConfig(BaseSettings): | |||
| POSITION_PROVIDER_PINS: str = Field( | |||
| description="Comma-separated list of pinned model providers", | |||
| @@ -961,5 +996,6 @@ class FeatureConfig( | |||
| # hosted services config | |||
| HostedServiceConfig, | |||
| CeleryBeatConfig, | |||
| CeleryScheduleTasksConfig, | |||
| ): | |||
| pass | |||
| @@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource): | |||
| parser.add_argument("position", type=int, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() | |||
| app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() | |||
| if not app: | |||
| raise NotFound(f"App '{args['app_id']}' is not found") | |||
| @@ -74,7 +74,7 @@ class InsertExploreAppListApi(Resource): | |||
| with Session(db.engine) as session: | |||
| recommended_app = session.execute( | |||
| select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) | |||
| select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) | |||
| ).scalar_one_or_none() | |||
| if not recommended_app: | |||
| @@ -117,21 +117,21 @@ class InsertExploreAppApi(Resource): | |||
| def delete(self, app_id): | |||
| with Session(db.engine) as session: | |||
| recommended_app = session.execute( | |||
| select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id)) | |||
| select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) | |||
| ).scalar_one_or_none() | |||
| if not recommended_app: | |||
| return {"result": "success"}, 204 | |||
| with Session(db.engine) as session: | |||
| app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none() | |||
| app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none() | |||
| if app: | |||
| app.is_public = False | |||
| with Session(db.engine) as session: | |||
| installed_apps = session.execute( | |||
| select(InstalledApp).filter( | |||
| select(InstalledApp).where( | |||
| InstalledApp.app_id == recommended_app.app_id, | |||
| InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, | |||
| ) | |||
| @@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource): | |||
| _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) | |||
| keys = ( | |||
| db.session.query(ApiToken) | |||
| .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) | |||
| .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) | |||
| .all() | |||
| ) | |||
| return {"items": keys} | |||
| @@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource): | |||
| current_key_count = ( | |||
| db.session.query(ApiToken) | |||
| .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) | |||
| .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) | |||
| .count() | |||
| ) | |||
| @@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource): | |||
| key = ( | |||
| db.session.query(ApiToken) | |||
| .filter( | |||
| .where( | |||
| getattr(ApiToken, self.resource_id_field) == resource_id, | |||
| ApiToken.type == self.resource_type, | |||
| ApiToken.id == api_key_id, | |||
| @@ -128,7 +128,7 @@ class BaseApiKeyResource(Resource): | |||
| if key is None: | |||
| flask_restful.abort(404, message="API key not found") | |||
| db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() | |||
| db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() | |||
| db.session.commit() | |||
| return {"result": "success"}, 204 | |||
| @@ -49,7 +49,7 @@ class CompletionConversationApi(Resource): | |||
| query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") | |||
| if args["keyword"]: | |||
| query = query.join(Message, Message.conversation_id == Conversation.id).filter( | |||
| query = query.join(Message, Message.conversation_id == Conversation.id).where( | |||
| or_( | |||
| Message.query.ilike("%{}%".format(args["keyword"])), | |||
| Message.answer.ilike("%{}%".format(args["keyword"])), | |||
| @@ -121,7 +121,7 @@ class CompletionConversationDetailApi(Resource): | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| @@ -181,7 +181,7 @@ class ChatConversationApi(Resource): | |||
| Message.conversation_id == Conversation.id, | |||
| ) | |||
| .join(subquery, subquery.c.conversation_id == Conversation.id) | |||
| .filter( | |||
| .where( | |||
| or_( | |||
| Message.query.ilike(keyword_filter), | |||
| Message.answer.ilike(keyword_filter), | |||
| @@ -286,7 +286,7 @@ class ChatConversationDetailApi(Resource): | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| @@ -308,7 +308,7 @@ api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversati | |||
| def _get_conversation(app_model, conversation_id): | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| @@ -26,7 +26,7 @@ class AppMCPServerController(Resource): | |||
| @get_app_model | |||
| @marshal_with(app_server_fields) | |||
| def get(self, app_model): | |||
| server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first() | |||
| server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() | |||
| return server | |||
| @setup_required | |||
| @@ -73,7 +73,7 @@ class AppMCPServerController(Resource): | |||
| parser.add_argument("parameters", type=dict, required=True, location="json") | |||
| parser.add_argument("status", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() | |||
| server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() | |||
| if not server: | |||
| raise NotFound() | |||
| @@ -104,8 +104,8 @@ class AppMCPServerRefreshController(Resource): | |||
| raise NotFound() | |||
| server = ( | |||
| db.session.query(AppMCPServer) | |||
| .filter(AppMCPServer.id == server_id) | |||
| .filter(AppMCPServer.tenant_id == current_user.current_tenant_id) | |||
| .where(AppMCPServer.id == server_id) | |||
| .where(AppMCPServer.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not server: | |||
| @@ -56,7 +56,7 @@ class ChatMessageListApi(Resource): | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) | |||
| .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| @@ -66,7 +66,7 @@ class ChatMessageListApi(Resource): | |||
| if args["first_id"]: | |||
| first_message = ( | |||
| db.session.query(Message) | |||
| .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) | |||
| .where(Message.conversation_id == conversation.id, Message.id == args["first_id"]) | |||
| .first() | |||
| ) | |||
| @@ -75,7 +75,7 @@ class ChatMessageListApi(Resource): | |||
| history_messages = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| .where( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < first_message.created_at, | |||
| Message.id != first_message.id, | |||
| @@ -87,7 +87,7 @@ class ChatMessageListApi(Resource): | |||
| else: | |||
| history_messages = ( | |||
| db.session.query(Message) | |||
| .filter(Message.conversation_id == conversation.id) | |||
| .where(Message.conversation_id == conversation.id) | |||
| .order_by(Message.created_at.desc()) | |||
| .limit(args["limit"]) | |||
| .all() | |||
| @@ -98,7 +98,7 @@ class ChatMessageListApi(Resource): | |||
| current_page_first_message = history_messages[-1] | |||
| rest_count = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| .where( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < current_page_first_message.created_at, | |||
| Message.id != current_page_first_message.id, | |||
| @@ -167,7 +167,7 @@ class MessageAnnotationCountApi(Resource): | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() | |||
| count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() | |||
| return {"count": count} | |||
| @@ -214,7 +214,7 @@ class MessageApi(Resource): | |||
| def get(self, app_model, message_id): | |||
| message_id = str(message_id) | |||
| message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() | |||
| message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() | |||
| if not message: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -42,7 +42,7 @@ class ModelConfigResource(Resource): | |||
| if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: | |||
| # get original app model config | |||
| original_app_model_config = ( | |||
| db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() | |||
| db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() | |||
| ) | |||
| if original_app_model_config is None: | |||
| raise ValueError("Original app model config not found") | |||
| @@ -49,7 +49,7 @@ class AppSite(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| site = db.session.query(Site).filter(Site.app_id == app_model.id).first() | |||
| site = db.session.query(Site).where(Site.app_id == app_model.id).first() | |||
| if not site: | |||
| raise NotFound | |||
| @@ -93,7 +93,7 @@ class AppSiteAccessTokenReset(Resource): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| site = db.session.query(Site).filter(Site.app_id == app_model.id).first() | |||
| site = db.session.query(Site).where(Site.app_id == app_model.id).first() | |||
| if not site: | |||
| raise NotFound | |||
| @@ -11,7 +11,7 @@ from models import App, AppMode | |||
| def _load_app_model(app_id: str) -> Optional[App]: | |||
| app_model = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| return app_model | |||
| @@ -113,3 +113,9 @@ class MemberNotInTenantError(BaseHTTPException): | |||
| error_code = "member_not_in_tenant" | |||
| description = "The member is not in the workspace." | |||
| code = 400 | |||
| class AccountInFreezeError(BaseHTTPException): | |||
| error_code = "account_in_freeze" | |||
| description = "This email is temporarily unavailable." | |||
| code = 400 | |||
| @@ -30,7 +30,7 @@ class DataSourceApi(Resource): | |||
| # get workspace data source integrates | |||
| data_source_integrates = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| .where( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.disabled == False, | |||
| ) | |||
| @@ -171,7 +171,7 @@ class DataSourceNotionApi(Resource): | |||
| page_id = str(page_id) | |||
| with Session(db.engine) as session: | |||
| data_source_binding = session.execute( | |||
| select(DataSourceOauthBinding).filter( | |||
| select(DataSourceOauthBinding).where( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| @@ -421,7 +421,7 @@ class DatasetIndexingEstimateApi(Resource): | |||
| file_ids = args["info_list"]["file_info_list"]["file_ids"] | |||
| file_details = ( | |||
| db.session.query(UploadFile) | |||
| .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) | |||
| .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) | |||
| .all() | |||
| ) | |||
| @@ -526,14 +526,14 @@ class DatasetIndexingStatusApi(Resource): | |||
| dataset_id = str(dataset_id) | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) | |||
| .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) | |||
| .all() | |||
| ) | |||
| documents_status = [] | |||
| for document in documents: | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| @@ -542,7 +542,7 @@ class DatasetIndexingStatusApi(Resource): | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| # Create a dictionary with document attributes and additional fields | |||
| @@ -577,7 +577,7 @@ class DatasetApiKeyApi(Resource): | |||
| def get(self): | |||
| keys = ( | |||
| db.session.query(ApiToken) | |||
| .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |||
| .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |||
| .all() | |||
| ) | |||
| return {"items": keys} | |||
| @@ -593,7 +593,7 @@ class DatasetApiKeyApi(Resource): | |||
| current_key_count = ( | |||
| db.session.query(ApiToken) | |||
| .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |||
| .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |||
| .count() | |||
| ) | |||
| @@ -629,7 +629,7 @@ class DatasetApiDeleteApi(Resource): | |||
| key = ( | |||
| db.session.query(ApiToken) | |||
| .filter( | |||
| .where( | |||
| ApiToken.tenant_id == current_user.current_tenant_id, | |||
| ApiToken.type == self.resource_type, | |||
| ApiToken.id == api_key_id, | |||
| @@ -640,7 +640,7 @@ class DatasetApiDeleteApi(Resource): | |||
| if key is None: | |||
| flask_restful.abort(404, message="API key not found") | |||
| db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() | |||
| db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() | |||
| db.session.commit() | |||
| return {"result": "success"}, 204 | |||
| @@ -126,7 +126,7 @@ class GetProcessRuleApi(Resource): | |||
| # get the latest process rule | |||
| dataset_process_rule = ( | |||
| db.session.query(DatasetProcessRule) | |||
| .filter(DatasetProcessRule.dataset_id == document.dataset_id) | |||
| .where(DatasetProcessRule.dataset_id == document.dataset_id) | |||
| .order_by(DatasetProcessRule.created_at.desc()) | |||
| .limit(1) | |||
| .one_or_none() | |||
| @@ -178,7 +178,7 @@ class DatasetDocumentListApi(Resource): | |||
| if search: | |||
| search = f"%{search}%" | |||
| query = query.filter(Document.name.like(search)) | |||
| query = query.where(Document.name.like(search)) | |||
| if sort.startswith("-"): | |||
| sort_logic = desc | |||
| @@ -214,7 +214,7 @@ class DatasetDocumentListApi(Resource): | |||
| for document in documents: | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| @@ -223,7 +223,7 @@ class DatasetDocumentListApi(Resource): | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| document.completed_segments = completed_segments | |||
| @@ -419,7 +419,7 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| file = ( | |||
| db.session.query(UploadFile) | |||
| .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) | |||
| .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) | |||
| .first() | |||
| ) | |||
| @@ -494,7 +494,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| file_id = data_source_info["upload_file_id"] | |||
| file_detail = ( | |||
| db.session.query(UploadFile) | |||
| .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) | |||
| .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) | |||
| .first() | |||
| ) | |||
| @@ -570,7 +570,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| for document in documents: | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| @@ -579,7 +579,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| # Create a dictionary with document attributes and additional fields | |||
| @@ -613,7 +613,7 @@ class DocumentIndexingStatusApi(DocumentResource): | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document_id), | |||
| DocumentSegment.status != "re_segment", | |||
| @@ -622,7 +622,7 @@ class DocumentIndexingStatusApi(DocumentResource): | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") | |||
| .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| @@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| query = ( | |||
| select(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.document_id == str(document_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id, | |||
| ) | |||
| @@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| ) | |||
| if status_list: | |||
| query = query.filter(DocumentSegment.status.in_(status_list)) | |||
| query = query.where(DocumentSegment.status.in_(status_list)) | |||
| if hit_count_gte is not None: | |||
| query = query.filter(DocumentSegment.hit_count >= hit_count_gte) | |||
| query = query.where(DocumentSegment.hit_count >= hit_count_gte) | |||
| if keyword: | |||
| query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) | |||
| if args["enabled"].lower() != "all": | |||
| if args["enabled"].lower() == "true": | |||
| query = query.filter(DocumentSegment.enabled == True) | |||
| query = query.where(DocumentSegment.enabled == True) | |||
| elif args["enabled"].lower() == "false": | |||
| query = query.filter(DocumentSegment.enabled == False) | |||
| query = query.where(DocumentSegment.enabled == False) | |||
| segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| @@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| segment_id = str(segment_id) | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| @@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| segment_id = str(segment_id) | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| @@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource): | |||
| segment_id = str(segment_id) | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| @@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource): | |||
| segment_id = str(segment_id) | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| @@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource): | |||
| segment_id = str(segment_id) | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| @@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource): | |||
| segment_id = str(segment_id) | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| @@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource): | |||
| child_chunk_id = str(child_chunk_id) | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) | |||
| .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not child_chunk: | |||
| @@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource): | |||
| segment_id = str(segment_id) | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| @@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource): | |||
| child_chunk_id = str(child_chunk_id) | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) | |||
| .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) | |||
| .first() | |||
| ) | |||
| if not child_chunk: | |||
| @@ -34,11 +34,11 @@ class InstalledAppsListApi(Resource): | |||
| if app_id: | |||
| installed_apps = ( | |||
| db.session.query(InstalledApp) | |||
| .filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) | |||
| .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) | |||
| .all() | |||
| ) | |||
| else: | |||
| installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() | |||
| installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() | |||
| current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) | |||
| installed_app_list: list[dict[str, Any]] = [ | |||
| @@ -94,12 +94,12 @@ class InstalledAppsListApi(Resource): | |||
| parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") | |||
| args = parser.parse_args() | |||
| recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first() | |||
| recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() | |||
| if recommended_app is None: | |||
| raise NotFound("App not found") | |||
| current_tenant_id = current_user.current_tenant_id | |||
| app = db.session.query(App).filter(App.id == args["app_id"]).first() | |||
| app = db.session.query(App).where(App.id == args["app_id"]).first() | |||
| if app is None: | |||
| raise NotFound("App not found") | |||
| @@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource): | |||
| installed_app = ( | |||
| db.session.query(InstalledApp) | |||
| .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) | |||
| .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) | |||
| .first() | |||
| ) | |||
| @@ -28,7 +28,7 @@ def installed_app_required(view=None): | |||
| installed_app = ( | |||
| db.session.query(InstalledApp) | |||
| .filter( | |||
| .where( | |||
| InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id | |||
| ) | |||
| .first() | |||
| @@ -21,7 +21,7 @@ def plugin_permission_required( | |||
| with Session(db.engine) as session: | |||
| permission = ( | |||
| session.query(TenantPluginPermission) | |||
| .filter( | |||
| .where( | |||
| TenantPluginPermission.tenant_id == tenant_id, | |||
| ) | |||
| .first() | |||
| @@ -9,6 +9,7 @@ from configs import dify_config | |||
| from constants.languages import supported_language | |||
| from controllers.console import api | |||
| from controllers.console.auth.error import ( | |||
| AccountInFreezeError, | |||
| EmailAlreadyInUseError, | |||
| EmailChangeLimitError, | |||
| EmailCodeError, | |||
| @@ -68,7 +69,7 @@ class AccountInitApi(Resource): | |||
| # check invitation code | |||
| invitation_code = ( | |||
| db.session.query(InvitationCode) | |||
| .filter( | |||
| .where( | |||
| InvitationCode.code == args["invitation_code"], | |||
| InvitationCode.status == "unused", | |||
| ) | |||
| @@ -228,7 +229,7 @@ class AccountIntegrateApi(Resource): | |||
| def get(self): | |||
| account = current_user | |||
| account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() | |||
| account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() | |||
| base_url = request.url_root.rstrip("/") | |||
| oauth_base_path = "/console/api/oauth/login" | |||
| @@ -479,21 +480,28 @@ class ChangeEmailResetApi(Resource): | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| if AccountService.is_account_in_freeze(args["new_email"]): | |||
| raise AccountInFreezeError() | |||
| if not AccountService.check_email_unique(args["new_email"]): | |||
| raise EmailAlreadyInUseError() | |||
| reset_data = AccountService.get_change_email_data(args["token"]) | |||
| if not reset_data: | |||
| raise InvalidTokenError() | |||
| AccountService.revoke_change_email_token(args["token"]) | |||
| if not AccountService.check_email_unique(args["new_email"]): | |||
| raise EmailAlreadyInUseError() | |||
| old_email = reset_data.get("old_email", "") | |||
| if current_user.email != old_email: | |||
| raise AccountNotFound() | |||
| updated_account = AccountService.update_account(current_user, email=args["new_email"]) | |||
| AccountService.send_change_email_completed_notify_email( | |||
| email=args["new_email"], | |||
| ) | |||
| return updated_account | |||
| @@ -108,7 +108,7 @@ class MemberCancelInviteApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, member_id): | |||
| member = db.session.query(Account).filter(Account.id == str(member_id)).first() | |||
| member = db.session.query(Account).where(Account.id == str(member_id)).first() | |||
| if member is None: | |||
| abort(404) | |||
| else: | |||
| @@ -12,7 +12,8 @@ from controllers.console.wraps import account_initialization_required, setup_req | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from libs.login import login_required | |||
| from models.account import TenantPluginPermission | |||
| from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission | |||
| from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService | |||
| from services.plugin.plugin_parameter_service import PluginParameterService | |||
| from services.plugin.plugin_permission_service import PluginPermissionService | |||
| from services.plugin.plugin_service import PluginService | |||
| @@ -534,6 +535,114 @@ class PluginFetchDynamicSelectOptionsApi(Resource): | |||
| return jsonable_encoder({"options": options}) | |||
| class PluginChangePreferencesApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| user = current_user | |||
| if not user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| req = reqparse.RequestParser() | |||
| req.add_argument("permission", type=dict, required=True, location="json") | |||
| req.add_argument("auto_upgrade", type=dict, required=True, location="json") | |||
| args = req.parse_args() | |||
| tenant_id = user.current_tenant_id | |||
| permission = args["permission"] | |||
| install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone")) | |||
| debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone")) | |||
| auto_upgrade = args["auto_upgrade"] | |||
| strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting( | |||
| auto_upgrade.get("strategy_setting", "fix_only") | |||
| ) | |||
| upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0) | |||
| upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude")) | |||
| exclude_plugins = auto_upgrade.get("exclude_plugins", []) | |||
| include_plugins = auto_upgrade.get("include_plugins", []) | |||
| # set permission | |||
| set_permission_result = PluginPermissionService.change_permission( | |||
| tenant_id, | |||
| install_permission, | |||
| debug_permission, | |||
| ) | |||
| if not set_permission_result: | |||
| return jsonable_encoder({"success": False, "message": "Failed to set permission"}) | |||
| # set auto upgrade strategy | |||
| set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy( | |||
| tenant_id, | |||
| strategy_setting, | |||
| upgrade_time_of_day, | |||
| upgrade_mode, | |||
| exclude_plugins, | |||
| include_plugins, | |||
| ) | |||
| if not set_auto_upgrade_strategy_result: | |||
| return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"}) | |||
| return jsonable_encoder({"success": True}) | |||
| class PluginFetchPreferencesApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| permission = PluginPermissionService.get_permission(tenant_id) | |||
| permission_dict = { | |||
| "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, | |||
| "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, | |||
| } | |||
| if permission: | |||
| permission_dict["install_permission"] = permission.install_permission | |||
| permission_dict["debug_permission"] = permission.debug_permission | |||
| auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id) | |||
| auto_upgrade_dict = { | |||
| "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED, | |||
| "upgrade_time_of_day": 0, | |||
| "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, | |||
| "exclude_plugins": [], | |||
| "include_plugins": [], | |||
| } | |||
| if auto_upgrade: | |||
| auto_upgrade_dict = { | |||
| "strategy_setting": auto_upgrade.strategy_setting, | |||
| "upgrade_time_of_day": auto_upgrade.upgrade_time_of_day, | |||
| "upgrade_mode": auto_upgrade.upgrade_mode, | |||
| "exclude_plugins": auto_upgrade.exclude_plugins, | |||
| "include_plugins": auto_upgrade.include_plugins, | |||
| } | |||
| return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict}) | |||
| class PluginAutoUpgradeExcludePluginApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| # exclude one single plugin | |||
| tenant_id = current_user.current_tenant_id | |||
| req = reqparse.RequestParser() | |||
| req.add_argument("plugin_id", type=str, required=True, location="json") | |||
| args = req.parse_args() | |||
| return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])}) | |||
| api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") | |||
| api.add_resource(PluginListApi, "/workspaces/current/plugin/list") | |||
| api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions") | |||
| @@ -560,3 +669,7 @@ api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permissi | |||
| api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") | |||
| api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options") | |||
| api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch") | |||
| api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change") | |||
| api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude") | |||
| @@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource): | |||
| raise Forbidden("no oauth available client config found for this tool provider") | |||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" | |||
| credentials = oauth_handler.get_credentials( | |||
| credentials_response = oauth_handler.get_credentials( | |||
| tenant_id=tenant_id, | |||
| user_id=user_id, | |||
| plugin_id=plugin_id, | |||
| @@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource): | |||
| redirect_uri=redirect_uri, | |||
| system_credentials=oauth_client_params, | |||
| request=request, | |||
| ).credentials | |||
| ) | |||
| credentials = credentials_response.credentials | |||
| expires_at = credentials_response.expires_at | |||
| if not credentials: | |||
| raise Exception("the plugin credentials failed") | |||
| @@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource): | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| credentials=dict(credentials), | |||
| expires_at=expires_at, | |||
| api_type=CredentialType.OAUTH2, | |||
| ) | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") | |||
| @@ -22,7 +22,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: | |||
| user_id = "DEFAULT-USER" | |||
| if user_id == "DEFAULT-USER": | |||
| user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() | |||
| user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first() | |||
| if not user_model: | |||
| user_model = EndUser( | |||
| tenant_id=tenant_id, | |||
| @@ -36,7 +36,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: | |||
| else: | |||
| user_model = AccountService.load_user(user_id) | |||
| if not user_model: | |||
| user_model = session.query(EndUser).filter(EndUser.id == user_id).first() | |||
| user_model = session.query(EndUser).where(EndUser.id == user_id).first() | |||
| if not user_model: | |||
| raise ValueError("user not found") | |||
| except Exception: | |||
| @@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None): | |||
| try: | |||
| tenant_model = ( | |||
| db.session.query(Tenant) | |||
| .filter( | |||
| .where( | |||
| Tenant.id == tenant_id, | |||
| ) | |||
| .first() | |||
| @@ -55,7 +55,7 @@ def enterprise_inner_api_user_auth(view): | |||
| if signature_base64 != token: | |||
| return view(*args, **kwargs) | |||
| kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() | |||
| kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first() | |||
| return view(*args, **kwargs) | |||
| @@ -30,7 +30,7 @@ class MCPAppApi(Resource): | |||
| request_id = args.get("id") | |||
| server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() | |||
| server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() | |||
| if not server: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found") | |||
| @@ -41,7 +41,7 @@ class MCPAppApi(Resource): | |||
| create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active") | |||
| ) | |||
| app = db.session.query(App).filter(App.id == server.app_id).first() | |||
| app = db.session.query(App).where(App.id == server.app_id).first() | |||
| if not app: | |||
| return helper.compact_generate_response( | |||
| create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found") | |||
| @@ -1,5 +1,6 @@ | |||
| import logging | |||
| from flask import request | |||
| from flask_restful import Resource, reqparse | |||
| from werkzeug.exceptions import InternalServerError, NotFound | |||
| @@ -23,6 +24,7 @@ from core.errors.error import ( | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| ) | |||
| from core.helper.trace_id_helper import get_external_trace_id | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs import helper | |||
| from libs.helper import uuid_value | |||
| @@ -111,6 +113,10 @@ class ChatApi(Resource): | |||
| args = parser.parse_args() | |||
| external_trace_id = get_external_trace_id(request) | |||
| if external_trace_id: | |||
| args["external_trace_id"] = external_trace_id | |||
| streaming = args["response_mode"] == "streaming" | |||
| try: | |||
| @@ -16,7 +16,7 @@ class AppSiteApi(Resource): | |||
| @marshal_with(fields.site_fields) | |||
| def get(self, app_model: App): | |||
| """Retrieve app site info.""" | |||
| site = db.session.query(Site).filter(Site.app_id == app_model.id).first() | |||
| site = db.session.query(Site).where(Site.app_id == app_model.id).first() | |||
| if not site: | |||
| raise Forbidden() | |||
| @@ -1,6 +1,7 @@ | |||
| import logging | |||
| from dateutil.parser import isoparse | |||
| from flask import request | |||
| from flask_restful import Resource, fields, marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| @@ -23,6 +24,7 @@ from core.errors.error import ( | |||
| ProviderTokenNotInitError, | |||
| QuotaExceededError, | |||
| ) | |||
| from core.helper.trace_id_helper import get_external_trace_id | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus | |||
| from extensions.ext_database import db | |||
| @@ -90,7 +92,9 @@ class WorkflowRunApi(Resource): | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| args = parser.parse_args() | |||
| external_trace_id = get_external_trace_id(request) | |||
| if external_trace_id: | |||
| args["external_trace_id"] = external_trace_id | |||
| streaming = args.get("response_mode") == "streaming" | |||
| try: | |||
| @@ -63,7 +63,7 @@ class DocumentAddByTextApi(DatasetApiResource): | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("Dataset does not exist.") | |||
| @@ -136,7 +136,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): | |||
| args = parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("Dataset does not exist.") | |||
| @@ -206,7 +206,7 @@ class DocumentAddByFileApi(DatasetApiResource): | |||
| # get dataset info | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("Dataset does not exist.") | |||
| @@ -299,7 +299,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): | |||
| # get dataset info | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("Dataset does not exist.") | |||
| @@ -367,7 +367,7 @@ class DocumentDeleteApi(DatasetApiResource): | |||
| tenant_id = str(tenant_id) | |||
| # get dataset info | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError("Dataset does not exist.") | |||
| @@ -398,7 +398,7 @@ class DocumentListApi(DatasetApiResource): | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| search = request.args.get("keyword", default=None, type=str) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource): | |||
| if search: | |||
| search = f"%{search}%" | |||
| query = query.filter(Document.name.like(search)) | |||
| query = query.where(Document.name.like(search)) | |||
| query = query.order_by(desc(Document.created_at), desc(Document.position)) | |||
| @@ -430,7 +430,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): | |||
| batch = str(batch) | |||
| tenant_id = str(tenant_id) | |||
| # get dataset | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # get documents | |||
| @@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): | |||
| for document in documents: | |||
| completed_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| @@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource): | |||
| ) | |||
| total_segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") | |||
| .count() | |||
| ) | |||
| # Create a dictionary with document attributes and additional fields | |||
| @@ -42,7 +42,7 @@ class SegmentApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| @@ -89,7 +89,7 @@ class SegmentApi(DatasetApiResource): | |||
| tenant_id = str(tenant_id) | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| @@ -146,7 +146,7 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| @@ -170,7 +170,7 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| @@ -216,7 +216,7 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| @@ -246,7 +246,7 @@ class ChildChunkApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -296,7 +296,7 @@ class ChildChunkApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -343,7 +343,7 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -382,7 +382,7 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -17,7 +17,7 @@ class UploadFileApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| @@ -31,7 +31,7 @@ class UploadFileApi(DatasetApiResource): | |||
| data_source_info = document.data_source_info_dict | |||
| if data_source_info and "upload_file_id" in data_source_info: | |||
| file_id = data_source_info["upload_file_id"] | |||
| upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() | |||
| upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() | |||
| if not upload_file: | |||
| raise NotFound("UploadFile not found.") | |||
| else: | |||
| @@ -44,7 +44,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| def decorated_view(*args, **kwargs): | |||
| api_token = validate_and_get_api_token("app") | |||
| app_model = db.session.query(App).filter(App.id == api_token.app_id).first() | |||
| app_model = db.session.query(App).where(App.id == api_token.app_id).first() | |||
| if not app_model: | |||
| raise Forbidden("The app no longer exists.") | |||
| @@ -54,7 +54,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| if not app_model.enable_api: | |||
| raise Forbidden("The app's API service has been disabled.") | |||
| tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() | |||
| tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() | |||
| if tenant is None: | |||
| raise ValueError("Tenant does not exist.") | |||
| if tenant.status == TenantStatus.ARCHIVE: | |||
| @@ -62,15 +62,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| tenant_account_join = ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .filter(Tenant.id == api_token.tenant_id) | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .filter(TenantAccountJoin.role.in_(["owner"])) | |||
| .filter(Tenant.status == TenantStatus.NORMAL) | |||
| .where(Tenant.id == api_token.tenant_id) | |||
| .where(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .where(TenantAccountJoin.role.in_(["owner"])) | |||
| .where(Tenant.status == TenantStatus.NORMAL) | |||
| .one_or_none() | |||
| ) # TODO: only owner information is required, so only one is returned. | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| account = db.session.query(Account).filter(Account.id == ta.account_id).first() | |||
| account = db.session.query(Account).where(Account.id == ta.account_id).first() | |||
| # Login admin | |||
| if account: | |||
| account.current_tenant = tenant | |||
| @@ -213,15 +213,15 @@ def validate_dataset_token(view=None): | |||
| api_token = validate_and_get_api_token("dataset") | |||
| tenant_account_join = ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .filter(Tenant.id == api_token.tenant_id) | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .filter(TenantAccountJoin.role.in_(["owner"])) | |||
| .filter(Tenant.status == TenantStatus.NORMAL) | |||
| .where(Tenant.id == api_token.tenant_id) | |||
| .where(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .where(TenantAccountJoin.role.in_(["owner"])) | |||
| .where(Tenant.status == TenantStatus.NORMAL) | |||
| .one_or_none() | |||
| ) # TODO: only owner information is required, so only one is returned. | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| account = db.session.query(Account).filter(Account.id == ta.account_id).first() | |||
| account = db.session.query(Account).where(Account.id == ta.account_id).first() | |||
| # Login admin | |||
| if account: | |||
| account.current_tenant = tenant | |||
| @@ -293,7 +293,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] | |||
| end_user = ( | |||
| db.session.query(EndUser) | |||
| .filter( | |||
| .where( | |||
| EndUser.tenant_id == app_model.tenant_id, | |||
| EndUser.app_id == app_model.id, | |||
| EndUser.session_id == user_id, | |||
| @@ -320,7 +320,7 @@ class DatasetApiResource(Resource): | |||
| method_decorators = [validate_dataset_token] | |||
| def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta | |||
| from flask import request | |||
| from flask_restful import Resource | |||
| from sqlalchemy import func, select | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| from configs import dify_config | |||
| @@ -42,17 +43,17 @@ class PassportResource(Resource): | |||
| raise WebAppAuthRequiredError() | |||
| # get site from db and check if it is normal | |||
| site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() | |||
| site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) | |||
| if not site: | |||
| raise NotFound() | |||
| # get app from db and check if it is normal and enable_site | |||
| app_model = db.session.query(App).filter(App.id == site.app_id).first() | |||
| app_model = db.session.scalar(select(App).where(App.id == site.app_id)) | |||
| if not app_model or app_model.status != "normal" or not app_model.enable_site: | |||
| raise NotFound() | |||
| if user_id: | |||
| end_user = ( | |||
| db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() | |||
| end_user = db.session.scalar( | |||
| select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id) | |||
| ) | |||
| if end_user: | |||
| @@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: | |||
| if not user_auth_type: | |||
| raise Unauthorized("Missing auth_type in the token.") | |||
| site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() | |||
| site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) | |||
| if not site: | |||
| raise NotFound() | |||
| app_model = db.session.query(App).filter(App.id == site.app_id).first() | |||
| app_model = db.session.scalar(select(App).where(App.id == site.app_id)) | |||
| if not app_model or app_model.status != "normal" or not app_model.enable_site: | |||
| raise NotFound() | |||
| @@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: | |||
| end_user = None | |||
| if end_user_id: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() | |||
| end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) | |||
| if session_id: | |||
| end_user = ( | |||
| db.session.query(EndUser) | |||
| .filter( | |||
| end_user = db.session.scalar( | |||
| select(EndUser).where( | |||
| EndUser.session_id == session_id, | |||
| EndUser.tenant_id == app_model.tenant_id, | |||
| EndUser.app_id == app_model.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not end_user: | |||
| if not session_id: | |||
| @@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded): | |||
| user_id = token_decoded.get("user_id") | |||
| end_user = None | |||
| if user_id: | |||
| end_user = ( | |||
| db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() | |||
| end_user = db.session.scalar( | |||
| select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id) | |||
| ) | |||
| if not end_user: | |||
| @@ -224,6 +223,8 @@ def generate_session_id(): | |||
| """ | |||
| while True: | |||
| session_id = str(uuid.uuid4()) | |||
| existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() | |||
| existing_count = db.session.scalar( | |||
| select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id) | |||
| ) | |||
| if existing_count == 0: | |||
| return session_id | |||
| @@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource): | |||
| def get(self, app_model, end_user): | |||
| """Retrieve app site info.""" | |||
| # get site | |||
| site = db.session.query(Site).filter(Site.app_id == app_model.id).first() | |||
| site = db.session.query(Site).where(Site.app_id == app_model.id).first() | |||
| if not site: | |||
| raise Forbidden() | |||
| @@ -3,6 +3,7 @@ from functools import wraps | |||
| from flask import request | |||
| from flask_restful import Resource | |||
| from sqlalchemy import select | |||
| from werkzeug.exceptions import BadRequest, NotFound, Unauthorized | |||
| from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError | |||
| @@ -48,8 +49,8 @@ def decode_jwt_token(): | |||
| decoded = PassportService().verify(tk) | |||
| app_code = decoded.get("app_code") | |||
| app_id = decoded.get("app_id") | |||
| app_model = db.session.query(App).filter(App.id == app_id).first() | |||
| site = db.session.query(Site).filter(Site.code == app_code).first() | |||
| app_model = db.session.scalar(select(App).where(App.id == app_id)) | |||
| site = db.session.scalar(select(Site).where(Site.code == app_code)) | |||
| if not app_model: | |||
| raise NotFound() | |||
| if not app_code or not site: | |||
| @@ -57,7 +58,7 @@ def decode_jwt_token(): | |||
| if app_model.enable_site is False: | |||
| raise BadRequest("Site is disabled.") | |||
| end_user_id = decoded.get("end_user_id") | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() | |||
| end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) | |||
| if not end_user: | |||
| raise NotFound() | |||
| @@ -99,7 +99,7 @@ class BaseAgentRunner(AppRunner): | |||
| # get how many agent thoughts have been created | |||
| self.agent_thought_count = ( | |||
| db.session.query(MessageAgentThought) | |||
| .filter( | |||
| .where( | |||
| MessageAgentThought.message_id == self.message.id, | |||
| ) | |||
| .count() | |||
| @@ -336,7 +336,7 @@ class BaseAgentRunner(AppRunner): | |||
| Save agent thought | |||
| """ | |||
| updated_agent_thought = ( | |||
| db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() | |||
| db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first() | |||
| ) | |||
| if not updated_agent_thought: | |||
| raise ValueError("agent thought not found") | |||
| @@ -496,7 +496,7 @@ class BaseAgentRunner(AppRunner): | |||
| return result | |||
| def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: | |||
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | |||
| files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() | |||
| if not files: | |||
| return UserPromptMessage(content=message.query) | |||
| if message.app_model_config: | |||
| @@ -1,48 +0,0 @@ | |||
| ## Guidelines for Database Connection Management in App Runner and Task Pipeline | |||
| Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks. | |||
| Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors. | |||
| Examples: | |||
| 1. Creating a new record: | |||
| ```python | |||
| app = App(id=1) | |||
| db.session.add(app) | |||
| db.session.commit() | |||
| db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close | |||
| # Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment). | |||
| db.session.close() | |||
| return app.id | |||
| ``` | |||
| 2. Fetching a record from the table: | |||
| ```python | |||
| app = db.session.query(App).filter(App.id == app_id).first() | |||
| created_at = app.created_at | |||
| db.session.close() | |||
| # Handle tasks (include long-running). | |||
| ``` | |||
| 3. Updating a table field: | |||
| ```python | |||
| app = db.session.query(App).filter(App.id == app_id).first() | |||
| app.updated_at = time.utcnow() | |||
| db.session.commit() | |||
| db.session.close() | |||
| return app_id | |||
| ``` | |||
| @@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| from sqlalchemy.orm import sessionmaker | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| import contexts | |||
| from configs import dify_config | |||
| @@ -23,6 +24,7 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator | |||
| from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse | |||
| from core.helper.trace_id_helper import extract_external_trace_id_from_args | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.prompt.utils.get_thread_messages_length import get_thread_messages_length | |||
| @@ -112,7 +114,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| query = query.replace("\x00", "") | |||
| inputs = args["inputs"] | |||
| extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} | |||
| extras = { | |||
| "auto_generate_conversation_name": args.get("auto_generate_name", False), | |||
| **extract_external_trace_id_from_args(args), | |||
| } | |||
| # get conversation | |||
| conversation = None | |||
| @@ -482,21 +487,52 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): | |||
| """ | |||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||
| try: | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| # chatbot app | |||
| runner = AdvancedChatAppRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message, | |||
| dialogue_count=self._dialogue_count, | |||
| variable_loader=variable_loader, | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow = session.scalar( | |||
| select(Workflow).where( | |||
| Workflow.tenant_id == application_generate_entity.app_config.tenant_id, | |||
| Workflow.app_id == application_generate_entity.app_config.app_id, | |||
| Workflow.id == application_generate_entity.app_config.workflow_id, | |||
| ) | |||
| ) | |||
| if workflow is None: | |||
| raise ValueError("Workflow not found") | |||
| # Determine system_user_id based on invocation source | |||
| is_external_api_call = application_generate_entity.invoke_from in { | |||
| InvokeFrom.WEB_APP, | |||
| InvokeFrom.SERVICE_API, | |||
| } | |||
| if is_external_api_call: | |||
| # For external API calls, use end user's session ID | |||
| end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id)) | |||
| system_user_id = end_user.session_id if end_user else "" | |||
| else: | |||
| # For internal calls, use the original user ID | |||
| system_user_id = application_generate_entity.user_id | |||
| app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id)) | |||
| if app is None: | |||
| raise ValueError("App not found") | |||
| runner = AdvancedChatAppRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message, | |||
| dialogue_count=self._dialogue_count, | |||
| variable_loader=variable_loader, | |||
| workflow=workflow, | |||
| system_user_id=system_user_id, | |||
| app=app, | |||
| ) | |||
| try: | |||
| runner.run() | |||
| except GenerateTaskStoppedError: | |||
| pass | |||
| @@ -1,6 +1,6 @@ | |||
| import logging | |||
| from collections.abc import Mapping | |||
| from typing import Any, cast | |||
| from typing import Any, Optional, cast | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| @@ -9,13 +9,19 @@ from configs import dify_config | |||
| from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager | |||
| from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner | |||
| from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom | |||
| from core.app.entities.app_invoke_entities import ( | |||
| AdvancedChatAppGenerateEntity, | |||
| AppGenerateEntity, | |||
| InvokeFrom, | |||
| ) | |||
| from core.app.entities.queue_entities import ( | |||
| QueueAnnotationReplyEvent, | |||
| QueueStopEvent, | |||
| QueueTextChunkEvent, | |||
| ) | |||
| from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature | |||
| from core.moderation.base import ModerationError | |||
| from core.moderation.input_moderation import InputModeration | |||
| from core.variables.variables import VariableUnion | |||
| from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback | |||
| from core.workflow.entities.variable_pool import VariablePool | |||
| @@ -23,8 +29,9 @@ from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models import Workflow | |||
| from models.enums import UserFrom | |||
| from models.model import App, Conversation, EndUser, Message | |||
| from models.model import App, Conversation, Message, MessageAnnotation | |||
| from models.workflow import ConversationVariable, WorkflowType | |||
| logger = logging.getLogger(__name__) | |||
| @@ -37,42 +44,38 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| def __init__( | |||
| self, | |||
| *, | |||
| application_generate_entity: AdvancedChatAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| dialogue_count: int, | |||
| variable_loader: VariableLoader, | |||
| workflow: Workflow, | |||
| system_user_id: str, | |||
| app: App, | |||
| ) -> None: | |||
| super().__init__(queue_manager, variable_loader) | |||
| super().__init__( | |||
| queue_manager=queue_manager, | |||
| variable_loader=variable_loader, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| self.application_generate_entity = application_generate_entity | |||
| self.conversation = conversation | |||
| self.message = message | |||
| self._dialogue_count = dialogue_count | |||
| def _get_app_id(self) -> str: | |||
| return self.application_generate_entity.app_config.app_id | |||
| self._workflow = workflow | |||
| self.system_user_id = system_user_id | |||
| self._app = app | |||
| def run(self) -> None: | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(AdvancedChatAppConfig, app_config) | |||
| app_record = db.session.query(App).filter(App.id == app_config.app_id).first() | |||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) | |||
| if not workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| user_id: str | None = None | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = self.application_generate_entity.user_id | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if dify_config.DEBUG: | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| @@ -80,14 +83,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| if self.application_generate_entity.single_iteration_run: | |||
| # if only single iteration run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=workflow, | |||
| workflow=self._workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), | |||
| ) | |||
| elif self.application_generate_entity.single_loop_run: | |||
| # if only single loop run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( | |||
| workflow=workflow, | |||
| workflow=self._workflow, | |||
| node_id=self.application_generate_entity.single_loop_run.node_id, | |||
| user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), | |||
| ) | |||
| @@ -98,7 +101,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| # moderation | |||
| if self.handle_input_moderation( | |||
| app_record=app_record, | |||
| app_record=self._app, | |||
| app_generate_entity=self.application_generate_entity, | |||
| inputs=inputs, | |||
| query=query, | |||
| @@ -108,7 +111,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| # annotation reply | |||
| if self.handle_annotation_reply( | |||
| app_record=app_record, | |||
| app_record=self._app, | |||
| message=self.message, | |||
| query=query, | |||
| app_generate_entity=self.application_generate_entity, | |||
| @@ -128,7 +131,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| ConversationVariable.from_variable( | |||
| app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable | |||
| ) | |||
| for variable in workflow.conversation_variables | |||
| for variable in self._workflow.conversation_variables | |||
| ] | |||
| session.add_all(db_conversation_variables) | |||
| # Convert database entities to variables. | |||
| @@ -141,7 +144,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| query=query, | |||
| files=files, | |||
| conversation_id=self.conversation.id, | |||
| user_id=user_id, | |||
| user_id=self.system_user_id, | |||
| dialogue_count=self._dialogue_count, | |||
| app_id=app_config.app_id, | |||
| workflow_id=app_config.workflow_id, | |||
| @@ -152,25 +155,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| environment_variables=self._workflow.environment_variables, | |||
| # Based on the definition of `VariableUnion`, | |||
| # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. | |||
| conversation_variables=cast(list[VariableUnion], conversation_variables), | |||
| ) | |||
| # init graph | |||
| graph = self._init_graph(graph_config=workflow.graph_dict) | |||
| graph = self._init_graph(graph_config=self._workflow.graph_dict) | |||
| db.session.close() | |||
| # RUN WORKFLOW | |||
| workflow_entry = WorkflowEntry( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType.value_of(workflow.type), | |||
| tenant_id=self._workflow.tenant_id, | |||
| app_id=self._workflow.app_id, | |||
| workflow_id=self._workflow.id, | |||
| workflow_type=WorkflowType.value_of(self._workflow.type), | |||
| graph=graph, | |||
| graph_config=workflow.graph_dict, | |||
| graph_config=self._workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=( | |||
| UserFrom.ACCOUNT | |||
| @@ -241,3 +244,51 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): | |||
| self._publish_event(QueueTextChunkEvent(text=text)) | |||
| self._publish_event(QueueStopEvent(stopped_by=stopped_by)) | |||
| def query_app_annotations_to_reply( | |||
| self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom | |||
| ) -> Optional[MessageAnnotation]: | |||
| """ | |||
| Query app annotations to reply | |||
| :param app_record: app record | |||
| :param message: message | |||
| :param query: query | |||
| :param user_id: user id | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| annotation_reply_feature = AnnotationReplyFeature() | |||
| return annotation_reply_feature.query( | |||
| app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from | |||
| ) | |||
| def moderation_for_inputs( | |||
| self, | |||
| *, | |||
| app_id: str, | |||
| tenant_id: str, | |||
| app_generate_entity: AppGenerateEntity, | |||
| inputs: Mapping[str, Any], | |||
| query: str | None = None, | |||
| message_id: str, | |||
| ) -> tuple[bool, Mapping[str, Any], str]: | |||
| """ | |||
| Process sensitive_word_avoidance. | |||
| :param app_id: app id | |||
| :param tenant_id: tenant id | |||
| :param app_generate_entity: app generate entity | |||
| :param inputs: inputs | |||
| :param query: query | |||
| :param message_id: message id | |||
| :return: | |||
| """ | |||
| moderation_feature = InputModeration() | |||
| return moderation_feature.check( | |||
| app_id=app_id, | |||
| tenant_id=tenant_id, | |||
| app_config=app_generate_entity.app_config, | |||
| inputs=dict(inputs), | |||
| query=query or "", | |||
| message_id=message_id, | |||
| trace_manager=app_generate_entity.trace_manager, | |||
| ) | |||
| @@ -559,6 +559,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| outputs=event.outputs, | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), | |||
| ) | |||
| workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( | |||
| session=session, | |||
| @@ -590,6 +591,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), | |||
| ) | |||
| workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( | |||
| session=session, | |||
| @@ -622,6 +624,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count, | |||
| external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), | |||
| ) | |||
| workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( | |||
| session=session, | |||
| @@ -653,6 +656,7 @@ class AdvancedChatAppGenerateTaskPipeline: | |||
| error_message=event.get_stop_reason(), | |||
| conversation_id=self._conversation_id, | |||
| trace_manager=trace_manager, | |||
| external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), | |||
| ) | |||
| workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( | |||
| session=session, | |||
| @@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner): | |||
| app_config = application_generate_entity.app_config | |||
| app_config = cast(AgentChatAppConfig, app_config) | |||
| app_record = db.session.query(App).filter(App.id == app_config.app_id).first() | |||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| @@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner): | |||
| if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): | |||
| agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING | |||
| conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() | |||
| conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() | |||
| if conversation_result is None: | |||
| raise ValueError("Conversation not found") | |||
| message_result = db.session.query(Message).filter(Message.id == message.id).first() | |||
| message_result = db.session.query(Message).where(Message.id == message.id).first() | |||
| if message_result is None: | |||
| raise ValueError("Message not found") | |||
| db.session.close() | |||
| @@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner): | |||
| app_config = application_generate_entity.app_config | |||
| app_config = cast(ChatAppConfig, app_config) | |||
| app_record = db.session.query(App).filter(App.id == app_config.app_id).first() | |||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| @@ -248,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): | |||
| """ | |||
| message = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| .where( | |||
| Message.id == message_id, | |||
| Message.app_id == app_model.id, | |||
| Message.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||
| @@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner): | |||
| app_config = application_generate_entity.app_config | |||
| app_config = cast(CompletionAppConfig, app_config) | |||
| app_record = db.session.query(App).filter(App.id == app_config.app_id).first() | |||
| app_record = db.session.query(App).where(App.id == app_config.app_id).first() | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| @@ -85,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| if conversation: | |||
| app_model_config = ( | |||
| db.session.query(AppModelConfig) | |||
| .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) | |||
| .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| @@ -151,13 +151,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| introduction = self._get_conversation_introduction(application_generate_entity) | |||
| # get conversation name | |||
| if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): | |||
| query = application_generate_entity.query or "New conversation" | |||
| else: | |||
| query = next(iter(application_generate_entity.inputs.values()), "New conversation") | |||
| if isinstance(query, int): | |||
| query = str(query) | |||
| query = query or "New conversation" | |||
| query = application_generate_entity.query or "New conversation" | |||
| conversation_name = (query[:20] + "…") if len(query) > 20 else query | |||
| if not conversation: | |||
| @@ -259,7 +253,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| :param conversation_id: conversation id | |||
| :return: conversation | |||
| """ | |||
| conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() | |||
| conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() | |||
| if not conversation: | |||
| raise ConversationNotExistsError("Conversation not exists") | |||
| @@ -272,7 +266,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): | |||
| :param message_id: message id | |||
| :return: message | |||
| """ | |||
| message = db.session.query(Message).filter(Message.id == message_id).first() | |||
| message = db.session.query(Message).where(Message.id == message_id).first() | |||
| if message is None: | |||
| raise MessageNotExistsError("Message not exists") | |||
| @@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| from sqlalchemy.orm import sessionmaker | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| import contexts | |||
| from configs import dify_config | |||
| @@ -22,6 +23,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera | |||
| from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline | |||
| from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity | |||
| from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse | |||
| from core.helper.trace_id_helper import extract_external_trace_id_from_args | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from core.ops.ops_trace_manager import TraceQueueManager | |||
| from core.repositories import DifyCoreRepositoryFactory | |||
| @@ -123,6 +125,10 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| ) | |||
| inputs: Mapping[str, Any] = args["inputs"] | |||
| extras = { | |||
| **extract_external_trace_id_from_args(args), | |||
| } | |||
| workflow_run_id = str(uuid.uuid4()) | |||
| # init application generate entity | |||
| application_generate_entity = WorkflowAppGenerateEntity( | |||
| @@ -142,6 +148,7 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| call_depth=call_depth, | |||
| trace_manager=trace_manager, | |||
| workflow_execution_id=workflow_run_id, | |||
| extras=extras, | |||
| ) | |||
| contexts.plugin_tool_providers.set({}) | |||
| @@ -439,17 +446,44 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| """ | |||
| with preserve_flask_contexts(flask_app, context_vars=context): | |||
| try: | |||
| # workflow app | |||
| runner = WorkflowAppRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| variable_loader=variable_loader, | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| workflow = session.scalar( | |||
| select(Workflow).where( | |||
| Workflow.tenant_id == application_generate_entity.app_config.tenant_id, | |||
| Workflow.app_id == application_generate_entity.app_config.app_id, | |||
| Workflow.id == application_generate_entity.app_config.workflow_id, | |||
| ) | |||
| ) | |||
| if workflow is None: | |||
| raise ValueError("Workflow not found") | |||
| # Determine system_user_id based on invocation source | |||
| is_external_api_call = application_generate_entity.invoke_from in { | |||
| InvokeFrom.WEB_APP, | |||
| InvokeFrom.SERVICE_API, | |||
| } | |||
| if is_external_api_call: | |||
| # For external API calls, use end user's session ID | |||
| end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id)) | |||
| system_user_id = end_user.session_id if end_user else "" | |||
| else: | |||
| # For internal calls, use the original user ID | |||
| system_user_id = application_generate_entity.user_id | |||
| runner = WorkflowAppRunner( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| workflow_thread_pool_id=workflow_thread_pool_id, | |||
| variable_loader=variable_loader, | |||
| workflow=workflow, | |||
| system_user_id=system_user_id, | |||
| ) | |||
| try: | |||
| runner.run() | |||
| except GenerateTaskStoppedError: | |||
| except GenerateTaskStoppedError as e: | |||
| logger.warning(f"Task stopped: {str(e)}") | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error( | |||
| @@ -465,8 +499,6 @@ class WorkflowAppGenerator(BaseAppGenerator): | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when generating") | |||
| queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) | |||
| finally: | |||
| db.session.close() | |||
| def _handle_response( | |||
| self, | |||
| @@ -14,10 +14,8 @@ from core.workflow.entities.variable_pool import VariablePool | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import VariableLoader | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.enums import UserFrom | |||
| from models.model import App, EndUser | |||
| from models.workflow import WorkflowType | |||
| from models.workflow import Workflow, WorkflowType | |||
| logger = logging.getLogger(__name__) | |||
| @@ -29,22 +27,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| def __init__( | |||
| self, | |||
| *, | |||
| application_generate_entity: WorkflowAppGenerateEntity, | |||
| queue_manager: AppQueueManager, | |||
| variable_loader: VariableLoader, | |||
| workflow_thread_pool_id: Optional[str] = None, | |||
| workflow: Workflow, | |||
| system_user_id: str, | |||
| ) -> None: | |||
| """ | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :param workflow_thread_pool_id: workflow thread pool id | |||
| """ | |||
| super().__init__(queue_manager, variable_loader) | |||
| super().__init__( | |||
| queue_manager=queue_manager, | |||
| variable_loader=variable_loader, | |||
| app_id=application_generate_entity.app_config.app_id, | |||
| ) | |||
| self.application_generate_entity = application_generate_entity | |||
| self.workflow_thread_pool_id = workflow_thread_pool_id | |||
| def _get_app_id(self) -> str: | |||
| return self.application_generate_entity.app_config.app_id | |||
| self._workflow = workflow | |||
| self._sys_user_id = system_user_id | |||
| def run(self) -> None: | |||
| """ | |||
| @@ -53,24 +52,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| app_config = self.application_generate_entity.app_config | |||
| app_config = cast(WorkflowAppConfig, app_config) | |||
| user_id = None | |||
| if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: | |||
| end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() | |||
| if end_user: | |||
| user_id = end_user.session_id | |||
| else: | |||
| user_id = self.application_generate_entity.user_id | |||
| app_record = db.session.query(App).filter(App.id == app_config.app_id).first() | |||
| if not app_record: | |||
| raise ValueError("App not found") | |||
| workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) | |||
| if not workflow: | |||
| raise ValueError("Workflow not initialized") | |||
| db.session.close() | |||
| workflow_callbacks: list[WorkflowCallback] = [] | |||
| if dify_config.DEBUG: | |||
| workflow_callbacks.append(WorkflowLoggingCallback()) | |||
| @@ -79,14 +60,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| if self.application_generate_entity.single_iteration_run: | |||
| # if only single iteration run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( | |||
| workflow=workflow, | |||
| workflow=self._workflow, | |||
| node_id=self.application_generate_entity.single_iteration_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_iteration_run.inputs, | |||
| ) | |||
| elif self.application_generate_entity.single_loop_run: | |||
| # if only single loop run is requested | |||
| graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( | |||
| workflow=workflow, | |||
| workflow=self._workflow, | |||
| node_id=self.application_generate_entity.single_loop_run.node_id, | |||
| user_inputs=self.application_generate_entity.single_loop_run.inputs, | |||
| ) | |||
| @@ -98,7 +79,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| system_inputs = SystemVariable( | |||
| files=files, | |||
| user_id=user_id, | |||
| user_id=self._sys_user_id, | |||
| app_id=app_config.app_id, | |||
| workflow_id=app_config.workflow_id, | |||
| workflow_execution_id=self.application_generate_entity.workflow_execution_id, | |||
| @@ -107,21 +88,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): | |||
| variable_pool = VariablePool( | |||
| system_variables=system_inputs, | |||
| user_inputs=inputs, | |||
| environment_variables=workflow.environment_variables, | |||
| environment_variables=self._workflow.environment_variables, | |||
| conversation_variables=[], | |||
| ) | |||
| # init graph | |||
| graph = self._init_graph(graph_config=workflow.graph_dict) | |||
| graph = self._init_graph(graph_config=self._workflow.graph_dict) | |||
| # RUN WORKFLOW | |||
| workflow_entry = WorkflowEntry( | |||
| tenant_id=workflow.tenant_id, | |||
| app_id=workflow.app_id, | |||
| workflow_id=workflow.id, | |||
| workflow_type=WorkflowType.value_of(workflow.type), | |||
| tenant_id=self._workflow.tenant_id, | |||
| app_id=self._workflow.app_id, | |||
| workflow_id=self._workflow.id, | |||
| workflow_type=WorkflowType.value_of(self._workflow.type), | |||
| graph=graph, | |||
| graph_config=workflow.graph_dict, | |||
| graph_config=self._workflow.graph_dict, | |||
| user_id=self.application_generate_entity.user_id, | |||
| user_from=( | |||
| UserFrom.ACCOUNT | |||
| @@ -490,6 +490,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| outputs=event.outputs, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), | |||
| ) | |||
| # save workflow app log | |||
| @@ -524,6 +525,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| exceptions_count=event.exceptions_count, | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), | |||
| ) | |||
| # save workflow app log | |||
| @@ -561,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline: | |||
| conversation_id=None, | |||
| trace_manager=trace_manager, | |||
| exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, | |||
| external_trace_id=self._application_generate_entity.extras.get("external_trace_id"), | |||
| ) | |||
| # save workflow app log | |||
| @@ -1,8 +1,7 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional, cast | |||
| from typing import Any, cast | |||
| from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom | |||
| from core.app.apps.base_app_runner import AppRunner | |||
| from core.app.entities.queue_entities import ( | |||
| AppQueueEvent, | |||
| QueueAgentLogEvent, | |||
| @@ -65,18 +64,20 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING | |||
| from core.workflow.system_variable import SystemVariable | |||
| from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool | |||
| from core.workflow.workflow_entry import WorkflowEntry | |||
| from extensions.ext_database import db | |||
| from models.model import App | |||
| from models.workflow import Workflow | |||
| class WorkflowBasedAppRunner(AppRunner): | |||
| def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: | |||
| self.queue_manager = queue_manager | |||
| class WorkflowBasedAppRunner: | |||
| def __init__( | |||
| self, | |||
| *, | |||
| queue_manager: AppQueueManager, | |||
| variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, | |||
| app_id: str, | |||
| ) -> None: | |||
| self._queue_manager = queue_manager | |||
| self._variable_loader = variable_loader | |||
| def _get_app_id(self) -> str: | |||
| raise NotImplementedError("not implemented") | |||
| self._app_id = app_id | |||
| def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: | |||
| """ | |||
| @@ -693,21 +694,5 @@ class WorkflowBasedAppRunner(AppRunner): | |||
| ) | |||
| ) | |||
| def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: | |||
| """ | |||
| Get workflow | |||
| """ | |||
| # fetch workflow by workflow_id | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter( | |||
| Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id | |||
| ) | |||
| .first() | |||
| ) | |||
| # return workflow | |||
| return workflow | |||
| def _publish_event(self, event: AppQueueEvent) -> None: | |||
| self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) | |||
| self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) | |||
| @@ -26,7 +26,7 @@ class AnnotationReplyFeature: | |||
| :return: | |||
| """ | |||
| annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() | |||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() | |||
| ) | |||
| if not annotation_setting: | |||
| @@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): | |||
| :return: | |||
| """ | |||
| agent_thought: Optional[MessageAgentThought] = ( | |||
| db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() | |||
| db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() | |||
| ) | |||
| if agent_thought: | |||
| @@ -81,7 +81,7 @@ class MessageCycleManager: | |||
| def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): | |||
| with flask_app.app_context(): | |||
| # get conversation and message | |||
| conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() | |||
| conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() | |||
| if not conversation: | |||
| return | |||
| @@ -140,7 +140,7 @@ class MessageCycleManager: | |||
| :param event: event | |||
| :return: | |||
| """ | |||
| message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() | |||
| message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first() | |||
| if message_file and message_file.url is not None: | |||
| # get tool file id | |||
| @@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler: | |||
| for document in documents: | |||
| if document.metadata is not None: | |||
| document_id = document.metadata["document_id"] | |||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | |||
| dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() | |||
| if not dataset_document: | |||
| _logger.warning( | |||
| "Expected DatasetDocument record to exist, but none was found, document_id=%s", | |||
| @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .filter( | |||
| .where( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| @@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler: | |||
| if child_chunk: | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == child_chunk.segment_id) | |||
| .where(DocumentSegment.id == child_chunk.segment_id) | |||
| .update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False | |||
| ) | |||
| ) | |||
| else: | |||
| query = db.session.query(DocumentSegment).filter( | |||
| query = db.session.query(DocumentSegment).where( | |||
| DocumentSegment.index_node_id == document.metadata["doc_id"] | |||
| ) | |||
| if "dataset_id" in document.metadata: | |||
| query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) | |||
| query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) | |||
| # add hit count to document segment | |||
| query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) | |||
| @@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel): | |||
| provider_record = ( | |||
| db.session.query(Provider) | |||
| .filter( | |||
| .where( | |||
| Provider.tenant_id == self.tenant_id, | |||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||
| Provider.provider_name.in_(provider_names), | |||
| @@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel): | |||
| provider_model_record = ( | |||
| db.session.query(ProviderModel) | |||
| .filter( | |||
| .where( | |||
| ProviderModel.tenant_id == self.tenant_id, | |||
| ProviderModel.provider_name.in_(provider_names), | |||
| ProviderModel.model_name == model, | |||
| @@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel): | |||
| return ( | |||
| db.session.query(ProviderModelSetting) | |||
| .filter( | |||
| .where( | |||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||
| ProviderModelSetting.provider_name.in_(provider_names), | |||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||
| @@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel): | |||
| return ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .filter( | |||
| .where( | |||
| LoadBalancingModelConfig.tenant_id == self.tenant_id, | |||
| LoadBalancingModelConfig.provider_name.in_(provider_names), | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| @@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel): | |||
| load_balancing_config_count = ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .filter( | |||
| .where( | |||
| LoadBalancingModelConfig.tenant_id == self.tenant_id, | |||
| LoadBalancingModelConfig.provider_name.in_(provider_names), | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| @@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel): | |||
| model_setting = ( | |||
| db.session.query(ProviderModelSetting) | |||
| .filter( | |||
| .where( | |||
| ProviderModelSetting.tenant_id == self.tenant_id, | |||
| ProviderModelSetting.provider_name.in_(provider_names), | |||
| ProviderModelSetting.model_type == model_type.to_origin_model_type(), | |||
| @@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel): | |||
| preferred_model_provider = ( | |||
| db.session.query(TenantPreferredModelProvider) | |||
| .filter( | |||
| .where( | |||
| TenantPreferredModelProvider.tenant_id == self.tenant_id, | |||
| TenantPreferredModelProvider.provider_name.in_(provider_names), | |||
| ) | |||
| @@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool): | |||
| # get api_based_extension | |||
| api_based_extension = ( | |||
| db.session.query(APIBasedExtension) | |||
| .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .first() | |||
| ) | |||
| @@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool): | |||
| # get api_based_extension | |||
| api_based_extension = ( | |||
| db.session.query(APIBasedExtension) | |||
| .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .first() | |||
| ) | |||
| @@ -15,7 +15,7 @@ def encrypt_token(tenant_id: str, token: str): | |||
| from models.account import Tenant | |||
| from models.engine import db | |||
| if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): | |||
| if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): | |||
| raise ValueError(f"Tenant with id {tenant_id} not found") | |||
| encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) | |||
| return base64.b64encode(encrypted_token).decode() | |||
| @@ -25,9 +25,29 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP | |||
| url = str(marketplace_api_url / "api/v1/plugins/batch") | |||
| response = requests.post(url, json={"plugin_ids": plugin_ids}) | |||
| response.raise_for_status() | |||
| return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] | |||
| def batch_fetch_plugin_manifests_ignore_deserialization_error( | |||
| plugin_ids: list[str], | |||
| ) -> Sequence[MarketplacePluginDeclaration]: | |||
| if len(plugin_ids) == 0: | |||
| return [] | |||
| url = str(marketplace_api_url / "api/v1/plugins/batch") | |||
| response = requests.post(url, json={"plugin_ids": plugin_ids}) | |||
| response.raise_for_status() | |||
| result: list[MarketplacePluginDeclaration] = [] | |||
| for plugin in response.json()["data"]["plugins"]: | |||
| try: | |||
| result.append(MarketplacePluginDeclaration(**plugin)) | |||
| except Exception as e: | |||
| pass | |||
| return result | |||
| def record_install_plugin_event(plugin_unique_identifier: str): | |||
| url = str(marketplace_api_url / "api/v1/stats/plugins/install_count") | |||
| response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) | |||
| @@ -0,0 +1,42 @@ | |||
| import re | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| def is_valid_trace_id(trace_id: str) -> bool: | |||
| """ | |||
| Check if the trace_id is valid. | |||
| Requirements: 1-128 characters, only letters, numbers, '-', and '_'. | |||
| """ | |||
| return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id)) | |||
| def get_external_trace_id(request: Any) -> Optional[str]: | |||
| """ | |||
| Retrieve the trace_id from the request. | |||
| Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid. | |||
| """ | |||
| trace_id = request.headers.get("X-Trace-Id") | |||
| if not trace_id: | |||
| trace_id = request.args.get("trace_id") | |||
| if not trace_id and getattr(request, "is_json", False): | |||
| json_data = getattr(request, "json", None) | |||
| if json_data: | |||
| trace_id = json_data.get("trace_id") | |||
| if isinstance(trace_id, str) and is_valid_trace_id(trace_id): | |||
| return trace_id | |||
| return None | |||
| def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: | |||
| """ | |||
| Extract 'external_trace_id' from args. | |||
| Returns a dict suitable for use in extras. Returns an empty dict if not found. | |||
| """ | |||
| trace_id = args.get("external_trace_id") | |||
| if trace_id: | |||
| return {"external_trace_id": trace_id} | |||
| return {} | |||
| @@ -59,7 +59,7 @@ class IndexingRunner: | |||
| # get the process rule | |||
| processing_rule = ( | |||
| db.session.query(DatasetProcessRule) | |||
| .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| if not processing_rule: | |||
| @@ -119,12 +119,12 @@ class IndexingRunner: | |||
| db.session.delete(document_segment) | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| # delete child chunks | |||
| db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() | |||
| db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() | |||
| db.session.commit() | |||
| # get the process rule | |||
| processing_rule = ( | |||
| db.session.query(DatasetProcessRule) | |||
| .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| if not processing_rule: | |||
| @@ -212,7 +212,7 @@ class IndexingRunner: | |||
| # get the process rule | |||
| processing_rule = ( | |||
| db.session.query(DatasetProcessRule) | |||
| .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| @@ -316,7 +316,7 @@ class IndexingRunner: | |||
| # delete image files and related db records | |||
| image_upload_file_ids = get_image_upload_file_ids(document.page_content) | |||
| for upload_file_id in image_upload_file_ids: | |||
| image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() | |||
| image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() | |||
| if image_file is None: | |||
| continue | |||
| try: | |||
| @@ -346,7 +346,7 @@ class IndexingRunner: | |||
| raise ValueError("no upload file found") | |||
| file_detail = ( | |||
| db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() | |||
| db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() | |||
| ) | |||
| if file_detail: | |||
| @@ -599,7 +599,7 @@ class IndexingRunner: | |||
| keyword.create(documents) | |||
| if dataset.indexing_technique != "high_quality": | |||
| document_ids = [document.metadata["doc_id"] for document in documents] | |||
| db.session.query(DocumentSegment).filter( | |||
| db.session.query(DocumentSegment).where( | |||
| DocumentSegment.document_id == document_id, | |||
| DocumentSegment.dataset_id == dataset_id, | |||
| DocumentSegment.index_node_id.in_(document_ids), | |||
| @@ -630,7 +630,7 @@ class IndexingRunner: | |||
| index_processor.load(dataset, chunk_documents, with_keywords=False) | |||
| document_ids = [document.metadata["doc_id"] for document in chunk_documents] | |||
| db.session.query(DocumentSegment).filter( | |||
| db.session.query(DocumentSegment).where( | |||
| DocumentSegment.document_id == dataset_document.id, | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.index_node_id.in_(document_ids), | |||
| @@ -672,8 +672,7 @@ class IndexingRunner: | |||
| if extra_update_params: | |||
| update_params.update(extra_update_params) | |||
| db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) | |||
| db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore | |||
| db.session.commit() | |||
| @staticmethod | |||
| @@ -114,7 +114,8 @@ class LLMGenerator: | |||
| ), | |||
| ) | |||
| questions = output_parser.parse(cast(str, response.message.content)) | |||
| text_content = response.message.get_text_content() | |||
| questions = output_parser.parse(text_content) if text_content else [] | |||
| except InvokeError: | |||
| questions = [] | |||
| except Exception: | |||
| @@ -15,5 +15,4 @@ class SuggestedQuestionsAfterAnswerOutputParser: | |||
| json_obj = json.loads(action_match.group(0).strip()) | |||
| else: | |||
| json_obj = [] | |||
| return json_obj | |||
| @@ -28,7 +28,7 @@ class MCPServerStreamableHTTPRequestHandler: | |||
| ): | |||
| self.app = app | |||
| self.request = request | |||
| mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first() | |||
| mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first() | |||
| if not mcp_server: | |||
| raise ValueError("MCP server not found") | |||
| self.mcp_server: AppMCPServer = mcp_server | |||
| @@ -192,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler: | |||
| def retrieve_end_user(self): | |||
| return ( | |||
| db.session.query(EndUser) | |||
| .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") | |||
| .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") | |||
| .first() | |||
| ) | |||
| @@ -67,7 +67,7 @@ class TokenBufferMemory: | |||
| prompt_messages: list[PromptMessage] = [] | |||
| for message in messages: | |||
| files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() | |||
| files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() | |||
| if files: | |||
| file_extra_config = None | |||
| if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: | |||
| @@ -156,6 +156,23 @@ class PromptMessage(ABC, BaseModel): | |||
| """ | |||
| return not self.content | |||
| def get_text_content(self) -> str: | |||
| """ | |||
| Get text content from prompt message. | |||
| :return: Text content as string, empty string if no text content | |||
| """ | |||
| if isinstance(self.content, str): | |||
| return self.content | |||
| elif isinstance(self.content, list): | |||
| text_parts = [] | |||
| for item in self.content: | |||
| if isinstance(item, TextPromptMessageContent): | |||
| text_parts.append(item.data) | |||
| return "".join(text_parts) | |||
| else: | |||
| return "" | |||
| @field_validator("content", mode="before") | |||
| @classmethod | |||
| def validate_content(cls, v): | |||
| @@ -89,7 +89,7 @@ class ApiModeration(Moderation): | |||
| def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: | |||
| extension = ( | |||
| db.session.query(APIBasedExtension) | |||
| .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) | |||
| .first() | |||
| ) | |||
| @@ -101,7 +101,8 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| raise ValueError(f"Aliyun get run url failed: {str(e)}") | |||
| def workflow_trace(self, trace_info: WorkflowTraceInfo): | |||
| trace_id = convert_to_trace_id(trace_info.workflow_run_id) | |||
| external_trace_id = trace_info.metadata.get("external_trace_id") | |||
| trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id) | |||
| workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") | |||
| self.add_workflow_span(trace_id, workflow_span_id, trace_info) | |||
| @@ -119,7 +120,7 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| user_id = message_data.from_account_id | |||
| if message_data.from_end_user_id: | |||
| end_user_data: Optional[EndUser] = ( | |||
| db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() | |||
| db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() | |||
| ) | |||
| if end_user_data is not None: | |||
| user_id = end_user_data.session_id | |||
| @@ -243,14 +244,14 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| if not app_id: | |||
| raise ValueError("No app_id found in trace_info metadata") | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| app = session.query(App).where(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| service_account = session.query(Account).where(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| current_tenant = ( | |||
| @@ -153,7 +153,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): | |||
| } | |||
| workflow_metadata.update(trace_info.metadata) | |||
| trace_id = uuid_to_trace_id(trace_info.workflow_run_id) | |||
| external_trace_id = trace_info.metadata.get("external_trace_id") | |||
| trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id) | |||
| span_id = RandomIdGenerator().generate_span_id() | |||
| context = SpanContext( | |||
| trace_id=trace_id, | |||
| @@ -296,7 +297,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): | |||
| # Add end user data if available | |||
| if trace_info.message_data.from_end_user_id: | |||
| end_user_data: Optional[EndUser] = ( | |||
| db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first() | |||
| db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() | |||
| ) | |||
| if end_user_data is not None: | |||
| message_metadata["end_user_id"] = end_user_data.session_id | |||
| @@ -702,7 +703,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): | |||
| WorkflowNodeExecutionModel.process_data, | |||
| WorkflowNodeExecutionModel.execution_metadata, | |||
| ) | |||
| .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) | |||
| .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) | |||
| .all() | |||
| ) | |||
| return workflow_nodes | |||
| @@ -44,14 +44,14 @@ class BaseTraceInstance(ABC): | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| # Get the app to find its creator | |||
| app = session.query(App).filter(App.id == app_id).first() | |||
| app = session.query(App).where(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError(f"App with id {app_id} not found") | |||
| if not app.created_by: | |||
| raise ValueError(f"App with id {app_id} has no creator (created_by is None)") | |||
| service_account = session.query(Account).filter(Account.id == app.created_by).first() | |||
| service_account = session.query(Account).where(Account.id == app.created_by).first() | |||
| if not service_account: | |||
| raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") | |||
| @@ -67,13 +67,14 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| self.generate_name_trace(trace_info) | |||
| def workflow_trace(self, trace_info: WorkflowTraceInfo): | |||
| trace_id = trace_info.workflow_run_id | |||
| external_trace_id = trace_info.metadata.get("external_trace_id") | |||
| trace_id = external_trace_id or trace_info.workflow_run_id | |||
| user_id = trace_info.metadata.get("user_id") | |||
| metadata = trace_info.metadata | |||
| metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id | |||
| if trace_info.message_id: | |||
| trace_id = trace_info.message_id | |||
| trace_id = external_trace_id or trace_info.message_id | |||
| name = TraceTaskName.MESSAGE_TRACE.value | |||
| trace_data = LangfuseTrace( | |||
| id=trace_id, | |||
| @@ -243,7 +244,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| user_id = message_data.from_account_id | |||
| if message_data.from_end_user_id: | |||
| end_user_data: Optional[EndUser] = ( | |||
| db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() | |||
| db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() | |||
| ) | |||
| if end_user_data is not None: | |||
| user_id = end_user_data.session_id | |||
| @@ -65,7 +65,8 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| self.generate_name_trace(trace_info) | |||
| def workflow_trace(self, trace_info: WorkflowTraceInfo): | |||
| trace_id = trace_info.message_id or trace_info.workflow_run_id | |||
| external_trace_id = trace_info.metadata.get("external_trace_id") | |||
| trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id | |||
| if trace_info.start_time is None: | |||
| trace_info.start_time = datetime.now() | |||
| message_dotted_order = ( | |||
| @@ -261,7 +262,7 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| if message_data.from_end_user_id: | |||
| end_user_data: Optional[EndUser] = ( | |||
| db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() | |||
| db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() | |||
| ) | |||
| if end_user_data is not None: | |||
| end_user_id = end_user_data.session_id | |||
| @@ -96,7 +96,8 @@ class OpikDataTrace(BaseTraceInstance): | |||
| self.generate_name_trace(trace_info) | |||
| def workflow_trace(self, trace_info: WorkflowTraceInfo): | |||
| dify_trace_id = trace_info.workflow_run_id | |||
| external_trace_id = trace_info.metadata.get("external_trace_id") | |||
| dify_trace_id = external_trace_id or trace_info.workflow_run_id | |||
| opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) | |||
| workflow_metadata = wrap_metadata( | |||
| trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id | |||
| @@ -104,7 +105,7 @@ class OpikDataTrace(BaseTraceInstance): | |||
| root_span_id = None | |||
| if trace_info.message_id: | |||
| dify_trace_id = trace_info.message_id | |||
| dify_trace_id = external_trace_id or trace_info.message_id | |||
| opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) | |||
| trace_data = { | |||
| @@ -283,7 +284,7 @@ class OpikDataTrace(BaseTraceInstance): | |||
| if message_data.from_end_user_id: | |||
| end_user_data: Optional[EndUser] = ( | |||
| db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() | |||
| db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() | |||
| ) | |||
| if end_user_data is not None: | |||
| end_user_id = end_user_data.session_id | |||
| @@ -218,7 +218,7 @@ class OpsTraceManager: | |||
| """ | |||
| trace_config_data: Optional[TraceAppConfig] = ( | |||
| db.session.query(TraceAppConfig) | |||
| .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) | |||
| .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) | |||
| .first() | |||
| ) | |||
| @@ -226,7 +226,7 @@ class OpsTraceManager: | |||
| return None | |||
| # decrypt_token | |||
| app = db.session.query(App).filter(App.id == app_id).first() | |||
| app = db.session.query(App).where(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError("App not found") | |||
| @@ -253,7 +253,7 @@ class OpsTraceManager: | |||
| if app_id is None: | |||
| return None | |||
| app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() | |||
| app: Optional[App] = db.session.query(App).where(App.id == app_id).first() | |||
| if app is None: | |||
| return None | |||
| @@ -293,18 +293,18 @@ class OpsTraceManager: | |||
| @classmethod | |||
| def get_app_config_through_message_id(cls, message_id: str): | |||
| app_model_config = None | |||
| message_data = db.session.query(Message).filter(Message.id == message_id).first() | |||
| message_data = db.session.query(Message).where(Message.id == message_id).first() | |||
| if not message_data: | |||
| return None | |||
| conversation_id = message_data.conversation_id | |||
| conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() | |||
| conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() | |||
| if not conversation_data: | |||
| return None | |||
| if conversation_data.app_model_config_id: | |||
| app_model_config = ( | |||
| db.session.query(AppModelConfig) | |||
| .filter(AppModelConfig.id == conversation_data.app_model_config_id) | |||
| .where(AppModelConfig.id == conversation_data.app_model_config_id) | |||
| .first() | |||
| ) | |||
| elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: | |||
| @@ -331,7 +331,7 @@ class OpsTraceManager: | |||
| if tracing_provider is not None: | |||
| raise ValueError(f"Invalid tracing provider: {tracing_provider}") | |||
| app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() | |||
| app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first() | |||
| if not app_config: | |||
| raise ValueError("App not found") | |||
| app_config.tracing = json.dumps( | |||
| @@ -349,7 +349,7 @@ class OpsTraceManager: | |||
| :param app_id: app id | |||
| :return: | |||
| """ | |||
| app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() | |||
| app: Optional[App] = db.session.query(App).where(App.id == app_id).first() | |||
| if not app: | |||
| raise ValueError("App not found") | |||
| if not app.tracing: | |||
| @@ -520,6 +520,10 @@ class TraceTask: | |||
| "app_id": workflow_run.app_id, | |||
| } | |||
| external_trace_id = self.kwargs.get("external_trace_id") | |||
| if external_trace_id: | |||
| metadata["external_trace_id"] = external_trace_id | |||
| workflow_trace_info = WorkflowTraceInfo( | |||
| workflow_data=workflow_run.to_dict(), | |||
| conversation_id=conversation_id, | |||
| @@ -3,6 +3,8 @@ from datetime import datetime | |||
| from typing import Optional, Union | |||
| from urllib.parse import urlparse | |||
| from sqlalchemy import select | |||
| from extensions.ext_database import db | |||
| from models.model import Message | |||
| @@ -20,7 +22,7 @@ def filter_none_values(data: dict): | |||
| def get_message_data(message_id: str): | |||
| return db.session.query(Message).filter(Message.id == message_id).first() | |||
| return db.session.scalar(select(Message).where(Message.id == message_id)) | |||
| @contextmanager | |||
| @@ -87,7 +87,8 @@ class WeaveDataTrace(BaseTraceInstance): | |||
| self.generate_name_trace(trace_info) | |||
| def workflow_trace(self, trace_info: WorkflowTraceInfo): | |||
| trace_id = trace_info.message_id or trace_info.workflow_run_id | |||
| external_trace_id = trace_info.metadata.get("external_trace_id") | |||
| trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id | |||
| if trace_info.start_time is None: | |||
| trace_info.start_time = datetime.now() | |||
| @@ -234,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance): | |||
| if message_data.from_end_user_id: | |||
| end_user_data: Optional[EndUser] = ( | |||
| db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() | |||
| db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() | |||
| ) | |||
| if end_user_data is not None: | |||
| end_user_id = end_user_data.session_id | |||
| @@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): | |||
| get the user by user id | |||
| """ | |||
| user = db.session.query(EndUser).filter(EndUser.id == user_id).first() | |||
| user = db.session.query(EndUser).where(EndUser.id == user_id).first() | |||
| if not user: | |||
| user = db.session.query(Account).filter(Account.id == user_id).first() | |||
| user = db.session.query(Account).where(Account.id == user_id).first() | |||
| if not user: | |||
| raise ValueError("user not found") | |||
| @@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): | |||
| get app | |||
| """ | |||
| try: | |||
| app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first() | |||
| app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() | |||
| except Exception: | |||
| raise ValueError("app not found") | |||
| @@ -194,6 +194,7 @@ class PluginOAuthCredentialsResponse(BaseModel): | |||
| metadata: Mapping[str, Any] = Field( | |||
| default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc." | |||
| ) | |||
| expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.") | |||
| credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") | |||
| @@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient): | |||
| except Exception as e: | |||
| raise ValueError(f"Error getting credentials: {e}") | |||
| def refresh_credentials( | |||
| self, | |||
| tenant_id: str, | |||
| user_id: str, | |||
| plugin_id: str, | |||
| provider: str, | |||
| redirect_uri: str, | |||
| system_credentials: Mapping[str, Any], | |||
| credentials: Mapping[str, Any], | |||
| ) -> PluginOAuthCredentialsResponse: | |||
| try: | |||
| response = self._request_with_plugin_daemon_response_stream( | |||
| "POST", | |||
| f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials", | |||
| PluginOAuthCredentialsResponse, | |||
| data={ | |||
| "user_id": user_id, | |||
| "data": { | |||
| "provider": provider, | |||
| "redirect_uri": redirect_uri, | |||
| "system_credentials": system_credentials, | |||
| "credentials": credentials, | |||
| }, | |||
| }, | |||
| headers={ | |||
| "X-Plugin-ID": plugin_id, | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| for resp in response: | |||
| return resp | |||
| raise ValueError("No response received from plugin daemon for refresh credentials request.") | |||
| except Exception as e: | |||
| raise ValueError(f"Error refreshing credentials: {e}") | |||
| def _convert_request_to_raw_data(self, request: Request) -> bytes: | |||
| """ | |||
| Convert a Request object to raw HTTP data. | |||
| @@ -275,7 +275,7 @@ class ProviderManager: | |||
| # Get the corresponding TenantDefaultModel record | |||
| default_model = ( | |||
| db.session.query(TenantDefaultModel) | |||
| .filter( | |||
| .where( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||
| ) | |||
| @@ -367,7 +367,7 @@ class ProviderManager: | |||
| # Get the list of available models from get_configurations and check if it is LLM | |||
| default_model = ( | |||
| db.session.query(TenantDefaultModel) | |||
| .filter( | |||
| .where( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.to_origin_model_type(), | |||
| ) | |||
| @@ -541,7 +541,7 @@ class ProviderManager: | |||
| db.session.rollback() | |||
| existed_provider_record = ( | |||
| db.session.query(Provider) | |||
| .filter( | |||
| .where( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == ModelProviderID(provider_name).provider_name, | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| @@ -94,11 +94,11 @@ class Jieba(BaseKeyword): | |||
| documents = [] | |||
| for chunk_index in sorted_chunk_indices: | |||
| segment_query = db.session.query(DocumentSegment).filter( | |||
| segment_query = db.session.query(DocumentSegment).where( | |||
| DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index | |||
| ) | |||
| if document_ids_filter: | |||
| segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter)) | |||
| segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter)) | |||
| segment = segment_query.first() | |||
| if segment: | |||
| @@ -215,7 +215,7 @@ class Jieba(BaseKeyword): | |||
| def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): | |||
| document_segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) | |||
| .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) | |||
| .first() | |||
| ) | |||
| if document_segment: | |||
| @@ -127,7 +127,7 @@ class RetrievalService: | |||
| external_retrieval_model: Optional[dict] = None, | |||
| metadata_filtering_conditions: Optional[dict] = None, | |||
| ): | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| return [] | |||
| metadata_condition = ( | |||
| @@ -145,7 +145,7 @@ class RetrievalService: | |||
| @classmethod | |||
| def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: | |||
| with Session(db.engine) as session: | |||
| return session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| return session.query(Dataset).where(Dataset.id == dataset_id).first() | |||
| @classmethod | |||
| def keyword_search( | |||
| @@ -294,7 +294,7 @@ class RetrievalService: | |||
| dataset_documents = { | |||
| doc.id: doc | |||
| for doc in db.session.query(DatasetDocument) | |||
| .filter(DatasetDocument.id.in_(document_ids)) | |||
| .where(DatasetDocument.id.in_(document_ids)) | |||
| .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) | |||
| .all() | |||
| } | |||
| @@ -318,7 +318,7 @@ class RetrievalService: | |||
| child_index_node_id = document.metadata.get("doc_id") | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() | |||
| db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first() | |||
| ) | |||
| if not child_chunk: | |||
| @@ -326,7 +326,7 @@ class RetrievalService: | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| @@ -381,7 +381,7 @@ class RetrievalService: | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| @@ -6,7 +6,7 @@ from uuid import UUID, uuid4 | |||
| from numpy import ndarray | |||
| from pgvecto_rs.sqlalchemy import VECTOR # type: ignore | |||
| from pydantic import BaseModel, model_validator | |||
| from sqlalchemy import Float, String, create_engine, insert, select, text | |||
| from sqlalchemy import Float, create_engine, insert, select, text | |||
| from sqlalchemy import text as sql_text | |||
| from sqlalchemy.dialects import postgresql | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column | |||
| @@ -67,7 +67,7 @@ class PGVectoRS(BaseVector): | |||
| postgresql.UUID(as_uuid=True), | |||
| primary_key=True, | |||
| ) | |||
| text: Mapped[str] = mapped_column(String) | |||
| text: Mapped[str] | |||
| meta: Mapped[dict] = mapped_column(postgresql.JSONB) | |||
| vector: Mapped[ndarray] = mapped_column(VECTOR(dim)) | |||
| @@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory): | |||
| if dataset.collection_binding_id: | |||
| dataset_collection_binding = ( | |||
| db.session.query(DatasetCollectionBinding) | |||
| .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) | |||
| .where(DatasetCollectionBinding.id == dataset.collection_binding_id) | |||
| .one_or_none() | |||
| ) | |||
| if dataset_collection_binding: | |||
| @@ -118,10 +118,21 @@ class TableStoreVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 4) | |||
| return self._search_by_vector(query_vector, top_k) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| filtered_list = None | |||
| if document_ids_filter: | |||
| filtered_list = ["document_id=" + item for item in document_ids_filter] | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold) | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| return self._search_by_full_text(query) | |||
| top_k = kwargs.get("top_k", 4) | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| filtered_list = None | |||
| if document_ids_filter: | |||
| filtered_list = ["document_id=" + item for item in document_ids_filter] | |||
| return self._search_by_full_text(query, filtered_list, top_k) | |||
| def delete(self) -> None: | |||
| self._delete_table_if_exist() | |||
| @@ -230,32 +241,51 @@ class TableStoreVector(BaseVector): | |||
| primary_key = [("id", id)] | |||
| row = tablestore.Row(primary_key) | |||
| self._tablestore_client.delete_row(self._table_name, row, None) | |||
| logging.info("Tablestore delete row successfully. id:%s", id) | |||
| def _search_by_metadata(self, key: str, value: str) -> list[str]: | |||
| query = tablestore.SearchQuery( | |||
| tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)), | |||
| limit=100, | |||
| limit=1000, | |||
| get_total_count=False, | |||
| ) | |||
| rows: list[str] = [] | |||
| next_token = None | |||
| while True: | |||
| if next_token is not None: | |||
| query.next_token = next_token | |||
| search_response = self._tablestore_client.search( | |||
| table_name=self._table_name, | |||
| index_name=self._index_name, | |||
| search_query=query, | |||
| columns_to_get=tablestore.ColumnsToGet( | |||
| column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED | |||
| ), | |||
| ) | |||
| search_response = self._tablestore_client.search( | |||
| table_name=self._table_name, | |||
| index_name=self._index_name, | |||
| search_query=query, | |||
| columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), | |||
| ) | |||
| if search_response is not None: | |||
| rows.extend([row[0][0][1] for row in search_response.rows]) | |||
| return [row[0][0][1] for row in search_response.rows] | |||
| if search_response is None or search_response.next_token == b"": | |||
| break | |||
| else: | |||
| next_token = search_response.next_token | |||
| def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]: | |||
| ots_query = tablestore.KnnVectorQuery( | |||
| return rows | |||
| def _search_by_vector( | |||
| self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float | |||
| ) -> list[Document]: | |||
| knn_vector_query = tablestore.KnnVectorQuery( | |||
| field_name=Field.VECTOR.value, | |||
| top_k=top_k, | |||
| float32_query_vector=query_vector, | |||
| ) | |||
| if document_ids_filter: | |||
| knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter) | |||
| sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]) | |||
| search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort) | |||
| search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort) | |||
| search_response = self._tablestore_client.search( | |||
| table_name=self._table_name, | |||
| @@ -263,30 +293,42 @@ class TableStoreVector(BaseVector): | |||
| search_query=search_query, | |||
| columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), | |||
| ) | |||
| logging.info( | |||
| "Tablestore search successfully. request_id:%s", | |||
| search_response.request_id, | |||
| ) | |||
| return self._to_query_result(search_response) | |||
| def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]: | |||
| documents = [] | |||
| for row in search_response.rows: | |||
| documents.append( | |||
| Document( | |||
| page_content=row[1][2][1], | |||
| vector=json.loads(row[1][3][1]), | |||
| metadata=json.loads(row[1][0][1]), | |||
| ) | |||
| ) | |||
| for search_hit in search_response.search_hits: | |||
| if search_hit.score > score_threshold: | |||
| ots_column_map = {} | |||
| for col in search_hit.row[1]: | |||
| ots_column_map[col[0]] = col[1] | |||
| vector_str = ots_column_map.get(Field.VECTOR.value) | |||
| metadata_str = ots_column_map.get(Field.METADATA_KEY.value) | |||
| vector = json.loads(vector_str) if vector_str else None | |||
| metadata = json.loads(metadata_str) if metadata_str else {} | |||
| metadata["score"] = search_hit.score | |||
| documents.append( | |||
| Document( | |||
| page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", | |||
| vector=vector, | |||
| metadata=metadata, | |||
| ) | |||
| ) | |||
| documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True) | |||
| return documents | |||
| def _search_by_full_text(self, query: str) -> list[Document]: | |||
| def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: | |||
| bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[]) | |||
| bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) | |||
| if document_ids_filter: | |||
| bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter)) | |||
| search_query = tablestore.SearchQuery( | |||
| query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value), | |||
| query=bool_query, | |||
| sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]), | |||
| limit=100, | |||
| limit=top_k, | |||
| ) | |||
| search_response = self._tablestore_client.search( | |||
| table_name=self._table_name, | |||
| @@ -295,7 +337,25 @@ class TableStoreVector(BaseVector): | |||
| columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX), | |||
| ) | |||
| return self._to_query_result(search_response) | |||
| documents = [] | |||
| for search_hit in search_response.search_hits: | |||
| ots_column_map = {} | |||
| for col in search_hit.row[1]: | |||
| ots_column_map[col[0]] = col[1] | |||
| vector_str = ots_column_map.get(Field.VECTOR.value) | |||
| metadata_str = ots_column_map.get(Field.METADATA_KEY.value) | |||
| vector = json.loads(vector_str) if vector_str else None | |||
| metadata = json.loads(metadata_str) if metadata_str else {} | |||
| documents.append( | |||
| Document( | |||
| page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "", | |||
| vector=vector, | |||
| metadata=metadata, | |||
| ) | |||
| ) | |||
| return documents | |||
| class TableStoreVectorFactory(AbstractVectorFactory): | |||
| @@ -284,7 +284,8 @@ class TencentVector(BaseVector): | |||
| # Compatible with version 1.1.3 and below. | |||
| meta = json.loads(meta) | |||
| score = 1 - result.get("score", 0.0) | |||
| score = result.get("score", 0.0) | |||
| else: | |||
| score = result.get("score", 0.0) | |||
| if score > score_threshold: | |||
| meta["score"] = score | |||
| doc = Document(page_content=result.get(self.field_text), metadata=meta) | |||
| @@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector): | |||
| class TidbOnQdrantVectorFactory(AbstractVectorFactory): | |||
| def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: | |||
| tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | |||
| db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() | |||
| ) | |||
| if not tidb_auth_binding: | |||
| with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): | |||
| tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .filter(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||
| .where(TidbAuthBinding.tenant_id == dataset.tenant_id) | |||
| .one_or_none() | |||
| ) | |||
| if tidb_auth_binding: | |||
| @@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): | |||
| else: | |||
| idle_tidb_auth_binding = ( | |||
| db.session.query(TidbAuthBinding) | |||
| .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") | |||
| .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") | |||
| .limit(1) | |||
| .one_or_none() | |||
| ) | |||
| @@ -47,7 +47,7 @@ class Vector: | |||
| if dify_config.VECTOR_STORE_WHITELIST_ENABLE: | |||
| whitelist = ( | |||
| db.session.query(Whitelist) | |||
| .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") | |||
| .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") | |||
| .one_or_none() | |||
| ) | |||
| if whitelist: | |||
| @@ -42,7 +42,7 @@ class DatasetDocumentStore: | |||
| @property | |||
| def docs(self) -> dict[str, Document]: | |||
| document_segments = ( | |||
| db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() | |||
| db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all() | |||
| ) | |||
| output = {} | |||
| @@ -63,7 +63,7 @@ class DatasetDocumentStore: | |||
| def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: | |||
| max_position = ( | |||
| db.session.query(func.max(DocumentSegment.position)) | |||
| .filter(DocumentSegment.document_id == self._document_id) | |||
| .where(DocumentSegment.document_id == self._document_id) | |||
| .scalar() | |||
| ) | |||
| @@ -147,7 +147,7 @@ class DatasetDocumentStore: | |||
| segment_document.tokens = tokens | |||
| if save_child and doc.children: | |||
| # delete the existing child chunks | |||
| db.session.query(ChildChunk).filter( | |||
| db.session.query(ChildChunk).where( | |||
| ChildChunk.tenant_id == self._dataset.tenant_id, | |||
| ChildChunk.dataset_id == self._dataset.id, | |||
| ChildChunk.document_id == self._document_id, | |||
| @@ -230,7 +230,7 @@ class DatasetDocumentStore: | |||
| def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: | |||
| document_segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) | |||
| .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) | |||
| .first() | |||
| ) | |||
| @@ -331,9 +331,10 @@ class NotionExtractor(BaseExtractor): | |||
| last_edited_time = self.get_notion_last_edited_time() | |||
| data_source_info = document_model.data_source_info_dict | |||
| data_source_info["last_edited_time"] = last_edited_time | |||
| update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} | |||
| db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params) | |||
| db.session.query(DocumentModel).filter_by(id=document_model.id).update( | |||
| {DocumentModel.data_source_info: json.dumps(data_source_info)} | |||
| ) # type: ignore | |||
| db.session.commit() | |||
| def get_notion_last_edited_time(self) -> str: | |||
| @@ -365,7 +366,7 @@ class NotionExtractor(BaseExtractor): | |||
| def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| .where( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| @@ -121,7 +121,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): | |||
| child_node_ids = ( | |||
| db.session.query(ChildChunk.index_node_id) | |||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.index_node_id.in_(node_ids), | |||
| ChildChunk.dataset_id == dataset.id, | |||
| @@ -131,7 +131,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): | |||
| child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] | |||
| vector.delete_by_ids(child_node_ids) | |||
| if delete_child_chunks: | |||
| db.session.query(ChildChunk).filter( | |||
| db.session.query(ChildChunk).where( | |||
| ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) | |||
| ).delete() | |||
| db.session.commit() | |||
| @@ -139,7 +139,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): | |||
| vector.delete() | |||
| if delete_child_chunks: | |||
| db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() | |||
| db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete() | |||
| db.session.commit() | |||
| def retrieve( | |||
| @@ -135,7 +135,7 @@ class DatasetRetrieval: | |||
| available_datasets = [] | |||
| for dataset_id in dataset_ids: | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| # pass if dataset is not available | |||
| if not dataset: | |||
| @@ -242,7 +242,7 @@ class DatasetRetrieval: | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| document = ( | |||
| db.session.query(DatasetDocument) | |||
| .filter( | |||
| .where( | |||
| DatasetDocument.id == segment.document_id, | |||
| DatasetDocument.enabled == True, | |||
| DatasetDocument.archived == False, | |||
| @@ -327,7 +327,7 @@ class DatasetRetrieval: | |||
| if dataset_id: | |||
| # get retrieval model config | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() | |||
| if dataset: | |||
| results = [] | |||
| if dataset.provider == "external": | |||
| @@ -516,14 +516,14 @@ class DatasetRetrieval: | |||
| if document.metadata is not None: | |||
| dataset_document = ( | |||
| db.session.query(DatasetDocument) | |||
| .filter(DatasetDocument.id == document.metadata["document_id"]) | |||
| .where(DatasetDocument.id == document.metadata["document_id"]) | |||
| .first() | |||
| ) | |||
| if dataset_document: | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunk = ( | |||
| db.session.query(ChildChunk) | |||
| .filter( | |||
| .where( | |||
| ChildChunk.index_node_id == document.metadata["doc_id"], | |||
| ChildChunk.dataset_id == dataset_document.dataset_id, | |||
| ChildChunk.document_id == dataset_document.id, | |||
| @@ -533,7 +533,7 @@ class DatasetRetrieval: | |||
| if child_chunk: | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.id == child_chunk.segment_id) | |||
| .where(DocumentSegment.id == child_chunk.segment_id) | |||
| .update( | |||
| {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, | |||
| synchronize_session=False, | |||
| @@ -541,13 +541,13 @@ class DatasetRetrieval: | |||
| ) | |||
| db.session.commit() | |||
| else: | |||
| query = db.session.query(DocumentSegment).filter( | |||
| query = db.session.query(DocumentSegment).where( | |||
| DocumentSegment.index_node_id == document.metadata["doc_id"] | |||
| ) | |||
| # if 'dataset_id' in document.metadata: | |||
| if "dataset_id" in document.metadata: | |||
| query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) | |||
| query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) | |||
| # add hit count to document segment | |||
| query.update( | |||
| @@ -600,7 +600,7 @@ class DatasetRetrieval: | |||
| ): | |||
| with flask_app.app_context(): | |||
| with Session(db.engine) as session: | |||
| dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| return [] | |||
| @@ -685,7 +685,7 @@ class DatasetRetrieval: | |||
| available_datasets = [] | |||
| for dataset_id in dataset_ids: | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| # pass if dataset is not available | |||
| if not dataset: | |||
| @@ -862,7 +862,7 @@ class DatasetRetrieval: | |||
| metadata_filtering_conditions: Optional[MetadataFilteringCondition], | |||
| inputs: dict, | |||
| ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: | |||
| document_query = db.session.query(DatasetDocument).filter( | |||
| document_query = db.session.query(DatasetDocument).where( | |||
| DatasetDocument.dataset_id.in_(dataset_ids), | |||
| DatasetDocument.indexing_status == "completed", | |||
| DatasetDocument.enabled == True, | |||
| @@ -930,9 +930,9 @@ class DatasetRetrieval: | |||
| raise ValueError("Invalid metadata filtering mode") | |||
| if filters: | |||
| if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore | |||
| document_query = document_query.filter(and_(*filters)) | |||
| document_query = document_query.where(and_(*filters)) | |||
| else: | |||
| document_query = document_query.filter(or_(*filters)) | |||
| document_query = document_query.where(or_(*filters)) | |||
| documents = document_query.all() | |||
| # group by dataset_id | |||
| metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore | |||
| @@ -958,7 +958,7 @@ class DatasetRetrieval: | |||
| self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig | |||
| ) -> Optional[list[dict[str, Any]]]: | |||
| # get all metadata field | |||
| metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||
| metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all() | |||
| all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] | |||
| # get metadata model config | |||
| if metadata_model_config is None: | |||
| @@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController): | |||
| # get tenant api providers | |||
| db_providers: list[ApiToolProvider] = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) | |||
| .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) | |||
| .all() | |||
| ) | |||
| @@ -160,7 +160,7 @@ class ToolFileManager: | |||
| with Session(self._engine, expire_on_commit=False) as session: | |||
| tool_file: ToolFile | None = ( | |||
| session.query(ToolFile) | |||
| .filter( | |||
| .where( | |||
| ToolFile.id == id, | |||
| ) | |||
| .first() | |||
| @@ -184,7 +184,7 @@ class ToolFileManager: | |||
| with Session(self._engine, expire_on_commit=False) as session: | |||
| message_file: MessageFile | None = ( | |||
| session.query(MessageFile) | |||
| .filter( | |||
| .where( | |||
| MessageFile.id == id, | |||
| ) | |||
| .first() | |||
| @@ -204,7 +204,7 @@ class ToolFileManager: | |||
| tool_file: ToolFile | None = ( | |||
| session.query(ToolFile) | |||
| .filter( | |||
| .where( | |||
| ToolFile.id == tool_file_id, | |||
| ) | |||
| .first() | |||
| @@ -228,7 +228,7 @@ class ToolFileManager: | |||
| with Session(self._engine, expire_on_commit=False) as session: | |||
| tool_file: ToolFile | None = ( | |||
| session.query(ToolFile) | |||
| .filter( | |||
| .where( | |||
| ToolFile.id == tool_file_id, | |||
| ) | |||
| .first() | |||
| @@ -29,7 +29,7 @@ class ToolLabelManager: | |||
| raise ValueError("Unsupported tool type") | |||
| # delete old labels | |||
| db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() | |||
| db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete() | |||
| # insert new labels | |||
| for label in labels: | |||
| @@ -57,7 +57,7 @@ class ToolLabelManager: | |||
| labels = ( | |||
| db.session.query(ToolLabelBinding.label_name) | |||
| .filter( | |||
| .where( | |||
| ToolLabelBinding.tool_id == provider_id, | |||
| ToolLabelBinding.tool_type == controller.provider_type.value, | |||
| ) | |||
| @@ -90,7 +90,7 @@ class ToolLabelManager: | |||
| provider_ids.append(controller.provider_id) | |||
| labels: list[ToolLabelBinding] = ( | |||
| db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() | |||
| db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all() | |||
| ) | |||
| tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} | |||
| @@ -1,16 +1,19 @@ | |||
| import json | |||
| import logging | |||
| import mimetypes | |||
| from collections.abc import Generator | |||
| import time | |||
| from collections.abc import Generator, Mapping | |||
| from os import listdir, path | |||
| from threading import Lock | |||
| from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast | |||
| from pydantic import TypeAdapter | |||
| from yarl import URL | |||
| import contexts | |||
| from core.helper.provider_cache import ToolProviderCredentialsCache | |||
| from core.plugin.entities.plugin import ToolProviderID | |||
| from core.plugin.impl.oauth import OAuthHandler | |||
| from core.plugin.impl.tool import PluginToolManager | |||
| from core.tools.__base.tool_provider import ToolProviderController | |||
| from core.tools.__base.tool_runtime import ToolRuntime | |||
| @@ -195,7 +198,7 @@ class ToolManager: | |||
| try: | |||
| builtin_provider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .filter( | |||
| .where( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.id == credential_id, | |||
| ) | |||
| @@ -213,7 +216,7 @@ class ToolManager: | |||
| # use the default provider | |||
| builtin_provider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .filter( | |||
| .where( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| (BuiltinToolProvider.provider == str(provider_id_entity)) | |||
| | (BuiltinToolProvider.provider == provider_id_entity.provider_name), | |||
| @@ -226,7 +229,7 @@ class ToolManager: | |||
| else: | |||
| builtin_provider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) | |||
| .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) | |||
| .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) | |||
| .first() | |||
| ) | |||
| @@ -244,12 +247,47 @@ class ToolManager: | |||
| tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id | |||
| ), | |||
| ) | |||
| # decrypt the credentials | |||
| decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) | |||
| # check if the credentials is expired | |||
| if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): | |||
| # TODO: circular import | |||
| from services.tools.builtin_tools_manage_service import BuiltinToolManageService | |||
| # refresh the credentials | |||
| tool_provider = ToolProviderID(provider_id) | |||
| provider_name = tool_provider.provider_name | |||
| redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" | |||
| system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) | |||
| oauth_handler = OAuthHandler() | |||
| # refresh the credentials | |||
| refreshed_credentials = oauth_handler.refresh_credentials( | |||
| tenant_id=tenant_id, | |||
| user_id=builtin_provider.user_id, | |||
| plugin_id=tool_provider.plugin_id, | |||
| provider=provider_name, | |||
| redirect_uri=redirect_uri, | |||
| system_credentials=system_credentials or {}, | |||
| credentials=decrypted_credentials, | |||
| ) | |||
| # update the credentials | |||
| builtin_provider.encrypted_credentials = ( | |||
| TypeAdapter(dict[str, Any]) | |||
| .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials))) | |||
| .decode("utf-8") | |||
| ) | |||
| builtin_provider.expires_at = refreshed_credentials.expires_at | |||
| db.session.commit() | |||
| decrypted_credentials = refreshed_credentials.credentials | |||
| return cast( | |||
| BuiltinTool, | |||
| builtin_tool.fork_tool_runtime( | |||
| runtime=ToolRuntime( | |||
| tenant_id=tenant_id, | |||
| credentials=encrypter.decrypt(builtin_provider.credentials), | |||
| credentials=dict(decrypted_credentials), | |||
| credential_type=CredentialType.of(builtin_provider.credential_type), | |||
| runtime_parameters={}, | |||
| invoke_from=invoke_from, | |||
| @@ -278,7 +316,7 @@ class ToolManager: | |||
| elif provider_type == ToolProviderType.WORKFLOW: | |||
| workflow_provider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) | |||
| .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) | |||
| .first() | |||
| ) | |||
| @@ -578,7 +616,7 @@ class ToolManager: | |||
| ORDER BY tenant_id, provider, is_default DESC, created_at DESC | |||
| """ | |||
| ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] | |||
| return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() | |||
| return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() | |||
| @classmethod | |||
| def list_providers_from_api( | |||
| @@ -626,7 +664,7 @@ class ToolManager: | |||
| # get db api providers | |||
| if "api" in filters: | |||
| db_api_providers: list[ApiToolProvider] = ( | |||
| db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() | |||
| db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() | |||
| ) | |||
| api_provider_controllers: list[dict[str, Any]] = [ | |||
| @@ -649,7 +687,7 @@ class ToolManager: | |||
| if "workflow" in filters: | |||
| # get workflow providers | |||
| workflow_providers: list[WorkflowToolProvider] = ( | |||
| db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() | |||
| db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all() | |||
| ) | |||
| workflow_provider_controllers: list[WorkflowToolProviderController] = [] | |||
| @@ -693,7 +731,7 @@ class ToolManager: | |||
| """ | |||
| provider: ApiToolProvider | None = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter( | |||
| .where( | |||
| ApiToolProvider.id == provider_id, | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ) | |||
| @@ -730,7 +768,7 @@ class ToolManager: | |||
| """ | |||
| provider: MCPToolProvider | None = ( | |||
| db.session.query(MCPToolProvider) | |||
| .filter( | |||
| .where( | |||
| MCPToolProvider.server_identifier == provider_id, | |||
| MCPToolProvider.tenant_id == tenant_id, | |||
| ) | |||
| @@ -755,7 +793,7 @@ class ToolManager: | |||
| provider_name = provider | |||
| provider_obj: ApiToolProvider | None = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter( | |||
| .where( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider, | |||
| ) | |||
| @@ -847,7 +885,7 @@ class ToolManager: | |||
| try: | |||
| workflow_provider: WorkflowToolProvider | None = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) | |||
| .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) | |||
| .first() | |||
| ) | |||
| @@ -864,7 +902,7 @@ class ToolManager: | |||
| try: | |||
| api_provider: ApiToolProvider | None = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) | |||
| .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) | |||
| .first() | |||
| ) | |||
| @@ -881,7 +919,7 @@ class ToolManager: | |||
| try: | |||
| mcp_provider: MCPToolProvider | None = ( | |||
| db.session.query(MCPToolProvider) | |||
| .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) | |||
| .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) | |||
| .first() | |||
| ) | |||
| @@ -973,7 +1011,9 @@ class ToolManager: | |||
| if variable is None: | |||
| raise ToolParameterError(f"Variable {tool_input.value} does not exist") | |||
| parameter_value = variable.value | |||
| elif tool_input.type in {"mixed", "constant"}: | |||
| elif tool_input.type == "constant": | |||
| parameter_value = tool_input.value | |||
| elif tool_input.type == "mixed": | |||
| segment_group = variable_pool.convert_template(str(tool_input.value)) | |||
| parameter_value = segment_group.text | |||
| else: | |||
| @@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| .where( | |||
| DocumentSegment.dataset_id.in_(self.dataset_ids), | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == "completed", | |||
| @@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() | |||
| document = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| .where( | |||
| Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| @@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| ): | |||
| with flask_app.app_context(): | |||
| dataset = ( | |||
| db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() | |||
| db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() | |||
| ) | |||
| if not dataset: | |||
| @@ -1,5 +1,5 @@ | |||
| from abc import abstractmethod | |||
| from typing import Any, Optional | |||
| from typing import Optional | |||
| from msal_extensions.persistence import ABC # type: ignore | |||
| from pydantic import BaseModel, ConfigDict | |||
| @@ -21,11 +21,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC): | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| @abstractmethod | |||
| def _run( | |||
| self, | |||
| *args: Any, | |||
| **kwargs: Any, | |||
| ) -> Any: | |||
| def _run(self, query: str) -> str: | |||
| """Use the tool. | |||
| Add run_manager: Optional[CallbackManagerForToolRun] = None | |||