浏览代码

Fix Empty Collection WHERE Filter Issue (#23086)

tags/1.7.2
NeatGuyCoding 3 个月前
父节点
当前提交
47cc951841
没有帐户链接到提交者的电子邮件

+ 3
- 2
api/services/app_service.py 查看文件

if args.get("name"): if args.get("name"):
name = args["name"][:30] name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%")) filters.append(App.name.ilike(f"%{name}%"))
if args.get("tag_ids"):
# Check if tag_ids is not empty to avoid WHERE false condition
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
if target_ids:
if target_ids and len(target_ids) > 0:
filters.append(App.id.in_(target_ids)) filters.append(App.id.in_(target_ids))
else: else:
return None return None

+ 4
- 2
api/services/conversation_service.py 查看文件

Conversation.from_account_id == (user.id if isinstance(user, Account) else None), Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
) )
if include_ids is not None:
# Check if include_ids is not None and not empty to avoid WHERE false condition
if include_ids is not None and len(include_ids) > 0:
stmt = stmt.where(Conversation.id.in_(include_ids)) stmt = stmt.where(Conversation.id.in_(include_ids))
if exclude_ids is not None:
# Check if exclude_ids is not None and not empty to avoid WHERE false condition
if exclude_ids is not None and len(exclude_ids) > 0:
stmt = stmt.where(~Conversation.id.in_(exclude_ids)) stmt = stmt.where(~Conversation.id.in_(exclude_ids))


# define sort fields and directions # define sort fields and directions

+ 21
- 5
api/services/dataset_service.py 查看文件



if user.current_role == TenantAccountRole.DATASET_OPERATOR: if user.current_role == TenantAccountRole.DATASET_OPERATOR:
# only show datasets that the user has permission to access # only show datasets that the user has permission to access
if permitted_dataset_ids:
# Check if permitted_dataset_ids is not empty to avoid WHERE false condition
if permitted_dataset_ids and len(permitted_dataset_ids) > 0:
query = query.where(Dataset.id.in_(permitted_dataset_ids)) query = query.where(Dataset.id.in_(permitted_dataset_ids))
else: else:
return [], 0 return [], 0
else: else:
if user.current_role != TenantAccountRole.OWNER or not include_all: if user.current_role != TenantAccountRole.OWNER or not include_all:
# show all datasets that the user has permission to access # show all datasets that the user has permission to access
if permitted_dataset_ids:
# Check if permitted_dataset_ids is not empty to avoid WHERE false condition
if permitted_dataset_ids and len(permitted_dataset_ids) > 0:
query = query.where( query = query.where(
db.or_( db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM, Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
if search: if search:
query = query.where(Dataset.name.ilike(f"%{search}%")) query = query.where(Dataset.name.ilike(f"%{search}%"))


if tag_ids:
# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
if target_ids:
if target_ids and len(target_ids) > 0:
query = query.where(Dataset.id.in_(target_ids)) query = query.where(Dataset.id.in_(target_ids))
else: else:
return [], 0 return [], 0


@staticmethod @staticmethod
def get_datasets_by_ids(ids, tenant_id): def get_datasets_by_ids(ids, tenant_id):
# Check if ids is not empty to avoid WHERE false condition
if not ids or len(ids) == 0:
return [], 0
stmt = select(Dataset).where(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) stmt = select(Dataset).where(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)


datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)


@staticmethod @staticmethod
def delete_documents(dataset: Dataset, document_ids: list[str]): def delete_documents(dataset: Dataset, document_ids: list[str]):
# Check if document_ids is not empty to avoid WHERE false condition
if not document_ids or len(document_ids) == 0:
return
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all() documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
file_ids = [ file_ids = [
document.data_source_info_dict["upload_file_id"] document.data_source_info_dict["upload_file_id"]


@classmethod @classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
index_node_ids = ( index_node_ids = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id) .with_entities(DocumentSegment.index_node_id)


@classmethod @classmethod
def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document):
# Check if segment_ids is not empty to avoid WHERE false condition
if not segment_ids or len(segment_ids) == 0:
return
if action == "enable": if action == "enable":
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
) )


if status_list:
# Check if status_list is not empty to avoid WHERE false condition
if status_list and len(status_list) > 0:
query = query.where(DocumentSegment.status.in_(status_list)) query = query.where(DocumentSegment.status.in_(status_list))


if keyword: if keyword:

+ 2
- 1
api/services/message_service.py 查看文件



base_query = base_query.where(Message.conversation_id == conversation.id) base_query = base_query.where(Message.conversation_id == conversation.id)


if include_ids is not None:
# Check if include_ids is not None and not empty to avoid WHERE false condition
if include_ids is not None and len(include_ids) > 0:
base_query = base_query.where(Message.id.in_(include_ids)) base_query = base_query.where(Message.id.in_(include_ids))


if last_id: if last_id:

+ 6
- 0
api/services/tag_service.py 查看文件



@staticmethod @staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
if not tags: if not tags:
return [] return []
tag_ids = [tag.id for tag in tags] tag_ids = [tag.id for tag in tags]
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
tag_bindings = ( tag_bindings = (
db.session.query(TagBinding.target_id) db.session.query(TagBinding.target_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)

正在加载...
取消
保存