Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.4.0
| @@ -1,5 +1,6 @@ | |||
| import enum | |||
| import json | |||
| from typing import cast | |||
| from flask_login import UserMixin # type: ignore | |||
| from sqlalchemy import func | |||
| @@ -46,7 +47,6 @@ class Account(UserMixin, Base): | |||
| @property | |||
| def current_tenant(self): | |||
| # FIXME: fix the type error later, because the type is important maybe cause some bugs | |||
| return self._current_tenant # type: ignore | |||
| @current_tenant.setter | |||
| @@ -64,25 +64,23 @@ class Account(UserMixin, Base): | |||
| def current_tenant_id(self) -> str | None: | |||
| return self._current_tenant.id if self._current_tenant else None | |||
| @current_tenant_id.setter | |||
| def current_tenant_id(self, value: str): | |||
| try: | |||
| tenant_account_join = ( | |||
| def set_tenant_id(self, tenant_id: str): | |||
| tenant_account_join = cast( | |||
| tuple[Tenant, TenantAccountJoin], | |||
| ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .filter(Tenant.id == value) | |||
| .filter(Tenant.id == tenant_id) | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .filter(TenantAccountJoin.account_id == self.id) | |||
| .one_or_none() | |||
| ) | |||
| ), | |||
| ) | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| tenant.current_role = ta.role | |||
| else: | |||
| tenant = None | |||
| except Exception: | |||
| tenant = None | |||
| if not tenant_account_join: | |||
| return | |||
| tenant, join = tenant_account_join | |||
| tenant.current_role = join.role | |||
| self._current_tenant = tenant | |||
| @property | |||
| @@ -191,7 +189,7 @@ class TenantAccountRole(enum.StrEnum): | |||
| } | |||
| class Tenant(db.Model): # type: ignore[name-defined] | |||
| class Tenant(Base): | |||
| __tablename__ = "tenants" | |||
| __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) | |||
| @@ -220,7 +218,7 @@ class Tenant(db.Model): # type: ignore[name-defined] | |||
| self.custom_config = json.dumps(value) | |||
| class TenantAccountJoin(db.Model): # type: ignore[name-defined] | |||
| class TenantAccountJoin(Base): | |||
| __tablename__ = "tenant_account_joins" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), | |||
| @@ -239,7 +237,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined] | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class AccountIntegrate(db.Model): # type: ignore[name-defined] | |||
| class AccountIntegrate(Base): | |||
| __tablename__ = "account_integrates" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), | |||
| @@ -256,7 +254,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined] | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class InvitationCode(db.Model): # type: ignore[name-defined] | |||
| class InvitationCode(Base): | |||
| __tablename__ = "invitation_codes" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), | |||
| @@ -2,6 +2,7 @@ import enum | |||
| from sqlalchemy import func | |||
| from .base import Base | |||
| from .engine import db | |||
| from .types import StringUUID | |||
| @@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum): | |||
| APP_MODERATION_OUTPUT = "app.moderation.output" | |||
| class APIBasedExtension(db.Model): # type: ignore[name-defined] | |||
| class APIBasedExtension(Base): | |||
| __tablename__ = "api_based_extensions" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), | |||
| @@ -22,6 +22,7 @@ from extensions.ext_storage import storage | |||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | |||
| from .account import Account | |||
| from .base import Base | |||
| from .engine import db | |||
| from .model import App, Tag, TagBinding, UploadFile | |||
| from .types import StringUUID | |||
| @@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum): | |||
| PARTIAL_TEAM = "partial_members" | |||
| class Dataset(db.Model): # type: ignore[name-defined] | |||
| class Dataset(Base): | |||
| __tablename__ = "datasets" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_pkey"), | |||
| @@ -255,7 +256,7 @@ class Dataset(db.Model): # type: ignore[name-defined] | |||
| return f"Vector_index_{normalized_dataset_id}_Node" | |||
| class DatasetProcessRule(db.Model): # type: ignore[name-defined] | |||
| class DatasetProcessRule(Base): | |||
| __tablename__ = "dataset_process_rules" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), | |||
| @@ -295,7 +296,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] | |||
| return None | |||
| class Document(db.Model): # type: ignore[name-defined] | |||
| class Document(Base): | |||
| __tablename__ = "documents" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="document_pkey"), | |||
| @@ -635,7 +636,7 @@ class Document(db.Model): # type: ignore[name-defined] | |||
| ) | |||
| class DocumentSegment(db.Model): # type: ignore[name-defined] | |||
| class DocumentSegment(Base): | |||
| __tablename__ = "document_segments" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="document_segment_pkey"), | |||
| @@ -786,7 +787,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] | |||
| return text | |||
| class ChildChunk(db.Model): # type: ignore[name-defined] | |||
| class ChildChunk(Base): | |||
| __tablename__ = "child_chunks" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), | |||
| @@ -829,7 +830,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined] | |||
| return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() | |||
| class AppDatasetJoin(db.Model): # type: ignore[name-defined] | |||
| class AppDatasetJoin(Base): | |||
| __tablename__ = "app_dataset_joins" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), | |||
| @@ -846,7 +847,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined] | |||
| return db.session.get(App, self.app_id) | |||
| class DatasetQuery(db.Model): # type: ignore[name-defined] | |||
| class DatasetQuery(Base): | |||
| __tablename__ = "dataset_queries" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), | |||
| @@ -863,7 +864,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| class DatasetKeywordTable(db.Model): # type: ignore[name-defined] | |||
| class DatasetKeywordTable(Base): | |||
| __tablename__ = "dataset_keyword_tables" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), | |||
| @@ -908,7 +909,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined] | |||
| return None | |||
| class Embedding(db.Model): # type: ignore[name-defined] | |||
| class Embedding(Base): | |||
| __tablename__ = "embeddings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="embedding_pkey"), | |||
| @@ -932,7 +933,7 @@ class Embedding(db.Model): # type: ignore[name-defined] | |||
| return cast(list[float], pickle.loads(self.embedding)) # noqa: S301 | |||
| class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] | |||
| class DatasetCollectionBinding(Base): | |||
| __tablename__ = "dataset_collection_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), | |||
| @@ -947,7 +948,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class TidbAuthBinding(db.Model): # type: ignore[name-defined] | |||
| class TidbAuthBinding(Base): | |||
| __tablename__ = "tidb_auth_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), | |||
| @@ -967,7 +968,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class Whitelist(db.Model): # type: ignore[name-defined] | |||
| class Whitelist(Base): | |||
| __tablename__ = "whitelists" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="whitelists_pkey"), | |||
| @@ -979,7 +980,7 @@ class Whitelist(db.Model): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class DatasetPermission(db.Model): # type: ignore[name-defined] | |||
| class DatasetPermission(Base): | |||
| __tablename__ = "dataset_permissions" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), | |||
| @@ -996,7 +997,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] | |||
| class ExternalKnowledgeApis(Base): | |||
| __tablename__ = "external_knowledge_apis" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), | |||
| @@ -1049,7 +1050,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] | |||
| return dataset_bindings | |||
| class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] | |||
| class ExternalKnowledgeBindings(Base): | |||
| __tablename__ = "external_knowledge_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), | |||
| @@ -1070,7 +1071,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] | |||
| class DatasetAutoDisableLog(Base): | |||
| __tablename__ = "dataset_auto_disable_logs" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), | |||
| @@ -1087,7 +1088,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class RateLimitLog(db.Model): # type: ignore[name-defined] | |||
| class RateLimitLog(Base): | |||
| __tablename__ = "rate_limit_logs" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), | |||
| @@ -1102,7 +1103,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class DatasetMetadata(db.Model): # type: ignore[name-defined] | |||
| class DatasetMetadata(Base): | |||
| __tablename__ = "dataset_metadatas" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), | |||
| @@ -1121,7 +1122,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined] | |||
| updated_by = db.Column(StringUUID, nullable=True) | |||
| class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] | |||
| class DatasetMetadataBinding(Base): | |||
| __tablename__ = "dataset_metadata_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), | |||
| @@ -16,7 +16,7 @@ if TYPE_CHECKING: | |||
| import sqlalchemy as sa | |||
| from flask import request | |||
| from flask_login import UserMixin # type: ignore | |||
| from flask_login import UserMixin | |||
| from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text | |||
| from sqlalchemy.orm import Mapped, Session, mapped_column | |||
| @@ -25,13 +25,13 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS | |||
| from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType | |||
| from core.file import helpers as file_helpers | |||
| from libs.helper import generate_string | |||
| from models.base import Base | |||
| from models.enums import CreatedByRole | |||
| from models.workflow import WorkflowRunStatus | |||
| from .account import Account, Tenant | |||
| from .base import Base | |||
| from .engine import db | |||
| from .enums import CreatedByRole | |||
| from .types import StringUUID | |||
| from .workflow import WorkflowRunStatus | |||
| if TYPE_CHECKING: | |||
| from .workflow import Workflow | |||
| @@ -602,7 +602,7 @@ class InstalledApp(Base): | |||
| return tenant | |||
| class Conversation(db.Model): # type: ignore[name-defined] | |||
| class Conversation(Base): | |||
| __tablename__ = "conversations" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="conversation_pkey"), | |||
| @@ -794,7 +794,7 @@ class Conversation(db.Model): # type: ignore[name-defined] | |||
| for message in messages: | |||
| if message.workflow_run: | |||
| status_counts[message.workflow_run.status] += 1 | |||
| status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1 | |||
| return ( | |||
| { | |||
| @@ -864,7 +864,7 @@ class Conversation(db.Model): # type: ignore[name-defined] | |||
| } | |||
| class Message(db.Model): # type: ignore[name-defined] | |||
| class Message(Base): | |||
| __tablename__ = "messages" | |||
| __table_args__ = ( | |||
| PrimaryKeyConstraint("id", name="message_pkey"), | |||
| @@ -1211,7 +1211,7 @@ class Message(db.Model): # type: ignore[name-defined] | |||
| ) | |||
| class MessageFeedback(db.Model): # type: ignore[name-defined] | |||
| class MessageFeedback(Base): | |||
| __tablename__ = "message_feedbacks" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), | |||
| @@ -1238,7 +1238,7 @@ class MessageFeedback(db.Model): # type: ignore[name-defined] | |||
| return account | |||
| class MessageFile(db.Model): # type: ignore[name-defined] | |||
| class MessageFile(Base): | |||
| __tablename__ = "message_files" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="message_file_pkey"), | |||
| @@ -1279,7 +1279,7 @@ class MessageFile(db.Model): # type: ignore[name-defined] | |||
| created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class MessageAnnotation(db.Model): # type: ignore[name-defined] | |||
| class MessageAnnotation(Base): | |||
| __tablename__ = "message_annotations" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), | |||
| @@ -1310,7 +1310,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined] | |||
| return account | |||
| class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] | |||
| class AppAnnotationHitHistory(Base): | |||
| __tablename__ = "app_annotation_hit_histories" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), | |||
| @@ -1322,7 +1322,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| annotation_id = db.Column(StringUUID, nullable=False) | |||
| annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False) | |||
| source = db.Column(db.Text, nullable=False) | |||
| question = db.Column(db.Text, nullable=False) | |||
| account_id = db.Column(StringUUID, nullable=False) | |||
| @@ -1348,7 +1348,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] | |||
| return account | |||
| class AppAnnotationSetting(db.Model): # type: ignore[name-defined] | |||
| class AppAnnotationSetting(Base): | |||
| __tablename__ = "app_annotation_settings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), | |||
| @@ -1364,26 +1364,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined] | |||
| updated_user_id = db.Column(StringUUID, nullable=False) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| @property | |||
| def created_account(self): | |||
| account = ( | |||
| db.session.query(Account) | |||
| .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) | |||
| .filter(AppAnnotationSetting.id == self.annotation_id) | |||
| .first() | |||
| ) | |||
| return account | |||
| @property | |||
| def updated_account(self): | |||
| account = ( | |||
| db.session.query(Account) | |||
| .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) | |||
| .filter(AppAnnotationSetting.id == self.annotation_id) | |||
| .first() | |||
| ) | |||
| return account | |||
| @property | |||
| def collection_binding_detail(self): | |||
| from .dataset import DatasetCollectionBinding | |||
| @@ -2,8 +2,7 @@ from enum import Enum | |||
| from sqlalchemy import func | |||
| from models.base import Base | |||
| from .base import Base | |||
| from .engine import db | |||
| from .types import StringUUID | |||
| @@ -9,7 +9,7 @@ from .engine import db | |||
| from .types import StringUUID | |||
| class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] | |||
| class DataSourceOauthBinding(Base): | |||
| __tablename__ = "data_source_oauth_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="source_binding_pkey"), | |||
| @@ -9,7 +9,7 @@ if TYPE_CHECKING: | |||
| from models.model import AppMode | |||
| import sqlalchemy as sa | |||
| from sqlalchemy import Index, PrimaryKeyConstraint, func | |||
| from sqlalchemy import func | |||
| from sqlalchemy.orm import Mapped, mapped_column | |||
| import contexts | |||
| @@ -18,11 +18,11 @@ from core.helper import encrypter | |||
| from core.variables import SecretVariable, Variable | |||
| from factories import variable_factory | |||
| from libs import helper | |||
| from models.base import Base | |||
| from models.enums import CreatedByRole | |||
| from .account import Account | |||
| from .base import Base | |||
| from .engine import db | |||
| from .enums import CreatedByRole | |||
| from .types import StringUUID | |||
| if TYPE_CHECKING: | |||
| @@ -768,17 +768,12 @@ class WorkflowAppLog(Base): | |||
| class ConversationVariable(Base): | |||
| __tablename__ = "workflow_conversation_variables" | |||
| __table_args__ = ( | |||
| PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"), | |||
| Index("workflow__conversation_variables_app_id_idx", "app_id"), | |||
| Index("workflow__conversation_variables_created_at_idx", "created_at"), | |||
| ) | |||
| id: Mapped[str] = mapped_column(StringUUID, primary_key=True) | |||
| conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True) | |||
| app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) | |||
| app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) | |||
| data = mapped_column(db.Text, nullable=False) | |||
| created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True) | |||
| updated_at = mapped_column( | |||
| db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() | |||
| ) | |||
| @@ -110,7 +110,7 @@ class AccountService: | |||
| current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() | |||
| if current_tenant: | |||
| account.current_tenant_id = current_tenant.tenant_id | |||
| account.set_tenant_id(current_tenant.tenant_id) | |||
| else: | |||
| available_ta = ( | |||
| TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() | |||
| @@ -118,7 +118,7 @@ class AccountService: | |||
| if not available_ta: | |||
| return None | |||
| account.current_tenant_id = available_ta.tenant_id | |||
| account.set_tenant_id(available_ta.tenant_id) | |||
| available_ta.current = True | |||
| db.session.commit() | |||
| @@ -700,7 +700,7 @@ class TenantService: | |||
| ).update({"current": False}) | |||
| tenant_account_join.current = True | |||
| # Set the current tenant for the account | |||
| account.current_tenant_id = tenant_account_join.tenant_id | |||
| account.set_tenant_id(tenant_account_join.tenant_id) | |||
| db.session.commit() | |||
| @staticmethod | |||