Selaa lähdekoodia

refactor(models): Use the SQLAlchemy base model. (#19435)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.4.0
-LAN- 5 kuukautta sitten
vanhempi
commit
792b321a81
No account linked to committer's email address

+ 16
- 18
api/models/account.py Näytä tiedosto

@@ -1,5 +1,6 @@
import enum
import json
from typing import cast

from flask_login import UserMixin # type: ignore
from sqlalchemy import func
@@ -46,7 +47,6 @@ class Account(UserMixin, Base):

@property
def current_tenant(self):
# FIXME: fix the type error later, because the type is important maybe cause some bugs
return self._current_tenant # type: ignore

@current_tenant.setter
@@ -64,25 +64,23 @@ class Account(UserMixin, Base):
def current_tenant_id(self) -> str | None:
return self._current_tenant.id if self._current_tenant else None

@current_tenant_id.setter
def current_tenant_id(self, value: str):
try:
tenant_account_join = (
def set_tenant_id(self, tenant_id: str):
tenant_account_join = cast(
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == value)
.filter(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id)
.one_or_none()
)
),
)

if tenant_account_join:
tenant, ta = tenant_account_join
tenant.current_role = ta.role
else:
tenant = None
except Exception:
tenant = None
if not tenant_account_join:
return

tenant, join = tenant_account_join
tenant.current_role = join.role
self._current_tenant = tenant

@property
@@ -191,7 +189,7 @@ class TenantAccountRole(enum.StrEnum):
}


class Tenant(db.Model): # type: ignore[name-defined]
class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)

@@ -220,7 +218,7 @@ class Tenant(db.Model): # type: ignore[name-defined]
self.custom_config = json.dumps(value)


class TenantAccountJoin(db.Model): # type: ignore[name-defined]
class TenantAccountJoin(Base):
__tablename__ = "tenant_account_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@@ -239,7 +237,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class AccountIntegrate(db.Model): # type: ignore[name-defined]
class AccountIntegrate(Base):
__tablename__ = "account_integrates"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@@ -256,7 +254,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class InvitationCode(db.Model): # type: ignore[name-defined]
class InvitationCode(Base):
__tablename__ = "invitation_codes"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),

+ 2
- 1
api/models/api_based_extension.py Näytä tiedosto

@@ -2,6 +2,7 @@ import enum

from sqlalchemy import func

from .base import Base
from .engine import db
from .types import StringUUID

@@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum):
APP_MODERATION_OUTPUT = "app.moderation.output"


class APIBasedExtension(db.Model): # type: ignore[name-defined]
class APIBasedExtension(Base):
__tablename__ = "api_based_extensions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),

+ 20
- 19
api/models/dataset.py Näytä tiedosto

@@ -22,6 +22,7 @@ from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule

from .account import Account
from .base import Base
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
@@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum):
PARTIAL_TEAM = "partial_members"


class Dataset(db.Model): # type: ignore[name-defined]
class Dataset(Base):
__tablename__ = "datasets"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"),
@@ -255,7 +256,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
return f"Vector_index_{normalized_dataset_id}_Node"


class DatasetProcessRule(db.Model): # type: ignore[name-defined]
class DatasetProcessRule(Base):
__tablename__ = "dataset_process_rules"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
@@ -295,7 +296,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
return None


class Document(db.Model): # type: ignore[name-defined]
class Document(Base):
__tablename__ = "documents"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"),
@@ -635,7 +636,7 @@ class Document(db.Model): # type: ignore[name-defined]
)


class DocumentSegment(db.Model): # type: ignore[name-defined]
class DocumentSegment(Base):
__tablename__ = "document_segments"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@@ -786,7 +787,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text


class ChildChunk(db.Model): # type: ignore[name-defined]
class ChildChunk(Base):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@@ -829,7 +830,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined]
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()


class AppDatasetJoin(db.Model): # type: ignore[name-defined]
class AppDatasetJoin(Base):
__tablename__ = "app_dataset_joins"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
@@ -846,7 +847,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined]
return db.session.get(App, self.app_id)


class DatasetQuery(db.Model): # type: ignore[name-defined]
class DatasetQuery(Base):
__tablename__ = "dataset_queries"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
@@ -863,7 +864,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())


class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
class DatasetKeywordTable(Base):
__tablename__ = "dataset_keyword_tables"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
@@ -908,7 +909,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return None


class Embedding(db.Model): # type: ignore[name-defined]
class Embedding(Base):
__tablename__ = "embeddings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"),
@@ -932,7 +933,7 @@ class Embedding(db.Model): # type: ignore[name-defined]
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301


class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
class DatasetCollectionBinding(Base):
__tablename__ = "dataset_collection_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
@@ -947,7 +948,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class TidbAuthBinding(db.Model): # type: ignore[name-defined]
class TidbAuthBinding(Base):
__tablename__ = "tidb_auth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@@ -967,7 +968,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class Whitelist(db.Model): # type: ignore[name-defined]
class Whitelist(Base):
__tablename__ = "whitelists"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
@@ -979,7 +980,7 @@ class Whitelist(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class DatasetPermission(db.Model): # type: ignore[name-defined]
class DatasetPermission(Base):
__tablename__ = "dataset_permissions"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
@@ -996,7 +997,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
class ExternalKnowledgeApis(Base):
__tablename__ = "external_knowledge_apis"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
@@ -1049,7 +1050,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
return dataset_bindings


class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
class ExternalKnowledgeBindings(Base):
__tablename__ = "external_knowledge_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
@@ -1070,7 +1071,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
class DatasetAutoDisableLog(Base):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
@@ -1087,7 +1088,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))


