Переглянути джерело

one example of Session (#24135)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
tags/1.9.1
Asuka Minato 1 місяць тому
джерело
коміт
25c69ac540
Аккаунт користувача з таким Email не знайдено

+ 76
- 76
api/commands.py Переглянути файл

@@ -10,6 +10,7 @@ from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker

from configs import dify_config
from constants.languages import languages
@@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()

account = db.session.query(Account).where(Account.email == email).one_or_none()

if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return

try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return

# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()

# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))


@click.command("reset-email", help="Reset the account email.")
@@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()

account = db.session.query(Account).where(Account.email == email).one_or_none()

if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return

try:
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
try:
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return

account.email = new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
account.email = new_email
click.echo(click.style("Email updated successfully.", fg="green"))


@click.command(
@@ -139,25 +138,24 @@ def reset_encrypt_key_pair():
if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return

tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return

tenant.encrypt_public_key = generate_key_pair(tenant.id)
tenant.encrypt_public_key = generate_key_pair(tenant.id)

db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()

click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
)
)
)


@click.command("vdb-migrate", help="Migrate vector db.")
@@ -182,14 +180,15 @@ def migrate_annotation_vector_database():
try:
# get apps info
per_page = 50
apps = (
db.session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
apps = (
session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
if not apps:
break
except SQLAlchemyError:
@@ -203,26 +202,27 @@ def migrate_annotation_vector_database():
)
try:
click.echo(f"Creating app annotation index: {app.id}")
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)

if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,

+ 3
- 2
api/controllers/console/app/conversation.py Переглянути файл

@@ -1,6 +1,7 @@
from datetime import datetime

import pytz # pip install pytz
import sqlalchemy as sa
from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range
@@ -70,7 +71,7 @@ class CompletionConversationApi(Resource):
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()

query = db.select(Conversation).where(
query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)

@@ -236,7 +237,7 @@ class ChatConversationApi(Resource):
.subquery()
)

query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))

if args["keyword"]:
keyword_filter = f"%{args['keyword']}%"

+ 3
- 2
api/controllers/console/datasets/datasets_document.py Переглянути файл

@@ -4,6 +4,7 @@ from argparse import ArgumentTypeError
from collections.abc import Sequence
from typing import Literal, cast

import sqlalchemy as sa
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@@ -211,13 +212,13 @@ class DatasetDocumentListApi(Resource):

if sort == "hit_count":
sub_query = (
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.group_by(DocumentSegment.document_id)
.subquery()
)

query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position),
)
elif sort == "created_at":

+ 2
- 2
api/models/dataset.py Переглянути файл

@@ -910,7 +910,7 @@ class AppDatasetJoin(Base):
id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())

@property
def app(self):
@@ -931,7 +931,7 @@ class DatasetQuery(Base):
source_app_id = mapped_column(StringUUID, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())


class DatasetKeywordTable(Base):

+ 3
- 3
api/models/model.py Переглянути файл

@@ -1731,7 +1731,7 @@ class MessageChain(Base):
type: Mapped[str] = mapped_column(String(255), nullable=False)
input = mapped_column(sa.Text, nullable=True)
output = mapped_column(sa.Text, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())


class MessageAgentThought(Base):
@@ -1769,7 +1769,7 @@ class MessageAgentThought(Base):
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())

@property
def files(self) -> list[Any]:
@@ -1872,7 +1872,7 @@ class DatasetRetrieverResource(Base):
index_node_hash = mapped_column(sa.Text, nullable=True)
retriever_from = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())


class Tag(Base):

+ 2
- 1
api/services/app_service.py Переглянути файл

@@ -2,6 +2,7 @@ import json
import logging
from typing import TypedDict, cast

import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination

from configs import dify_config
@@ -65,7 +66,7 @@ class AppService:
return None

app_models = db.paginate(
db.select(App).where(*filters).order_by(App.created_at.desc()),
sa.select(App).where(*filters).order_by(App.created_at.desc()),
page=args["page"],
per_page=args["limit"],
error_out=False,

+ 6
- 6
api/services/dataset_service.py Переглянути файл

@@ -115,12 +115,12 @@ class DatasetService:
# Check if permitted_dataset_ids is not empty to avoid WHERE false condition
if permitted_dataset_ids and len(permitted_dataset_ids) > 0:
query = query.where(
db.or_(
sa.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
sa.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
),
db.and_(
sa.and_(
Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
Dataset.id.in_(permitted_dataset_ids),
),
@@ -128,9 +128,9 @@ class DatasetService:
)
else:
query = query.where(
db.or_(
sa.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
sa.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
),
)
@@ -1879,7 +1879,7 @@ class DocumentService:
# for notion_info in notion_info_list:
# workspace_id = notion_info.workspace_id
# data_source_binding = DataSourceOauthBinding.query.filter(
# db.and_(
# sa.and_(
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
# DataSourceOauthBinding.provider == "notion",
# DataSourceOauthBinding.disabled == False,

+ 1
- 1
api/services/plugin/plugin_migration.py Переглянути файл

@@ -471,7 +471,7 @@ class PluginMigration:
total_failed_tenant = 0
while True:
# paginate
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
if tenants.items is None or len(tenants.items) == 0:
break


+ 2
- 1
api/services/tag_service.py Переглянути файл

@@ -1,5 +1,6 @@
import uuid

import sqlalchemy as sa
from flask_login import current_user
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound
@@ -18,7 +19,7 @@ class TagService:
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results

+ 2
- 1
api/tasks/document_indexing_sync_task.py Переглянути файл

@@ -2,6 +2,7 @@ import logging
import time

import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import select

@@ -51,7 +52,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
sa.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,

Завантаження…
Відмінити
Зберегти