Signed-off-by: -LAN- <laipz8200@outlook.com>tags/0.14.2
| @@ -1,12 +1,14 @@ | |||
| from flask_login import current_user | |||
| from flask_restful import marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.explore.error import NotChatAppError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields | |||
| from libs.helper import uuid_value | |||
| from models.model import AppMode | |||
| @@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource): | |||
| pinned = True if args["pinned"] == "true" else False | |||
| try: | |||
| return WebConversationService.pagination_by_last_id( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| pinned=pinned, | |||
| ) | |||
| with Session(db.engine) as session: | |||
| return WebConversationService.pagination_by_last_id( | |||
| session=session, | |||
| app_model=app_model, | |||
| user=current_user, | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| pinned=pinned, | |||
| ) | |||
| except LastConversationNotExistsError: | |||
| raise NotFound("Last Conversation Not Exists.") | |||
| @@ -1,5 +1,6 @@ | |||
| from flask_restful import Resource, marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| import services | |||
| @@ -7,6 +8,7 @@ from controllers.service_api import api | |||
| from controllers.service_api.app.error import NotChatAppError | |||
| from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from fields.conversation_fields import ( | |||
| conversation_delete_fields, | |||
| conversation_infinite_scroll_pagination_fields, | |||
| @@ -39,14 +41,16 @@ class ConversationApi(Resource): | |||
| args = parser.parse_args() | |||
| try: | |||
| return ConversationService.pagination_by_last_id( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| sort_by=args["sort_by"], | |||
| ) | |||
| with Session(db.engine) as session: | |||
| return ConversationService.pagination_by_last_id( | |||
| session=session, | |||
| app_model=app_model, | |||
| user=end_user, | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| sort_by=args["sort_by"], | |||
| ) | |||
| except services.errors.conversation.LastConversationNotExistsError: | |||
| raise NotFound("Last Conversation Not Exists.") | |||
| @@ -1,11 +1,13 @@ | |||
| from flask_restful import marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.web import api | |||
| from controllers.web.error import NotChatAppError | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields | |||
| from libs.helper import uuid_value | |||
| from models.model import AppMode | |||
| @@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource): | |||
| pinned = True if args["pinned"] == "true" else False | |||
| try: | |||
| return WebConversationService.pagination_by_last_id( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| pinned=pinned, | |||
| sort_by=args["sort_by"], | |||
| ) | |||
| with Session(db.engine) as session: | |||
| return WebConversationService.pagination_by_last_id( | |||
| session=session, | |||
| app_model=app_model, | |||
| user=end_user, | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| pinned=pinned, | |||
| sort_by=args["sort_by"], | |||
| ) | |||
| except LastConversationNotExistsError: | |||
| raise NotFound("Last Conversation Not Exists.") | |||
| @@ -1,4 +1,5 @@ | |||
| from sqlalchemy import func | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| from .engine import db | |||
| from .model import Message | |||
| @@ -33,7 +34,7 @@ class PinnedConversation(db.Model): | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| conversation_id = db.Column(StringUUID, nullable=False) | |||
| conversation_id: Mapped[str] = mapped_column(StringUUID) | |||
| created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @@ -1,8 +1,9 @@ | |||
| from collections.abc import Callable | |||
| from collections.abc import Callable, Sequence | |||
| from datetime import UTC, datetime | |||
| from typing import Optional, Union | |||
| from sqlalchemy import asc, desc, or_ | |||
| from sqlalchemy import asc, desc, func, or_, select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from core.llm_generator.llm_generator import LLMGenerator | |||
| @@ -18,19 +19,21 @@ class ConversationService: | |||
| @classmethod | |||
| def pagination_by_last_id( | |||
| cls, | |||
| *, | |||
| session: Session, | |||
| app_model: App, | |||
| user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], | |||
| limit: int, | |||
| invoke_from: InvokeFrom, | |||
| include_ids: Optional[list] = None, | |||
| exclude_ids: Optional[list] = None, | |||
| include_ids: Optional[Sequence[str]] = None, | |||
| exclude_ids: Optional[Sequence[str]] = None, | |||
| sort_by: str = "-updated_at", | |||
| ) -> InfiniteScrollPagination: | |||
| if not user: | |||
| return InfiniteScrollPagination(data=[], limit=limit, has_more=False) | |||
| base_query = db.session.query(Conversation).filter( | |||
| stmt = select(Conversation).where( | |||
| Conversation.is_deleted == False, | |||
| Conversation.app_id == app_model.id, | |||
| Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||
| @@ -38,37 +41,40 @@ class ConversationService: | |||
| Conversation.from_account_id == (user.id if isinstance(user, Account) else None), | |||
| or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), | |||
| ) | |||
| if include_ids is not None: | |||
| base_query = base_query.filter(Conversation.id.in_(include_ids)) | |||
| stmt = stmt.where(Conversation.id.in_(include_ids)) | |||
| if exclude_ids is not None: | |||
| base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) | |||
| stmt = stmt.where(~Conversation.id.in_(exclude_ids)) | |||
| # define sort fields and directions | |||
| sort_field, sort_direction = cls._get_sort_params(sort_by) | |||
| if last_id: | |||
| last_conversation = base_query.filter(Conversation.id == last_id).first() | |||
| last_conversation = session.scalar(stmt.where(Conversation.id == last_id)) | |||
| if not last_conversation: | |||
| raise LastConversationNotExistsError() | |||
| # build filters based on sorting | |||
| filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) | |||
| base_query = base_query.filter(filter_condition) | |||
| base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) | |||
| conversations = base_query.limit(limit).all() | |||
| filter_condition = cls._build_filter_condition( | |||
| sort_field=sort_field, | |||
| sort_direction=sort_direction, | |||
| reference_conversation=last_conversation, | |||
| ) | |||
| stmt = stmt.where(filter_condition) | |||
| query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit) | |||
| conversations = session.scalars(query_stmt).all() | |||
| has_more = False | |||
| if len(conversations) == limit: | |||
| current_page_last_conversation = conversations[-1] | |||
| rest_filter_condition = cls._build_filter_condition( | |||
| sort_field, sort_direction, current_page_last_conversation, is_next_page=True | |||
| sort_field=sort_field, | |||
| sort_direction=sort_direction, | |||
| reference_conversation=current_page_last_conversation, | |||
| ) | |||
| rest_count = base_query.filter(rest_filter_condition).count() | |||
| count_stmt = stmt.where(rest_filter_condition) | |||
| count_stmt = select(func.count()).select_from(count_stmt.subquery()) | |||
| rest_count = session.scalar(count_stmt) or 0 | |||
| if rest_count > 0: | |||
| has_more = True | |||
| @@ -81,11 +87,9 @@ class ConversationService: | |||
| return sort_by, asc | |||
| @classmethod | |||
| def _build_filter_condition( | |||
| cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False | |||
| ): | |||
| def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): | |||
| field_value = getattr(reference_conversation, sort_field) | |||
| if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): | |||
| if sort_direction == desc: | |||
| return getattr(Conversation, sort_field) < field_value | |||
| else: | |||
| return getattr(Conversation, sort_field) > field_value | |||
| @@ -1,5 +1,8 @@ | |||
| from typing import Optional, Union | |||
| from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |||
| @@ -13,6 +16,8 @@ class WebConversationService: | |||
| @classmethod | |||
| def pagination_by_last_id( | |||
| cls, | |||
| *, | |||
| session: Session, | |||
| app_model: App, | |||
| user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], | |||
| @@ -23,24 +28,25 @@ class WebConversationService: | |||
| ) -> InfiniteScrollPagination: | |||
| include_ids = None | |||
| exclude_ids = None | |||
| if pinned is not None: | |||
| pinned_conversations = ( | |||
| db.session.query(PinnedConversation) | |||
| .filter( | |||
| if pinned is not None and user: | |||
| stmt = ( | |||
| select(PinnedConversation.conversation_id) | |||
| .where( | |||
| PinnedConversation.app_id == app_model.id, | |||
| PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), | |||
| PinnedConversation.created_by == user.id, | |||
| ) | |||
| .order_by(PinnedConversation.created_at.desc()) | |||
| .all() | |||
| ) | |||
| pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] | |||
| pinned_conversation_ids = session.scalars(stmt).all() | |||
| if pinned: | |||
| include_ids = pinned_conversation_ids | |||
| else: | |||
| exclude_ids = pinned_conversation_ids | |||
| return ConversationService.pagination_by_last_id( | |||
| session=session, | |||
| app_model=app_model, | |||
| user=user, | |||
| last_id=last_id, | |||