class RateLimitLog(db.Model): # type: ignore[name-defined]
class RateLimitLog(Base):
__tablename__ = "rate_limit_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
@@ -1102,7 +1103,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))


class DatasetMetadata(db.Model): # type: ignore[name-defined]
class DatasetMetadata(Base):
__tablename__ = "dataset_metadatas"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
@@ -1121,7 +1122,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined]
updated_by = db.Column(StringUUID, nullable=True)


class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
class DatasetMetadataBinding(Base):
__tablename__ = "dataset_metadata_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),

+ 13
- 33
api/models/model.py Näytä tiedosto

@@ -16,7 +16,7 @@ if TYPE_CHECKING:

import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore
from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column

@@ -25,13 +25,13 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from libs.helper import generate_string
from models.base import Base
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus

from .account import Account, Tenant
from .base import Base
from .engine import db
from .enums import CreatedByRole
from .types import StringUUID
from .workflow import WorkflowRunStatus

if TYPE_CHECKING:
from .workflow import Workflow
@@ -602,7 +602,7 @@ class InstalledApp(Base):
return tenant


class Conversation(db.Model): # type: ignore[name-defined]
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"),
@@ -794,7 +794,7 @@ class Conversation(db.Model): # type: ignore[name-defined]

for message in messages:
if message.workflow_run:
status_counts[message.workflow_run.status] += 1
status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1

return (
{
@@ -864,7 +864,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
}


class Message(db.Model): # type: ignore[name-defined]
class Message(Base):
__tablename__ = "messages"
__table_args__ = (
PrimaryKeyConstraint("id", name="message_pkey"),
@@ -1211,7 +1211,7 @@ class Message(db.Model): # type: ignore[name-defined]
)


class MessageFeedback(db.Model): # type: ignore[name-defined]
class MessageFeedback(Base):
__tablename__ = "message_feedbacks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@@ -1238,7 +1238,7 @@ class MessageFeedback(db.Model): # type: ignore[name-defined]
return account


class MessageFile(db.Model): # type: ignore[name-defined]
class MessageFile(Base):
__tablename__ = "message_files"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"),
@@ -1279,7 +1279,7 @@ class MessageFile(db.Model): # type: ignore[name-defined]
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class MessageAnnotation(db.Model): # type: ignore[name-defined]
class MessageAnnotation(Base):
__tablename__ = "message_annotations"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@@ -1310,7 +1310,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined]
return account


class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@@ -1322,7 +1322,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]

id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
annotation_id = db.Column(StringUUID, nullable=False)
annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False)
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
@@ -1348,7 +1348,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
return account


class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@@ -1364,26 +1364,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@property
def created_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account

@property
def updated_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account

@property
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding

+ 1
- 2
api/models/provider.py Näytä tiedosto

@@ -2,8 +2,7 @@ from enum import Enum

from sqlalchemy import func

from models.base import Base

from .base import Base
from .engine import db
from .types import StringUUID


+ 1
- 1
api/models/source.py Näytä tiedosto

@@ -9,7 +9,7 @@ from .engine import db
from .types import StringUUID


class DataSourceOauthBinding(db.Model): # type: ignore[name-defined]
class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"),

+ 5
- 10
api/models/workflow.py Näytä tiedosto

@@ -9,7 +9,7 @@ if TYPE_CHECKING:
from models.model import AppMode

import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, func
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column

import contexts
@@ -18,11 +18,11 @@ from core.helper import encrypter
from core.variables import SecretVariable, Variable
from factories import variable_factory
from libs import helper
from models.base import Base
from models.enums import CreatedByRole

from .account import Account
from .base import Base
from .engine import db
from .enums import CreatedByRole
from .types import StringUUID

if TYPE_CHECKING:
@@ -768,17 +768,12 @@ class WorkflowAppLog(Base):

class ConversationVariable(Base):
__tablename__ = "workflow_conversation_variables"
__table_args__ = (
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
Index("workflow__conversation_variables_app_id_idx", "app_id"),
Index("workflow__conversation_variables_created_at_idx", "created_at"),
)

id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True)
updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

+ 3
- 3
api/services/account_service.py Näytä tiedosto

@@ -110,7 +110,7 @@ class AccountService:

current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
if current_tenant:
account.current_tenant_id = current_tenant.tenant_id
account.set_tenant_id(current_tenant.tenant_id)
else:
available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
@@ -118,7 +118,7 @@ class AccountService:
if not available_ta:
return None

account.current_tenant_id = available_ta.tenant_id
account.set_tenant_id(available_ta.tenant_id)
available_ta.current = True
db.session.commit()

@@ -700,7 +700,7 @@ class TenantService:
).update({"current": False})
tenant_account_join.current = True
# Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id
account.set_tenant_id(tenant_account_join.tenant_id)
db.session.commit()

@staticmethod

Loading…
Peruuta
Tallenna