Преглед на файлове

one example of Session (#24135)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
tags/1.9.1
Asuka Minato преди 1 месец
родител
ревизия
25c69ac540
No account linked to committer's email address

+ 76
- 76
api/commands.py Целия файл

from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker


from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
if str(new_password).strip() != str(password_confirm).strip(): if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red")) click.echo(click.style("Passwords do not match.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()


account = db.session.query(Account).where(Account.email == email).one_or_none()

if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return


try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return
try:
valid_password(new_password)
except:
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
return


# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()


# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))
# encrypt password with salt
password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(email)
click.echo(click.style("Password reset successfully.", fg="green"))




@click.command("reset-email", help="Reset the account email.") @click.command("reset-email", help="Reset the account email.")
if str(new_email).strip() != str(email_confirm).strip(): if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red")) click.echo(click.style("New emails do not match.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()


account = db.session.query(Account).where(Account.email == email).one_or_none()

if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return


try:
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
try:
email_validate(new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return


account.email = new_email
db.session.commit()
click.echo(click.style("Email updated successfully.", fg="green"))
account.email = new_email
click.echo(click.style("Email updated successfully.", fg="green"))




@click.command( @click.command(
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
return return
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
tenants = session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return


tenants = db.session.query(Tenant).all()
for tenant in tenants:
if not tenant:
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
return

tenant.encrypt_public_key = generate_key_pair(tenant.id)
tenant.encrypt_public_key = generate_key_pair(tenant.id)


db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()


click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
click.echo(
click.style(
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
fg="green",
)
) )
)




@click.command("vdb-migrate", help="Migrate vector db.") @click.command("vdb-migrate", help="Migrate vector db.")
try: try:
# get apps info # get apps info
per_page = 50 per_page = 50
apps = (
db.session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
apps = (
session.query(App)
.where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
.all()
)
if not apps: if not apps:
break break
except SQLAlchemyError: except SQLAlchemyError:
) )
try: try:
click.echo(f"Creating app annotation index: {app.id}") click.echo(f"Creating app annotation index: {app.id}")
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)


if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo(f"App annotation setting disabled: {app.id}")
continue
# get dataset_collection_binding info
dataset_collection_binding = (
session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo(f"App annotation collection binding not found: {app.id}")
continue
annotations = session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
).all()
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,

+ 3
- 2
api/controllers/console/app/conversation.py Целия файл

from datetime import datetime from datetime import datetime


import pytz # pip install pytz import pytz # pip install pytz
import sqlalchemy as sa
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from flask_restx.inputs import int_range from flask_restx.inputs import int_range
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()


query = db.select(Conversation).where(
query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
) )


.subquery() .subquery()
) )


query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))


if args["keyword"]: if args["keyword"]:
keyword_filter = f"%{args['keyword']}%" keyword_filter = f"%{args['keyword']}%"

+ 3
- 2
api/controllers/console/datasets/datasets_document.py Целия файл

from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, cast from typing import Literal, cast


import sqlalchemy as sa
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse


if sort == "hit_count": if sort == "hit_count":
sub_query = ( sub_query = (
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
.group_by(DocumentSegment.document_id) .group_by(DocumentSegment.document_id)
.subquery() .subquery()
) )


query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
sort_logic(Document.position), sort_logic(Document.position),
) )
elif sort == "created_at": elif sort == "created_at":

+ 2
- 2
api/models/dataset.py Целия файл

id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()")) id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False)
dataset_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())


@property @property
def app(self): def app(self):
source_app_id = mapped_column(StringUUID, nullable=True) source_app_id = mapped_column(StringUUID, nullable=True)
created_by_role = mapped_column(String, nullable=False) created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False) created_by = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=sa.func.current_timestamp())




class DatasetKeywordTable(Base): class DatasetKeywordTable(Base):

+ 3
- 3
api/models/model.py Целия файл

