Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>tags/1.9.1
| @@ -10,6 +10,7 @@ from flask import current_app | |||
| from pydantic import TypeAdapter | |||
| from sqlalchemy import select | |||
| from sqlalchemy.exc import SQLAlchemyError | |||
| from sqlalchemy.orm import sessionmaker | |||
| from configs import dify_config | |||
| from constants.languages import languages | |||
| @@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm): | |||
| if str(new_password).strip() != str(password_confirm).strip(): | |||
| click.echo(click.style("Passwords do not match.", fg="red")) | |||
| return | |||
| with sessionmaker(db.engine, expire_on_commit=False).begin() as session: | |||
| account = session.query(Account).where(Account.email == email).one_or_none() | |||
| account = db.session.query(Account).where(Account.email == email).one_or_none() | |||
| if not account: | |||
| click.echo(click.style(f"Account not found for email: {email}", fg="red")) | |||
| return | |||
| if not account: | |||
| click.echo(click.style(f"Account not found for email: {email}", fg="red")) | |||
| return | |||
| try: | |||
| valid_password(new_password) | |||
| except: | |||
| click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) | |||
| return | |||
| try: | |||
| valid_password(new_password) | |||
| except: | |||
| click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) | |||
| return | |||
| # generate password salt | |||
| salt = secrets.token_bytes(16) | |||
| base64_salt = base64.b64encode(salt).decode() | |||
| # generate password salt | |||
| salt = secrets.token_bytes(16) | |||
| base64_salt = base64.b64encode(salt).decode() | |||
| # encrypt password with salt | |||
| password_hashed = hash_password(new_password, salt) | |||
| base64_password_hashed = base64.b64encode(password_hashed).decode() | |||
| account.password = base64_password_hashed | |||
| account.password_salt = base64_salt | |||
| db.session.commit() | |||
| AccountService.reset_login_error_rate_limit(email) | |||
| click.echo(click.style("Password reset successfully.", fg="green")) | |||
| # encrypt password with salt | |||
| password_hashed = hash_password(new_password, salt) | |||
| base64_password_hashed = base64.b64encode(password_hashed).decode() | |||
| account.password = base64_password_hashed | |||
| account.password_salt = base64_salt | |||
| AccountService.reset_login_error_rate_limit(email) | |||
| click.echo(click.style("Password reset successfully.", fg="green")) | |||
| @click.command("reset-email", help="Reset the account email.") | |||
| @@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm): | |||
| if str(new_email).strip() != str(email_confirm).strip(): | |||
| click.echo(click.style("New emails do not match.", fg="red")) | |||
| return | |||
| with sessionmaker(db.engine, expire_on_commit=False).begin() as session: | |||
| account = session.query(Account).where(Account.email == email).one_or_none() | |||
| account = db.session.query(Account).where(Account.email == email).one_or_none() | |||
| if not account: | |||
| click.echo(click.style(f"Account not found for email: {email}", fg="red")) | |||
| return | |||
| if not account: | |||
| click.echo(click.style(f"Account not found for email: {email}", fg="red")) | |||
| return | |||
| try: | |||
| email_validate(new_email) | |||
| except: | |||
| click.echo(click.style(f"Invalid email: {new_email}", fg="red")) | |||
| return | |||
| try: | |||
| email_validate(new_email) | |||
| except: | |||
| click.echo(click.style(f"Invalid email: {new_email}", fg="red")) | |||
| return | |||
| account.email = new_email | |||
| db.session.commit() | |||
| click.echo(click.style("Email updated successfully.", fg="green")) | |||
| account.email = new_email | |||
| click.echo(click.style("Email updated successfully.", fg="green")) | |||
| @click.command( | |||
| @@ -139,25 +138,24 @@ def reset_encrypt_key_pair(): | |||
| if dify_config.EDITION != "SELF_HOSTED": | |||
| click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) | |||
| return | |||
| with sessionmaker(db.engine, expire_on_commit=False).begin() as session: | |||
| tenants = session.query(Tenant).all() | |||
| for tenant in tenants: | |||
| if not tenant: | |||
| click.echo(click.style("No workspaces found. Run /install first.", fg="red")) | |||
| return | |||
| tenants = db.session.query(Tenant).all() | |||
| for tenant in tenants: | |||
| if not tenant: | |||
| click.echo(click.style("No workspaces found. Run /install first.", fg="red")) | |||
| return | |||
| tenant.encrypt_public_key = generate_key_pair(tenant.id) | |||
| tenant.encrypt_public_key = generate_key_pair(tenant.id) | |||
| db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() | |||
| db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() | |||
| db.session.commit() | |||
| session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() | |||
| session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() | |||
| click.echo( | |||
| click.style( | |||
| f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", | |||
| fg="green", | |||
| click.echo( | |||
| click.style( | |||
| f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", | |||
| fg="green", | |||
| ) | |||
| ) | |||
| ) | |||
| @click.command("vdb-migrate", help="Migrate vector db.") | |||
| @@ -182,14 +180,15 @@ def migrate_annotation_vector_database(): | |||
| try: | |||
| # get apps info | |||
| per_page = 50 | |||
| apps = ( | |||
| db.session.query(App) | |||
| .where(App.status == "normal") | |||
| .order_by(App.created_at.desc()) | |||
| .limit(per_page) | |||
| .offset((page - 1) * per_page) | |||
| .all() | |||
| ) | |||
| with sessionmaker(db.engine, expire_on_commit=False).begin() as session: | |||
| apps = ( | |||
| session.query(App) | |||
| .where(App.status == "normal") | |||
| .order_by(App.created_at.desc()) | |||
| .limit(per_page) | |||
| .offset((page - 1) * per_page) | |||
| .all() | |||
| ) | |||
| if not apps: | |||
| break | |||
| except SQLAlchemyError: | |||
| @@ -203,26 +202,27 @@ def migrate_annotation_vector_database(): | |||
| ) | |||
| try: | |||
| click.echo(f"Creating app annotation index: {app.id}") | |||
| app_annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() | |||
| ) | |||
| with sessionmaker(db.engine, expire_on_commit=False).begin() as session: | |||
| app_annotation_setting = ( | |||
| session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() | |||
| ) | |||
| if not app_annotation_setting: | |||
| skipped_count = skipped_count + 1 | |||
| click.echo(f"App annotation setting disabled: {app.id}") | |||
| continue | |||
| # get dataset_collection_binding info | |||
| dataset_collection_binding = ( | |||
| db.session.query(DatasetCollectionBinding) | |||
| .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) | |||
| .first() | |||
| ) | |||
| if not dataset_collection_binding: | |||
| click.echo(f"App annotation collection binding not found: {app.id}") | |||
| continue | |||
| annotations = db.session.scalars( | |||
| select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) | |||
| ).all() | |||
| if not app_annotation_setting: | |||
| skipped_count = skipped_count + 1 | |||
| click.echo(f"App annotation setting disabled: {app.id}") | |||
| continue | |||
| # get dataset_collection_binding info | |||
| dataset_collection_binding = ( | |||
| session.query(DatasetCollectionBinding) | |||
| .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) | |||
| .first() | |||
| ) | |||
| if not dataset_collection_binding: | |||
| click.echo(f"App annotation collection binding not found: {app.id}") | |||
| continue | |||
| annotations = session.scalars( | |||
| select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) | |||
| ).all() | |||
| dataset = Dataset( | |||
| id=app.id, | |||
| tenant_id=app.tenant_id, | |||
| @@ -1,6 +1,7 @@ | |||
| from datetime import datetime | |||
| import pytz # pip install pytz | |||
| import sqlalchemy as sa | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, marshal_with, reqparse | |||
| from flask_restx.inputs import int_range | |||
| @@ -70,7 +71,7 @@ class CompletionConversationApi(Resource): | |||
| parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") | |||
| args = parser.parse_args() | |||
| query = db.select(Conversation).where( | |||
| query = sa.select(Conversation).where( | |||
| Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) | |||
| ) | |||
| @@ -236,7 +237,7 @@ class ChatConversationApi(Resource): | |||
| .subquery() | |||
| ) | |||
| query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) | |||
| query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) | |||
| if args["keyword"]: | |||
| keyword_filter = f"%{args['keyword']}%" | |||
| @@ -4,6 +4,7 @@ from argparse import ArgumentTypeError | |||
| from collections.abc import Sequence | |||
| from typing import Literal, cast | |||
| import sqlalchemy as sa | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restx import Resource, fields, marshal, marshal_with, reqparse | |||
| @@ -211,13 +212,13 @@ class DatasetDocumentListApi(Resource): | |||
| if sort == "hit_count": | |||
| sub_query = ( | |||
| db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) | |||
| sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")) | |||
| .group_by(DocumentSegment.document_id) | |||
| .subquery() | |||
| ) | |||
| query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( | |||
| sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), | |||
| sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)), | |||
| sort_logic(Document.position), | |||
| ) | |||
| elif sort == "created_at": | |||
| @@ -910,7 +910,7 @@ class AppDatasetJoin(Base): | |||
| id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) | |||
| app_id = mapped_column(StringUUID, nullable=False) | |||
| dataset_id = mapped_column(StringUUID, nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) | |||
| @property | |||
| def app(self): | |||
| @@ -931,7 +931,7 @@ class DatasetQuery(Base): | |||
| source_app_id = mapped_column(StringUUID, nullable=True) | |||
| created_by_role = mapped_column(String, nullable=False) | |||
| created_by = mapped_column(StringUUID, nullable=False) | |||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp()) | |||
| class DatasetKeywordTable(Base): | |||
| @@ -1731,7 +1731,7 @@ class MessageChain(Base): | |||
| type: Mapped[str] = mapped_column(String(255), nullable=False) | |||
| input = mapped_column(sa.Text, nullable=True) | |||
| output = mapped_column(sa.Text, nullable=True) | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) | |||
| class MessageAgentThought(Base): | |||
| @@ -1769,7 +1769,7 @@ class MessageAgentThought(Base): | |||
| latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) | |||
| created_by_role = mapped_column(String, nullable=False) | |||
| created_by = mapped_column(StringUUID, nullable=False) | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) | |||
| @property | |||
| def files(self) -> list[Any]: | |||
| @@ -1872,7 +1872,7 @@ class DatasetRetrieverResource(Base): | |||
| index_node_hash = mapped_column(sa.Text, nullable=True) | |||
| retriever_from = mapped_column(sa.Text, nullable=False) | |||
| created_by = mapped_column(StringUUID, nullable=False) | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp()) | |||
| class Tag(Base): | |||
| @@ -2,6 +2,7 @@ import json | |||
| import logging | |||
| from typing import TypedDict, cast | |||
| import sqlalchemy as sa | |||
| from flask_sqlalchemy.pagination import Pagination | |||
| from configs import dify_config | |||
| @@ -65,7 +66,7 @@ class AppService: | |||
| return None | |||
| app_models = db.paginate( | |||
| db.select(App).where(*filters).order_by(App.created_at.desc()), | |||
| sa.select(App).where(*filters).order_by(App.created_at.desc()), | |||
| page=args["page"], | |||
| per_page=args["limit"], | |||
| error_out=False, | |||
| @@ -115,12 +115,12 @@ class DatasetService: | |||
| # Check if permitted_dataset_ids is not empty to avoid WHERE false condition | |||
| if permitted_dataset_ids and len(permitted_dataset_ids) > 0: | |||
| query = query.where( | |||
| db.or_( | |||
| sa.or_( | |||
| Dataset.permission == DatasetPermissionEnum.ALL_TEAM, | |||
| db.and_( | |||
| sa.and_( | |||
| Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id | |||
| ), | |||
| db.and_( | |||
| sa.and_( | |||
| Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, | |||
| Dataset.id.in_(permitted_dataset_ids), | |||
| ), | |||
| @@ -128,9 +128,9 @@ class DatasetService: | |||
| ) | |||
| else: | |||
| query = query.where( | |||
| db.or_( | |||
| sa.or_( | |||
| Dataset.permission == DatasetPermissionEnum.ALL_TEAM, | |||
| db.and_( | |||
| sa.and_( | |||
| Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id | |||
| ), | |||
| ) | |||
| @@ -1879,7 +1879,7 @@ class DocumentService: | |||
| # for notion_info in notion_info_list: | |||
| # workspace_id = notion_info.workspace_id | |||
| # data_source_binding = DataSourceOauthBinding.query.filter( | |||
| # db.and_( | |||
| # sa.and_( | |||
| # DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| # DataSourceOauthBinding.provider == "notion", | |||
| # DataSourceOauthBinding.disabled == False, | |||
| @@ -471,7 +471,7 @@ class PluginMigration: | |||
| total_failed_tenant = 0 | |||
| while True: | |||
| # paginate | |||
| tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100) | |||
| tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100) | |||
| if tenants.items is None or len(tenants.items) == 0: | |||
| break | |||
| @@ -1,5 +1,6 @@ | |||
| import uuid | |||
| import sqlalchemy as sa | |||
| from flask_login import current_user | |||
| from sqlalchemy import func, select | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -18,7 +19,7 @@ class TagService: | |||
| .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) | |||
| ) | |||
| if keyword: | |||
| query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%"))) | |||
| query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%"))) | |||
| query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) | |||
| results: list = query.order_by(Tag.created_at.desc()).all() | |||
| return results | |||
| @@ -2,6 +2,7 @@ import logging | |||
| import time | |||
| import click | |||
| import sqlalchemy as sa | |||
| from celery import shared_task | |||
| from sqlalchemy import select | |||
| @@ -51,7 +52,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| data_source_binding = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .where( | |||
| db.and_( | |||
| sa.and_( | |||
| DataSourceOauthBinding.tenant_id == document.tenant_id, | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||