type: Mapped[str] = mapped_column(String(255), nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False)
input = mapped_column(sa.Text, nullable=True) input = mapped_column(sa.Text, nullable=True)
output = mapped_column(sa.Text, nullable=True) output = mapped_column(sa.Text, nullable=True)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())




class MessageAgentThought(Base): class MessageAgentThought(Base):
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True) latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
created_by_role = mapped_column(String, nullable=False) created_by_role = mapped_column(String, nullable=False)
created_by = mapped_column(StringUUID, nullable=False) created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())


@property @property
def files(self) -> list[Any]: def files(self) -> list[Any]:
index_node_hash = mapped_column(sa.Text, nullable=True) index_node_hash = mapped_column(sa.Text, nullable=True)
retriever_from = mapped_column(sa.Text, nullable=False) retriever_from = mapped_column(sa.Text, nullable=False)
created_by = mapped_column(StringUUID, nullable=False) created_by = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=db.func.current_timestamp())
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.current_timestamp())




class Tag(Base): class Tag(Base):

+ 2
- 1
api/services/app_service.py Целия файл

import logging import logging
from typing import TypedDict, cast from typing import TypedDict, cast


import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination from flask_sqlalchemy.pagination import Pagination


from configs import dify_config from configs import dify_config
return None return None


app_models = db.paginate( app_models = db.paginate(
db.select(App).where(*filters).order_by(App.created_at.desc()),
sa.select(App).where(*filters).order_by(App.created_at.desc()),
page=args["page"], page=args["page"],
per_page=args["limit"], per_page=args["limit"],
error_out=False, error_out=False,

+ 6
- 6
api/services/dataset_service.py Целия файл

# Check if permitted_dataset_ids is not empty to avoid WHERE false condition # Check if permitted_dataset_ids is not empty to avoid WHERE false condition
if permitted_dataset_ids and len(permitted_dataset_ids) > 0: if permitted_dataset_ids and len(permitted_dataset_ids) > 0:
query = query.where( query = query.where(
db.or_(
sa.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM, Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
sa.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
), ),
db.and_(
sa.and_(
Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM,
Dataset.id.in_(permitted_dataset_ids), Dataset.id.in_(permitted_dataset_ids),
), ),
) )
else: else:
query = query.where( query = query.where(
db.or_(
sa.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM, Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
sa.and_(
Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id
), ),
) )
# for notion_info in notion_info_list: # for notion_info in notion_info_list:
# workspace_id = notion_info.workspace_id # workspace_id = notion_info.workspace_id
# data_source_binding = DataSourceOauthBinding.query.filter( # data_source_binding = DataSourceOauthBinding.query.filter(
# db.and_(
# sa.and_(
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, # DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
# DataSourceOauthBinding.provider == "notion", # DataSourceOauthBinding.provider == "notion",
# DataSourceOauthBinding.disabled == False, # DataSourceOauthBinding.disabled == False,

+ 1
- 1
api/services/plugin/plugin_migration.py Целия файл

total_failed_tenant = 0 total_failed_tenant = 0
while True: while True:
# paginate # paginate
tenants = db.paginate(db.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
tenants = db.paginate(sa.select(Tenant).order_by(Tenant.created_at.desc()), page=page, per_page=100)
if tenants.items is None or len(tenants.items) == 0: if tenants.items is None or len(tenants.items) == 0:
break break



+ 2
- 1
api/services/tag_service.py Целия файл

import uuid import uuid


import sqlalchemy as sa
from flask_login import current_user from flask_login import current_user
from sqlalchemy import func, select from sqlalchemy import func, select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
) )
if keyword: if keyword:
query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at) query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all() results: list = query.order_by(Tag.created_at.desc()).all()
return results return results

+ 2
- 1
api/tasks/document_indexing_sync_task.py Целия файл

import time import time


import click import click
import sqlalchemy as sa
from celery import shared_task from celery import shared_task
from sqlalchemy import select from sqlalchemy import select


data_source_binding = ( data_source_binding = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.where( .where(
db.and_(
sa.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id, DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,

Loading…
Отказ
Запис