Co-authored-by: -LAN- <laipz8200@outlook.com>tags/0.8.0
| @@ -4,7 +4,7 @@ from .model import App, AppMode, Message | |||
| from .types import StringUUID | |||
| from .workflow import ConversationVariable, Workflow, WorkflowNodeExecutionStatus | |||
| __all__ = ['ConversationVariable', 'StringUUID', 'AppMode', 'WorkflowNodeExecutionStatus', 'Workflow', 'App', 'Message'] | |||
| __all__ = ["ConversationVariable", "StringUUID", "AppMode", "WorkflowNodeExecutionStatus", "Workflow", "App", "Message"] | |||
| class CreatedByRole(Enum): | |||
| @@ -12,11 +12,11 @@ class CreatedByRole(Enum): | |||
| Enum class for createdByRole | |||
| """ | |||
| ACCOUNT = 'account' | |||
| END_USER = 'end_user' | |||
| ACCOUNT = "account" | |||
| END_USER = "end_user" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'CreatedByRole': | |||
| def value_of(cls, value: str) -> "CreatedByRole": | |||
| """ | |||
| Get value of given mode. | |||
| @@ -26,4 +26,4 @@ class CreatedByRole(Enum): | |||
| for role in cls: | |||
| if role.value == value: | |||
| return role | |||
| raise ValueError(f'invalid createdByRole value {value}') | |||
| raise ValueError(f"invalid createdByRole value {value}") | |||
| @@ -9,21 +9,18 @@ from .types import StringUUID | |||
| class AccountStatus(str, enum.Enum): | |||
| PENDING = 'pending' | |||
| UNINITIALIZED = 'uninitialized' | |||
| ACTIVE = 'active' | |||
| BANNED = 'banned' | |||
| CLOSED = 'closed' | |||
| PENDING = "pending" | |||
| UNINITIALIZED = "uninitialized" | |||
| ACTIVE = "active" | |||
| BANNED = "banned" | |||
| CLOSED = "closed" | |||
| class Account(UserMixin, db.Model): | |||
| __tablename__ = 'accounts' | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='account_pkey'), | |||
| db.Index('account_email_idx', 'email') | |||
| ) | |||
| __tablename__ = "accounts" | |||
| __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| email = db.Column(db.String(255), nullable=False) | |||
| password = db.Column(db.String(255), nullable=True) | |||
| @@ -34,11 +31,11 @@ class Account(UserMixin, db.Model): | |||
| timezone = db.Column(db.String(255)) | |||
| last_login_at = db.Column(db.DateTime) | |||
| last_login_ip = db.Column(db.String(255)) | |||
| last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) | |||
| initialized_at = db.Column(db.DateTime) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def is_password_set(self): | |||
| @@ -65,11 +62,13 @@ class Account(UserMixin, db.Model): | |||
| @current_tenant_id.setter | |||
| def current_tenant_id(self, value: str): | |||
| try: | |||
| tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ | |||
| .filter(Tenant.id == value) \ | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | |||
| .filter(TenantAccountJoin.account_id == self.id) \ | |||
| tenant_account_join = ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .filter(Tenant.id == value) | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .filter(TenantAccountJoin.account_id == self.id) | |||
| .one_or_none() | |||
| ) | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| @@ -91,20 +90,18 @@ class Account(UserMixin, db.Model): | |||
| @classmethod | |||
| def get_by_openid(cls, provider: str, open_id: str) -> db.Model: | |||
| account_integrate = db.session.query(AccountIntegrate). \ | |||
| filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id). \ | |||
| one_or_none() | |||
| account_integrate = ( | |||
| db.session.query(AccountIntegrate) | |||
| .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) | |||
| .one_or_none() | |||
| ) | |||
| if account_integrate: | |||
| return db.session.query(Account). \ | |||
| filter(Account.id == account_integrate.account_id). \ | |||
| one_or_none() | |||
| return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() | |||
| return None | |||
| def get_integrates(self) -> list[db.Model]: | |||
| ai = db.Model | |||
| return db.session.query(ai).filter( | |||
| ai.account_id == self.id | |||
| ).all() | |||
| return db.session.query(ai).filter(ai.account_id == self.id).all() | |||
| # check current_user.current_tenant.current_role in ['admin', 'owner'] | |||
| @property | |||
| @@ -123,61 +120,75 @@ class Account(UserMixin, db.Model): | |||
| def is_dataset_operator(self): | |||
| return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR | |||
| class TenantStatus(str, enum.Enum): | |||
| NORMAL = 'normal' | |||
| ARCHIVE = 'archive' | |||
| NORMAL = "normal" | |||
| ARCHIVE = "archive" | |||
| class TenantAccountRole(str, enum.Enum): | |||
| OWNER = 'owner' | |||
| ADMIN = 'admin' | |||
| EDITOR = 'editor' | |||
| NORMAL = 'normal' | |||
| DATASET_OPERATOR = 'dataset_operator' | |||
| OWNER = "owner" | |||
| ADMIN = "admin" | |||
| EDITOR = "editor" | |||
| NORMAL = "normal" | |||
| DATASET_OPERATOR = "dataset_operator" | |||
| @staticmethod | |||
| def is_valid_role(role: str) -> bool: | |||
| return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, | |||
| TenantAccountRole.NORMAL, TenantAccountRole.DATASET_OPERATOR} | |||
| return role and role in { | |||
| TenantAccountRole.OWNER, | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.NORMAL, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| @staticmethod | |||
| def is_privileged_role(role: str) -> bool: | |||
| return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} | |||
| @staticmethod | |||
| def is_non_owner_role(role: str) -> bool: | |||
| return role and role in {TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, TenantAccountRole.NORMAL, | |||
| TenantAccountRole.DATASET_OPERATOR} | |||
| return role and role in { | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.NORMAL, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| @staticmethod | |||
| def is_editing_role(role: str) -> bool: | |||
| return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} | |||
| @staticmethod | |||
| def is_dataset_edit_role(role: str) -> bool: | |||
| return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR, | |||
| TenantAccountRole.DATASET_OPERATOR} | |||
| return role and role in { | |||
| TenantAccountRole.OWNER, | |||
| TenantAccountRole.ADMIN, | |||
| TenantAccountRole.EDITOR, | |||
| TenantAccountRole.DATASET_OPERATOR, | |||
| } | |||
| class Tenant(db.Model): | |||
| __tablename__ = 'tenants' | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tenant_pkey'), | |||
| ) | |||
| __tablename__ = "tenants" | |||
| __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| encrypt_public_key = db.Column(db.Text) | |||
| plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) | |||
| status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) | |||
| custom_config = db.Column(db.Text) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| def get_accounts(self) -> list[Account]: | |||
| return db.session.query(Account).filter( | |||
| Account.id == TenantAccountJoin.account_id, | |||
| TenantAccountJoin.tenant_id == self.id | |||
| ).all() | |||
| return ( | |||
| db.session.query(Account) | |||
| .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) | |||
| .all() | |||
| ) | |||
| @property | |||
| def custom_config_dict(self) -> dict: | |||
| @@ -189,54 +200,54 @@ class Tenant(db.Model): | |||
| class TenantAccountJoinRole(enum.Enum): | |||
| OWNER = 'owner' | |||
| ADMIN = 'admin' | |||
| NORMAL = 'normal' | |||
| DATASET_OPERATOR = 'dataset_operator' | |||
| OWNER = "owner" | |||
| ADMIN = "admin" | |||
| NORMAL = "normal" | |||
| DATASET_OPERATOR = "dataset_operator" | |||
| class TenantAccountJoin(db.Model): | |||
| __tablename__ = 'tenant_account_joins' | |||
| __tablename__ = "tenant_account_joins" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tenant_account_join_pkey'), | |||
| db.Index('tenant_account_join_account_id_idx', 'account_id'), | |||
| db.Index('tenant_account_join_tenant_id_idx', 'tenant_id'), | |||
| db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') | |||
| db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), | |||
| db.Index("tenant_account_join_account_id_idx", "account_id"), | |||
| db.Index("tenant_account_join_tenant_id_idx", "tenant_id"), | |||
| db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| account_id = db.Column(StringUUID, nullable=False) | |||
| current = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) | |||
| role = db.Column(db.String(16), nullable=False, server_default='normal') | |||
| current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| role = db.Column(db.String(16), nullable=False, server_default="normal") | |||
| invited_by = db.Column(StringUUID, nullable=True) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class AccountIntegrate(db.Model): | |||
| __tablename__ = 'account_integrates' | |||
| __tablename__ = "account_integrates" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='account_integrate_pkey'), | |||
| db.UniqueConstraint('account_id', 'provider', name='unique_account_provider'), | |||
| db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') | |||
| db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), | |||
| db.UniqueConstraint("account_id", "provider", name="unique_account_provider"), | |||
| db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| account_id = db.Column(StringUUID, nullable=False) | |||
| provider = db.Column(db.String(16), nullable=False) | |||
| open_id = db.Column(db.String(255), nullable=False) | |||
| encrypted_token = db.Column(db.String(255), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class InvitationCode(db.Model): | |||
| __tablename__ = 'invitation_codes' | |||
| __tablename__ = "invitation_codes" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='invitation_code_pkey'), | |||
| db.Index('invitation_codes_batch_idx', 'batch'), | |||
| db.Index('invitation_codes_code_idx', 'code', 'status') | |||
| db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), | |||
| db.Index("invitation_codes_batch_idx", "batch"), | |||
| db.Index("invitation_codes_code_idx", "code", "status"), | |||
| ) | |||
| id = db.Column(db.Integer, nullable=False) | |||
| @@ -247,4 +258,4 @@ class InvitationCode(db.Model): | |||
| used_by_tenant_id = db.Column(StringUUID) | |||
| used_by_account_id = db.Column(StringUUID) | |||
| deprecated_at = db.Column(db.DateTime) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @@ -6,22 +6,22 @@ from .types import StringUUID | |||
| class APIBasedExtensionPoint(enum.Enum): | |||
| APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query' | |||
| PING = 'ping' | |||
| APP_MODERATION_INPUT = 'app.moderation.input' | |||
| APP_MODERATION_OUTPUT = 'app.moderation.output' | |||
| APP_EXTERNAL_DATA_TOOL_QUERY = "app.external_data_tool.query" | |||
| PING = "ping" | |||
| APP_MODERATION_INPUT = "app.moderation.input" | |||
| APP_MODERATION_OUTPUT = "app.moderation.output" | |||
| class APIBasedExtension(db.Model): | |||
| __tablename__ = 'api_based_extensions' | |||
| __tablename__ = "api_based_extensions" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'), | |||
| db.Index('api_based_extension_tenant_idx', 'tenant_id'), | |||
| db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), | |||
| db.Index("api_based_extension_tenant_idx", "tenant_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| api_endpoint = db.Column(db.String(255), nullable=False) | |||
| api_key = db.Column(db.Text, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @@ -24,37 +24,34 @@ from .types import StringUUID | |||
| class DatasetPermissionEnum(str, enum.Enum): | |||
| ONLY_ME = 'only_me' | |||
| ALL_TEAM = 'all_team_members' | |||
| PARTIAL_TEAM = 'partial_members' | |||
| ONLY_ME = "only_me" | |||
| ALL_TEAM = "all_team_members" | |||
| PARTIAL_TEAM = "partial_members" | |||
| class Dataset(db.Model): | |||
| __tablename__ = 'datasets' | |||
| __tablename__ = "datasets" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_pkey'), | |||
| db.Index('dataset_tenant_idx', 'tenant_id'), | |||
| db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin') | |||
| db.PrimaryKeyConstraint("id", name="dataset_pkey"), | |||
| db.Index("dataset_tenant_idx", "tenant_id"), | |||
| db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), | |||
| ) | |||
| INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None] | |||
| INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| description = db.Column(db.Text, nullable=True) | |||
| provider = db.Column(db.String(255), nullable=False, | |||
| server_default=db.text("'vendor'::character varying")) | |||
| permission = db.Column(db.String(255), nullable=False, | |||
| server_default=db.text("'only_me'::character varying")) | |||
| provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) | |||
| permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) | |||
| data_source_type = db.Column(db.String(255)) | |||
| indexing_technique = db.Column(db.String(255), nullable=True) | |||
| index_struct = db.Column(db.Text, nullable=True) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, | |||
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_by = db.Column(StringUUID, nullable=True) | |||
| updated_at = db.Column(db.DateTime, nullable=False, | |||
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| embedding_model = db.Column(db.String(255), nullable=True) | |||
| embedding_model_provider = db.Column(db.String(255), nullable=True) | |||
| collection_binding_id = db.Column(StringUUID, nullable=True) | |||
| @@ -62,8 +59,9 @@ class Dataset(db.Model): | |||
| @property | |||
| def dataset_keyword_table(self): | |||
| dataset_keyword_table = db.session.query(DatasetKeywordTable).filter( | |||
| DatasetKeywordTable.dataset_id == self.id).first() | |||
| dataset_keyword_table = ( | |||
| db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() | |||
| ) | |||
| if dataset_keyword_table: | |||
| return dataset_keyword_table | |||
| @@ -79,13 +77,19 @@ class Dataset(db.Model): | |||
| @property | |||
| def latest_process_rule(self): | |||
| return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \ | |||
| .order_by(DatasetProcessRule.created_at.desc()).first() | |||
| return ( | |||
| DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) | |||
| .order_by(DatasetProcessRule.created_at.desc()) | |||
| .first() | |||
| ) | |||
| @property | |||
| def app_count(self): | |||
| return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id, | |||
| App.id == AppDatasetJoin.app_id).scalar() | |||
| return ( | |||
| db.session.query(func.count(AppDatasetJoin.id)) | |||
| .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) | |||
| .scalar() | |||
| ) | |||
| @property | |||
| def document_count(self): | |||
| @@ -93,30 +97,40 @@ class Dataset(db.Model): | |||
| @property | |||
| def available_document_count(self): | |||
| return db.session.query(func.count(Document.id)).filter( | |||
| Document.dataset_id == self.id, | |||
| Document.indexing_status == 'completed', | |||
| Document.enabled == True, | |||
| Document.archived == False | |||
| ).scalar() | |||
| return ( | |||
| db.session.query(func.count(Document.id)) | |||
| .filter( | |||
| Document.dataset_id == self.id, | |||
| Document.indexing_status == "completed", | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .scalar() | |||
| ) | |||
| @property | |||
| def available_segment_count(self): | |||
| return db.session.query(func.count(DocumentSegment.id)).filter( | |||
| DocumentSegment.dataset_id == self.id, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.enabled == True | |||
| ).scalar() | |||
| return ( | |||
| db.session.query(func.count(DocumentSegment.id)) | |||
| .filter( | |||
| DocumentSegment.dataset_id == self.id, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| ) | |||
| .scalar() | |||
| ) | |||
| @property | |||
| def word_count(self): | |||
| return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ | |||
| .filter(Document.dataset_id == self.id).scalar() | |||
| return ( | |||
| Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) | |||
| .filter(Document.dataset_id == self.id) | |||
| .scalar() | |||
| ) | |||
| @property | |||
| def doc_form(self): | |||
| document = db.session.query(Document).filter( | |||
| Document.dataset_id == self.id).first() | |||
| document = db.session.query(Document).filter(Document.dataset_id == self.id).first() | |||
| if document: | |||
| return document.doc_form | |||
| return None | |||
| @@ -124,76 +138,68 @@ class Dataset(db.Model): | |||
| @property | |||
| def retrieval_model_dict(self): | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| "reranking_enable": False, | |||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||
| "top_k": 2, | |||
| "score_threshold_enabled": False, | |||
| } | |||
| return self.retrieval_model if self.retrieval_model else default_retrieval_model | |||
| @property | |||
| def tags(self): | |||
| tags = db.session.query(Tag).join( | |||
| TagBinding, | |||
| Tag.id == TagBinding.tag_id | |||
| ).filter( | |||
| TagBinding.target_id == self.id, | |||
| TagBinding.tenant_id == self.tenant_id, | |||
| Tag.tenant_id == self.tenant_id, | |||
| Tag.type == 'knowledge' | |||
| ).all() | |||
| tags = ( | |||
| db.session.query(Tag) | |||
| .join(TagBinding, Tag.id == TagBinding.tag_id) | |||
| .filter( | |||
| TagBinding.target_id == self.id, | |||
| TagBinding.tenant_id == self.tenant_id, | |||
| Tag.tenant_id == self.tenant_id, | |||
| Tag.type == "knowledge", | |||
| ) | |||
| .all() | |||
| ) | |||
| return tags if tags else [] | |||
| @staticmethod | |||
| def gen_collection_name_by_id(dataset_id: str) -> str: | |||
| normalized_dataset_id = dataset_id.replace("-", "_") | |||
| return f'Vector_index_{normalized_dataset_id}_Node' | |||
| return f"Vector_index_{normalized_dataset_id}_Node" | |||
| class DatasetProcessRule(db.Model): | |||
| __tablename__ = 'dataset_process_rules' | |||
| __tablename__ = "dataset_process_rules" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'), | |||
| db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'), | |||
| db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), | |||
| db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), | |||
| ) | |||
| id = db.Column(StringUUID, nullable=False, | |||
| server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| mode = db.Column(db.String(255), nullable=False, | |||
| server_default=db.text("'automatic'::character varying")) | |||
| mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) | |||
| rules = db.Column(db.Text, nullable=True) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, | |||
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| MODES = ['automatic', 'custom'] | |||
| PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails'] | |||
| MODES = ["automatic", "custom"] | |||
| PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] | |||
| AUTOMATIC_RULES = { | |||
| 'pre_processing_rules': [ | |||
| {'id': 'remove_extra_spaces', 'enabled': True}, | |||
| {'id': 'remove_urls_emails', 'enabled': False} | |||
| "pre_processing_rules": [ | |||
| {"id": "remove_extra_spaces", "enabled": True}, | |||
| {"id": "remove_urls_emails", "enabled": False}, | |||
| ], | |||
| 'segmentation': { | |||
| 'delimiter': '\n', | |||
| 'max_tokens': 500, | |||
| 'chunk_overlap': 50 | |||
| } | |||
| "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | |||
| } | |||
| def to_dict(self): | |||
| return { | |||
| 'id': self.id, | |||
| 'dataset_id': self.dataset_id, | |||
| 'mode': self.mode, | |||
| 'rules': self.rules_dict, | |||
| 'created_by': self.created_by, | |||
| 'created_at': self.created_at, | |||
| "id": self.id, | |||
| "dataset_id": self.dataset_id, | |||
| "mode": self.mode, | |||
| "rules": self.rules_dict, | |||
| "created_by": self.created_by, | |||
| "created_at": self.created_at, | |||
| } | |||
| @property | |||
| @@ -205,17 +211,16 @@ class DatasetProcessRule(db.Model): | |||
| class Document(db.Model): | |||
| __tablename__ = 'documents' | |||
| __tablename__ = "documents" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='document_pkey'), | |||
| db.Index('document_dataset_id_idx', 'dataset_id'), | |||
| db.Index('document_is_paused_idx', 'is_paused'), | |||
| db.Index('document_tenant_idx', 'tenant_id'), | |||
| db.PrimaryKeyConstraint("id", name="document_pkey"), | |||
| db.Index("document_dataset_id_idx", "dataset_id"), | |||
| db.Index("document_is_paused_idx", "is_paused"), | |||
| db.Index("document_tenant_idx", "tenant_id"), | |||
| ) | |||
| # initial fields | |||
| id = db.Column(StringUUID, nullable=False, | |||
| server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| position = db.Column(db.Integer, nullable=False) | |||
| @@ -227,8 +232,7 @@ class Document(db.Model): | |||
| created_from = db.Column(db.String(255), nullable=False) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_api_request_id = db.Column(StringUUID, nullable=True) | |||
| created_at = db.Column(db.DateTime, nullable=False, | |||
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| # start processing | |||
| processing_started_at = db.Column(db.DateTime, nullable=True) | |||
| @@ -250,7 +254,7 @@ class Document(db.Model): | |||
| completed_at = db.Column(db.DateTime, nullable=True) | |||
| # pause | |||
| is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) | |||
| is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) | |||
| paused_by = db.Column(StringUUID, nullable=True) | |||
| paused_at = db.Column(db.DateTime, nullable=True) | |||
| @@ -259,44 +263,39 @@ class Document(db.Model): | |||
| stopped_at = db.Column(db.DateTime, nullable=True) | |||
| # basic fields | |||
| indexing_status = db.Column(db.String( | |||
| 255), nullable=False, server_default=db.text("'waiting'::character varying")) | |||
| enabled = db.Column(db.Boolean, nullable=False, | |||
| server_default=db.text('true')) | |||
| indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) | |||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||
| disabled_at = db.Column(db.DateTime, nullable=True) | |||
| disabled_by = db.Column(StringUUID, nullable=True) | |||
| archived = db.Column(db.Boolean, nullable=False, | |||
| server_default=db.text('false')) | |||
| archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| archived_reason = db.Column(db.String(255), nullable=True) | |||
| archived_by = db.Column(StringUUID, nullable=True) | |||
| archived_at = db.Column(db.DateTime, nullable=True) | |||
| updated_at = db.Column(db.DateTime, nullable=False, | |||
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| doc_type = db.Column(db.String(40), nullable=True) | |||
| doc_metadata = db.Column(db.JSON, nullable=True) | |||
| doc_form = db.Column(db.String( | |||
| 255), nullable=False, server_default=db.text("'text_model'::character varying")) | |||
| doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) | |||
| doc_language = db.Column(db.String(255), nullable=True) | |||
| DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl'] | |||
| DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] | |||
| @property | |||
| def display_status(self): | |||
| status = None | |||
| if self.indexing_status == 'waiting': | |||
| status = 'queuing' | |||
| elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused: | |||
| status = 'paused' | |||
| elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']: | |||
| status = 'indexing' | |||
| elif self.indexing_status == 'error': | |||
| status = 'error' | |||
| elif self.indexing_status == 'completed' and not self.archived and self.enabled: | |||
| status = 'available' | |||
| elif self.indexing_status == 'completed' and not self.archived and not self.enabled: | |||
| status = 'disabled' | |||
| elif self.indexing_status == 'completed' and self.archived: | |||
| status = 'archived' | |||
| if self.indexing_status == "waiting": | |||
| status = "queuing" | |||
| elif self.indexing_status not in ["completed", "error", "waiting"] and self.is_paused: | |||
| status = "paused" | |||
| elif self.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: | |||
| status = "indexing" | |||
| elif self.indexing_status == "error": | |||
| status = "error" | |||
| elif self.indexing_status == "completed" and not self.archived and self.enabled: | |||
| status = "available" | |||
| elif self.indexing_status == "completed" and not self.archived and not self.enabled: | |||
| status = "disabled" | |||
| elif self.indexing_status == "completed" and self.archived: | |||
| status = "archived" | |||
| return status | |||
| @property | |||
| @@ -313,24 +312,26 @@ class Document(db.Model): | |||
| @property | |||
| def data_source_detail_dict(self): | |||
| if self.data_source_info: | |||
| if self.data_source_type == 'upload_file': | |||
| if self.data_source_type == "upload_file": | |||
| data_source_info_dict = json.loads(self.data_source_info) | |||
| file_detail = db.session.query(UploadFile). \ | |||
| filter(UploadFile.id == data_source_info_dict['upload_file_id']). \ | |||
| one_or_none() | |||
| file_detail = ( | |||
| db.session.query(UploadFile) | |||
| .filter(UploadFile.id == data_source_info_dict["upload_file_id"]) | |||
| .one_or_none() | |||
| ) | |||
| if file_detail: | |||
| return { | |||
| 'upload_file': { | |||
| 'id': file_detail.id, | |||
| 'name': file_detail.name, | |||
| 'size': file_detail.size, | |||
| 'extension': file_detail.extension, | |||
| 'mime_type': file_detail.mime_type, | |||
| 'created_by': file_detail.created_by, | |||
| 'created_at': file_detail.created_at.timestamp() | |||
| "upload_file": { | |||
| "id": file_detail.id, | |||
| "name": file_detail.name, | |||
| "size": file_detail.size, | |||
| "extension": file_detail.extension, | |||
| "mime_type": file_detail.mime_type, | |||
| "created_by": file_detail.created_by, | |||
| "created_at": file_detail.created_at.timestamp(), | |||
| } | |||
| } | |||
| elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl': | |||
| elif self.data_source_type == "notion_import" or self.data_source_type == "website_crawl": | |||
| return json.loads(self.data_source_info) | |||
| return {} | |||
| @@ -356,120 +357,123 @@ class Document(db.Model): | |||
| @property | |||
| def hit_count(self): | |||
| return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \ | |||
| .filter(DocumentSegment.document_id == self.id).scalar() | |||
| return ( | |||
| DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) | |||
| .filter(DocumentSegment.document_id == self.id) | |||
| .scalar() | |||
| ) | |||
| def to_dict(self): | |||
| return { | |||
| 'id': self.id, | |||
| 'tenant_id': self.tenant_id, | |||
| 'dataset_id': self.dataset_id, | |||
| 'position': self.position, | |||
| 'data_source_type': self.data_source_type, | |||
| 'data_source_info': self.data_source_info, | |||
| 'dataset_process_rule_id': self.dataset_process_rule_id, | |||
| 'batch': self.batch, | |||
| 'name': self.name, | |||
| 'created_from': self.created_from, | |||
| 'created_by': self.created_by, | |||
| 'created_api_request_id': self.created_api_request_id, | |||
| 'created_at': self.created_at, | |||
| 'processing_started_at': self.processing_started_at, | |||
| 'file_id': self.file_id, | |||
| 'word_count': self.word_count, | |||
| 'parsing_completed_at': self.parsing_completed_at, | |||
| 'cleaning_completed_at': self.cleaning_completed_at, | |||
| 'splitting_completed_at': self.splitting_completed_at, | |||
| 'tokens': self.tokens, | |||
| 'indexing_latency': self.indexing_latency, | |||
| 'completed_at': self.completed_at, | |||
| 'is_paused': self.is_paused, | |||
| 'paused_by': self.paused_by, | |||
| 'paused_at': self.paused_at, | |||
| 'error': self.error, | |||
| 'stopped_at': self.stopped_at, | |||
| 'indexing_status': self.indexing_status, | |||
| 'enabled': self.enabled, | |||
| 'disabled_at': self.disabled_at, | |||
| 'disabled_by': self.disabled_by, | |||
| 'archived': self.archived, | |||
| 'archived_reason': self.archived_reason, | |||
| 'archived_by': self.archived_by, | |||
| 'archived_at': self.archived_at, | |||
| 'updated_at': self.updated_at, | |||
| 'doc_type': self.doc_type, | |||
| 'doc_metadata': self.doc_metadata, | |||
| 'doc_form': self.doc_form, | |||
| 'doc_language': self.doc_language, | |||
| 'display_status': self.display_status, | |||
| 'data_source_info_dict': self.data_source_info_dict, | |||
| 'average_segment_length': self.average_segment_length, | |||
| 'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | |||
| 'dataset': self.dataset.to_dict() if self.dataset else None, | |||
| 'segment_count': self.segment_count, | |||
| 'hit_count': self.hit_count | |||
| "id": self.id, | |||
| "tenant_id": self.tenant_id, | |||
| "dataset_id": self.dataset_id, | |||
| "position": self.position, | |||
| "data_source_type": self.data_source_type, | |||
| "data_source_info": self.data_source_info, | |||
| "dataset_process_rule_id": self.dataset_process_rule_id, | |||
| "batch": self.batch, | |||
| "name": self.name, | |||
| "created_from": self.created_from, | |||
| "created_by": self.created_by, | |||
| "created_api_request_id": self.created_api_request_id, | |||
| "created_at": self.created_at, | |||
| "processing_started_at": self.processing_started_at, | |||
| "file_id": self.file_id, | |||
| "word_count": self.word_count, | |||
| "parsing_completed_at": self.parsing_completed_at, | |||
| "cleaning_completed_at": self.cleaning_completed_at, | |||
| "splitting_completed_at": self.splitting_completed_at, | |||
| "tokens": self.tokens, | |||
| "indexing_latency": self.indexing_latency, | |||
| "completed_at": self.completed_at, | |||
| "is_paused": self.is_paused, | |||
| "paused_by": self.paused_by, | |||
| "paused_at": self.paused_at, | |||
| "error": self.error, | |||
| "stopped_at": self.stopped_at, | |||
| "indexing_status": self.indexing_status, | |||
| "enabled": self.enabled, | |||
| "disabled_at": self.disabled_at, | |||
| "disabled_by": self.disabled_by, | |||
| "archived": self.archived, | |||
| "archived_reason": self.archived_reason, | |||
| "archived_by": self.archived_by, | |||
| "archived_at": self.archived_at, | |||
| "updated_at": self.updated_at, | |||
| "doc_type": self.doc_type, | |||
| "doc_metadata": self.doc_metadata, | |||
| "doc_form": self.doc_form, | |||
| "doc_language": self.doc_language, | |||
| "display_status": self.display_status, | |||
| "data_source_info_dict": self.data_source_info_dict, | |||
| "average_segment_length": self.average_segment_length, | |||
| "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | |||
| "dataset": self.dataset.to_dict() if self.dataset else None, | |||
| "segment_count": self.segment_count, | |||
| "hit_count": self.hit_count, | |||
| } | |||
| @classmethod | |||
| def from_dict(cls, data: dict): | |||
| return cls( | |||
| id=data.get('id'), | |||
| tenant_id=data.get('tenant_id'), | |||
| dataset_id=data.get('dataset_id'), | |||
| position=data.get('position'), | |||
| data_source_type=data.get('data_source_type'), | |||
| data_source_info=data.get('data_source_info'), | |||
| dataset_process_rule_id=data.get('dataset_process_rule_id'), | |||
| batch=data.get('batch'), | |||
| name=data.get('name'), | |||
| created_from=data.get('created_from'), | |||
| created_by=data.get('created_by'), | |||
| created_api_request_id=data.get('created_api_request_id'), | |||
| created_at=data.get('created_at'), | |||
| processing_started_at=data.get('processing_started_at'), | |||
| file_id=data.get('file_id'), | |||
| word_count=data.get('word_count'), | |||
| parsing_completed_at=data.get('parsing_completed_at'), | |||
| cleaning_completed_at=data.get('cleaning_completed_at'), | |||
| splitting_completed_at=data.get('splitting_completed_at'), | |||
| tokens=data.get('tokens'), | |||
| indexing_latency=data.get('indexing_latency'), | |||
| completed_at=data.get('completed_at'), | |||
| is_paused=data.get('is_paused'), | |||
| paused_by=data.get('paused_by'), | |||
| paused_at=data.get('paused_at'), | |||
| error=data.get('error'), | |||
| stopped_at=data.get('stopped_at'), | |||
| indexing_status=data.get('indexing_status'), | |||
| enabled=data.get('enabled'), | |||
| disabled_at=data.get('disabled_at'), | |||
| disabled_by=data.get('disabled_by'), | |||
| archived=data.get('archived'), | |||
| archived_reason=data.get('archived_reason'), | |||
| archived_by=data.get('archived_by'), | |||
| archived_at=data.get('archived_at'), | |||
| updated_at=data.get('updated_at'), | |||
| doc_type=data.get('doc_type'), | |||
| doc_metadata=data.get('doc_metadata'), | |||
| doc_form=data.get('doc_form'), | |||
| doc_language=data.get('doc_language') | |||
| id=data.get("id"), | |||
| tenant_id=data.get("tenant_id"), | |||
| dataset_id=data.get("dataset_id"), | |||
| position=data.get("position"), | |||
| data_source_type=data.get("data_source_type"), | |||
| data_source_info=data.get("data_source_info"), | |||
| dataset_process_rule_id=data.get("dataset_process_rule_id"), | |||
| batch=data.get("batch"), | |||
| name=data.get("name"), | |||
| created_from=data.get("created_from"), | |||
| created_by=data.get("created_by"), | |||
| created_api_request_id=data.get("created_api_request_id"), | |||
| created_at=data.get("created_at"), | |||
| processing_started_at=data.get("processing_started_at"), | |||
| file_id=data.get("file_id"), | |||
| word_count=data.get("word_count"), | |||
| parsing_completed_at=data.get("parsing_completed_at"), | |||
| cleaning_completed_at=data.get("cleaning_completed_at"), | |||
| splitting_completed_at=data.get("splitting_completed_at"), | |||
| tokens=data.get("tokens"), | |||
| indexing_latency=data.get("indexing_latency"), | |||
| completed_at=data.get("completed_at"), | |||
| is_paused=data.get("is_paused"), | |||
| paused_by=data.get("paused_by"), | |||
| paused_at=data.get("paused_at"), | |||
| error=data.get("error"), | |||
| stopped_at=data.get("stopped_at"), | |||
| indexing_status=data.get("indexing_status"), | |||
| enabled=data.get("enabled"), | |||
| disabled_at=data.get("disabled_at"), | |||
| disabled_by=data.get("disabled_by"), | |||
| archived=data.get("archived"), | |||
| archived_reason=data.get("archived_reason"), | |||
| archived_by=data.get("archived_by"), | |||
| archived_at=data.get("archived_at"), | |||
| updated_at=data.get("updated_at"), | |||
| doc_type=data.get("doc_type"), | |||
| doc_metadata=data.get("doc_metadata"), | |||
| doc_form=data.get("doc_form"), | |||
| doc_language=data.get("doc_language"), | |||
| ) | |||
| class DocumentSegment(db.Model): | |||
| __tablename__ = 'document_segments' | |||
| __tablename__ = "document_segments" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='document_segment_pkey'), | |||
| db.Index('document_segment_dataset_id_idx', 'dataset_id'), | |||
| db.Index('document_segment_document_id_idx', 'document_id'), | |||
| db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'), | |||
| db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'), | |||
| db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'), | |||
| db.Index('document_segment_tenant_idx', 'tenant_id'), | |||
| db.PrimaryKeyConstraint("id", name="document_segment_pkey"), | |||
| db.Index("document_segment_dataset_id_idx", "dataset_id"), | |||
| db.Index("document_segment_document_id_idx", "document_id"), | |||
| db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), | |||
| db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), | |||
| db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), | |||
| db.Index("document_segment_tenant_idx", "tenant_id"), | |||
| ) | |||
| # initial fields | |||
| id = db.Column(StringUUID, nullable=False, | |||
| server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| document_id = db.Column(StringUUID, nullable=False) | |||
| @@ -486,18 +490,14 @@ class DocumentSegment(db.Model): | |||
| # basic fields | |||
| hit_count = db.Column(db.Integer, nullable=False, default=0) | |||
| enabled = db.Column(db.Boolean, nullable=False, | |||
| server_default=db.text('true')) | |||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||
| disabled_at = db.Column(db.DateTime, nullable=True) | |||
| disabled_by = db.Column(StringUUID, nullable=True) | |||
| status = db.Column(db.String(255), nullable=False, | |||
| server_default=db.text("'waiting'::character varying")) | |||
| status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, | |||
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_by = db.Column(StringUUID, nullable=True) | |||
| updated_at = db.Column(db.DateTime, nullable=False, | |||
| server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| indexing_at = db.Column(db.DateTime, nullable=True) | |||
| completed_at = db.Column(db.DateTime, nullable=True) | |||
| error = db.Column(db.Text, nullable=True) | |||
| @@ -513,17 +513,19 @@ class DocumentSegment(db.Model): | |||
| @property | |||
| def previous_segment(self): | |||
| return db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == self.document_id, | |||
| DocumentSegment.position == self.position - 1 | |||
| ).first() | |||
| return ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) | |||
| .first() | |||
| ) | |||
| @property | |||
| def next_segment(self): | |||
| return db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.document_id == self.document_id, | |||
| DocumentSegment.position == self.position + 1 | |||
| ).first() | |||
| return ( | |||
| db.session.query(DocumentSegment) | |||
| .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) | |||
| .first() | |||
| ) | |||
| def get_sign_content(self): | |||
| pattern = r"/files/([a-f0-9\-]+)/image-preview" | |||
| @@ -535,7 +537,7 @@ class DocumentSegment(db.Model): | |||
| nonce = os.urandom(16).hex() | |||
| timestamp = str(int(time.time())) | |||
| data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | |||
| secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b'' | |||
| secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" | |||
| sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |||
| encoded_sign = base64.urlsafe_b64encode(sign).decode() | |||
| @@ -546,21 +548,20 @@ class DocumentSegment(db.Model): | |||
| # Reconstruct the text with signed URLs | |||
| offset = 0 | |||
| for start, end, signed_url in signed_urls: | |||
| text = text[:start + offset] + signed_url + text[end + offset:] | |||
| text = text[: start + offset] + signed_url + text[end + offset :] | |||
| offset += len(signed_url) - (end - start) | |||
| return text | |||
| class AppDatasetJoin(db.Model): | |||
| __tablename__ = 'app_dataset_joins' | |||
| __tablename__ = "app_dataset_joins" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'), | |||
| db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'), | |||
| db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), | |||
| db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), | |||
| ) | |||
| id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| @@ -571,13 +572,13 @@ class AppDatasetJoin(db.Model): | |||
| class DatasetQuery(db.Model): | |||
| __tablename__ = 'dataset_queries' | |||
| __tablename__ = "dataset_queries" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_query_pkey'), | |||
| db.Index('dataset_query_dataset_id_idx', 'dataset_id'), | |||
| db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), | |||
| db.Index("dataset_query_dataset_id_idx", "dataset_id"), | |||
| ) | |||
| id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| content = db.Column(db.Text, nullable=False) | |||
| source = db.Column(db.String(255), nullable=False) | |||
| @@ -588,17 +589,18 @@ class DatasetQuery(db.Model): | |||
| class DatasetKeywordTable(db.Model): | |||
| __tablename__ = 'dataset_keyword_tables' | |||
| __tablename__ = "dataset_keyword_tables" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'), | |||
| db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'), | |||
| db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), | |||
| db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), | |||
| ) | |||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||
| dataset_id = db.Column(StringUUID, nullable=False, unique=True) | |||
| keyword_table = db.Column(db.Text, nullable=False) | |||
| data_source_type = db.Column(db.String(255), nullable=False, | |||
| server_default=db.text("'database'::character varying")) | |||
| data_source_type = db.Column( | |||
| db.String(255), nullable=False, server_default=db.text("'database'::character varying") | |||
| ) | |||
| @property | |||
| def keyword_table_dict(self): | |||
| @@ -614,19 +616,17 @@ class DatasetKeywordTable(db.Model): | |||
| return dct | |||
| # get dataset | |||
| dataset = Dataset.query.filter_by( | |||
| id=self.dataset_id | |||
| ).first() | |||
| dataset = Dataset.query.filter_by(id=self.dataset_id).first() | |||
| if not dataset: | |||
| return None | |||
| if self.data_source_type == 'database': | |||
| if self.data_source_type == "database": | |||
| return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None | |||
| else: | |||
| file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt' | |||
| file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt" | |||
| try: | |||
| keyword_table_text = storage.load_once(file_key) | |||
| if keyword_table_text: | |||
| return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder) | |||
| return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) | |||
| return None | |||
| except Exception as e: | |||
| logging.exception(str(e)) | |||
| @@ -634,21 +634,21 @@ class DatasetKeywordTable(db.Model): | |||
| class Embedding(db.Model): | |||
| __tablename__ = 'embeddings' | |||
| __tablename__ = "embeddings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='embedding_pkey'), | |||
| db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx'), | |||
| db.Index('created_at_idx', 'created_at') | |||
| db.PrimaryKeyConstraint("id", name="embedding_pkey"), | |||
| db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), | |||
| db.Index("created_at_idx", "created_at"), | |||
| ) | |||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | |||
| model_name = db.Column(db.String(255), nullable=False, | |||
| server_default=db.text("'text-embedding-ada-002'::character varying")) | |||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||
| model_name = db.Column( | |||
| db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") | |||
| ) | |||
| hash = db.Column(db.String(64), nullable=False) | |||
| embedding = db.Column(db.LargeBinary, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| provider_name = db.Column(db.String(255), nullable=False, | |||
| server_default=db.text("''::character varying")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) | |||
| def set_embedding(self, embedding_data: list[float]): | |||
| self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) | |||
| @@ -658,33 +658,32 @@ class Embedding(db.Model): | |||
| class DatasetCollectionBinding(db.Model): | |||
| __tablename__ = 'dataset_collection_bindings' | |||
| __tablename__ = "dataset_collection_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'), | |||
| db.Index('provider_model_name_idx', 'provider_name', 'model_name') | |||
| db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), | |||
| db.Index("provider_model_name_idx", "provider_name", "model_name"), | |||
| ) | |||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |||
| provider_name = db.Column(db.String(40), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) | |||
| collection_name = db.Column(db.String(64), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class DatasetPermission(db.Model): | |||
| __tablename__ = 'dataset_permissions' | |||
| __tablename__ = "dataset_permissions" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_permission_pkey'), | |||
| db.Index('idx_dataset_permissions_dataset_id', 'dataset_id'), | |||
| db.Index('idx_dataset_permissions_account_id', 'account_id'), | |||
| db.Index('idx_dataset_permissions_tenant_id', 'tenant_id') | |||
| db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), | |||
| db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), | |||
| db.Index("idx_dataset_permissions_account_id", "account_id"), | |||
| db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'), primary_key=True) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| account_id = db.Column(StringUUID, nullable=False) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @@ -6,8 +6,8 @@ from .types import StringUUID | |||
| class ProviderType(Enum): | |||
| CUSTOM = 'custom' | |||
| SYSTEM = 'system' | |||
| CUSTOM = "custom" | |||
| SYSTEM = "system" | |||
| @staticmethod | |||
| def value_of(value): | |||
| @@ -18,13 +18,13 @@ class ProviderType(Enum): | |||
| class ProviderQuotaType(Enum): | |||
| PAID = 'paid' | |||
| PAID = "paid" | |||
| """hosted paid quota""" | |||
| FREE = 'free' | |||
| FREE = "free" | |||
| """third-party free quota""" | |||
| TRIAL = 'trial' | |||
| TRIAL = "trial" | |||
| """hosted trial quota""" | |||
| @staticmethod | |||
| @@ -39,27 +39,30 @@ class Provider(db.Model): | |||
| """ | |||
| Provider model representing the API providers and their configurations. | |||
| """ | |||
| __tablename__ = 'providers' | |||
| __tablename__ = "providers" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='provider_pkey'), | |||
| db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'), | |||
| db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') | |||
| db.PrimaryKeyConstraint("id", name="provider_pkey"), | |||
| db.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"), | |||
| db.UniqueConstraint( | |||
| "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota" | |||
| ), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) | |||
| encrypted_config = db.Column(db.Text, nullable=True) | |||
| is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) | |||
| is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| last_used = db.Column(db.DateTime, nullable=True) | |||
| quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) | |||
| quota_limit = db.Column(db.BigInteger, nullable=True) | |||
| quota_used = db.Column(db.BigInteger, default=0) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| def __repr__(self): | |||
| return f"<Provider(id={self.id}, tenant_id={self.tenant_id}, provider_name='{self.provider_name}', provider_type='{self.provider_type}')>" | |||
| @@ -67,8 +70,8 @@ class Provider(db.Model): | |||
| @property | |||
| def token_is_set(self): | |||
| """ | |||
| Returns True if the encrypted_config is not None, indicating that the token is set. | |||
| """ | |||
| Returns True if the encrypted_config is not None, indicating that the token is set. | |||
| """ | |||
| return self.encrypted_config is not None | |||
| @property | |||
| @@ -86,118 +89,123 @@ class ProviderModel(db.Model): | |||
| """ | |||
| Provider model representing the API provider_models and their configurations. | |||
| """ | |||
| __tablename__ = 'provider_models' | |||
| __tablename__ = "provider_models" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='provider_model_pkey'), | |||
| db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'), | |||
| db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') | |||
| db.PrimaryKeyConstraint("id", name="provider_model_pkey"), | |||
| db.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"), | |||
| db.UniqueConstraint( | |||
| "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name" | |||
| ), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| model_type = db.Column(db.String(40), nullable=False) | |||
| encrypted_config = db.Column(db.Text, nullable=True) | |||
| is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class TenantDefaultModel(db.Model): | |||
| __tablename__ = 'tenant_default_models' | |||
| __tablename__ = "tenant_default_models" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'), | |||
| db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), | |||
| db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"), | |||
| db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| model_type = db.Column(db.String(40), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class TenantPreferredModelProvider(db.Model): | |||
| __tablename__ = 'tenant_preferred_model_providers' | |||
| __tablename__ = "tenant_preferred_model_providers" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'), | |||
| db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), | |||
| db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"), | |||
| db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| preferred_provider_type = db.Column(db.String(40), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class ProviderOrder(db.Model): | |||
| __tablename__ = 'provider_orders' | |||
| __tablename__ = "provider_orders" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='provider_order_pkey'), | |||
| db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), | |||
| db.PrimaryKeyConstraint("id", name="provider_order_pkey"), | |||
| db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| account_id = db.Column(StringUUID, nullable=False) | |||
| payment_product_id = db.Column(db.String(191), nullable=False) | |||
| payment_id = db.Column(db.String(191)) | |||
| transaction_id = db.Column(db.String(191)) | |||
| quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1')) | |||
| quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) | |||
| currency = db.Column(db.String(40)) | |||
| total_amount = db.Column(db.Integer) | |||
| payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) | |||
| paid_at = db.Column(db.DateTime) | |||
| pay_failed_at = db.Column(db.DateTime) | |||
| refunded_at = db.Column(db.DateTime) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class ProviderModelSetting(db.Model): | |||
| """ | |||
| Provider model settings for record the model enabled status and load balancing status. | |||
| """ | |||
| __tablename__ = 'provider_model_settings' | |||
| __tablename__ = "provider_model_settings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='provider_model_setting_pkey'), | |||
| db.Index('provider_model_setting_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), | |||
| db.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"), | |||
| db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| model_type = db.Column(db.String(40), nullable=False) | |||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) | |||
| load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||
| load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class LoadBalancingModelConfig(db.Model): | |||
| """ | |||
| Configurations for load balancing models. | |||
| """ | |||
| __tablename__ = 'load_balancing_model_configs' | |||
| __tablename__ = "load_balancing_model_configs" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='load_balancing_model_config_pkey'), | |||
| db.Index('load_balancing_model_config_tenant_provider_model_idx', 'tenant_id', 'provider_name', 'model_type'), | |||
| db.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"), | |||
| db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| provider_name = db.Column(db.String(255), nullable=False) | |||
| model_name = db.Column(db.String(255), nullable=False) | |||
| model_type = db.Column(db.String(40), nullable=False) | |||
| name = db.Column(db.String(255), nullable=False) | |||
| encrypted_config = db.Column(db.Text, nullable=True) | |||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @@ -8,48 +8,48 @@ from .types import StringUUID | |||
| class DataSourceOauthBinding(db.Model): | |||
| __tablename__ = 'data_source_oauth_bindings' | |||
| __tablename__ = "data_source_oauth_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='source_binding_pkey'), | |||
| db.Index('source_binding_tenant_id_idx', 'tenant_id'), | |||
| db.Index('source_info_idx', "source_info", postgresql_using='gin') | |||
| db.PrimaryKeyConstraint("id", name="source_binding_pkey"), | |||
| db.Index("source_binding_tenant_id_idx", "tenant_id"), | |||
| db.Index("source_info_idx", "source_info", postgresql_using="gin"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| access_token = db.Column(db.String(255), nullable=False) | |||
| provider = db.Column(db.String(255), nullable=False) | |||
| source_info = db.Column(JSONB, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) | |||
| class DataSourceApiKeyAuthBinding(db.Model): | |||
| __tablename__ = 'data_source_api_key_auth_bindings' | |||
| __tablename__ = "data_source_api_key_auth_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'), | |||
| db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'), | |||
| db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'), | |||
| db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"), | |||
| db.Index("data_source_api_key_auth_binding_tenant_id_idx", "tenant_id"), | |||
| db.Index("data_source_api_key_auth_binding_provider_idx", "provider"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| category = db.Column(db.String(255), nullable=False) | |||
| provider = db.Column(db.String(255), nullable=False) | |||
| credentials = db.Column(db.Text, nullable=True) # JSON | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) | |||
| def to_dict(self): | |||
| return { | |||
| 'id': self.id, | |||
| 'tenant_id': self.tenant_id, | |||
| 'category': self.category, | |||
| 'provider': self.provider, | |||
| 'credentials': json.loads(self.credentials), | |||
| 'created_at': self.created_at.timestamp(), | |||
| 'updated_at': self.updated_at.timestamp(), | |||
| 'disabled': self.disabled | |||
| "id": self.id, | |||
| "tenant_id": self.tenant_id, | |||
| "category": self.category, | |||
| "provider": self.provider, | |||
| "credentials": json.loads(self.credentials), | |||
| "created_at": self.created_at.timestamp(), | |||
| "updated_at": self.updated_at.timestamp(), | |||
| "disabled": self.disabled, | |||
| } | |||
| @@ -8,15 +8,18 @@ from extensions.ext_database import db | |||
| class CeleryTask(db.Model): | |||
| """Task result/status.""" | |||
| __tablename__ = 'celery_taskmeta' | |||
| __tablename__ = "celery_taskmeta" | |||
| id = db.Column(db.Integer, db.Sequence('task_id_sequence'), | |||
| primary_key=True, autoincrement=True) | |||
| id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True) | |||
| task_id = db.Column(db.String(155), unique=True) | |||
| status = db.Column(db.String(50), default=states.PENDING) | |||
| result = db.Column(db.PickleType, nullable=True) | |||
| date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), | |||
| onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) | |||
| date_done = db.Column( | |||
| db.DateTime, | |||
| default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), | |||
| onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), | |||
| nullable=True, | |||
| ) | |||
| traceback = db.Column(db.Text, nullable=True) | |||
| name = db.Column(db.String(155), nullable=True) | |||
| args = db.Column(db.LargeBinary, nullable=True) | |||
| @@ -29,11 +32,9 @@ class CeleryTask(db.Model): | |||
| class CeleryTaskSet(db.Model): | |||
| """TaskSet result.""" | |||
| __tablename__ = 'celery_tasksetmeta' | |||
| __tablename__ = "celery_tasksetmeta" | |||
| id = db.Column(db.Integer, db.Sequence('taskset_id_sequence'), | |||
| autoincrement=True, primary_key=True) | |||
| id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) | |||
| taskset_id = db.Column(db.String(155), unique=True) | |||
| result = db.Column(db.PickleType, nullable=True) | |||
| date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), | |||
| nullable=True) | |||
| date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) | |||
| @@ -7,7 +7,7 @@ from .types import StringUUID | |||
| class ToolProviderName(Enum): | |||
| SERPAPI = 'serpapi' | |||
| SERPAPI = "serpapi" | |||
| @staticmethod | |||
| def value_of(value): | |||
| @@ -18,25 +18,25 @@ class ToolProviderName(Enum): | |||
| class ToolProvider(db.Model): | |||
| __tablename__ = 'tool_providers' | |||
| __tablename__ = "tool_providers" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tool_provider_pkey'), | |||
| db.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name') | |||
| db.PrimaryKeyConstraint("id", name="tool_provider_pkey"), | |||
| db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| tool_name = db.Column(db.String(40), nullable=False) | |||
| encrypted_credentials = db.Column(db.Text, nullable=True) | |||
| is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def credentials_is_set(self): | |||
| """ | |||
| Returns True if the encrypted_config is not None, indicating that the token is set. | |||
| """ | |||
| Returns True if the encrypted_config is not None, indicating that the token is set. | |||
| """ | |||
| return self.encrypted_credentials is not None | |||
| @property | |||
| @@ -15,15 +15,16 @@ class BuiltinToolProvider(db.Model): | |||
| """ | |||
| This table stores the tool provider information for built-in tools for each tenant. | |||
| """ | |||
| __tablename__ = 'tool_builtin_providers' | |||
| __tablename__ = "tool_builtin_providers" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tool_builtin_provider_pkey'), | |||
| db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), | |||
| # one tenant can only have one tool provider with the same name | |||
| db.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_tool_provider') | |||
| db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), | |||
| ) | |||
| # id of the tool provider | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # id of the tenant | |||
| tenant_id = db.Column(StringUUID, nullable=True) | |||
| # who created this tool provider | |||
| @@ -32,27 +33,29 @@ class BuiltinToolProvider(db.Model): | |||
| provider = db.Column(db.String(40), nullable=False) | |||
| # credential of the tool provider | |||
| encrypted_credentials = db.Column(db.Text, nullable=True) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def credentials(self) -> dict: | |||
| return json.loads(self.encrypted_credentials) | |||
| class PublishedAppTool(db.Model): | |||
| """ | |||
| The table stores the apps published as a tool for each person. | |||
| """ | |||
| __tablename__ = 'tool_published_apps' | |||
| __tablename__ = "tool_published_apps" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='published_app_tool_pkey'), | |||
| db.UniqueConstraint('app_id', 'user_id', name='unique_published_app_tool') | |||
| db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"), | |||
| db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"), | |||
| ) | |||
| # id of the tool provider | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # id of the app | |||
| app_id = db.Column(StringUUID, ForeignKey('apps.id'), nullable=False) | |||
| app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False) | |||
| # who published this tool | |||
| user_id = db.Column(StringUUID, nullable=False) | |||
| # description of the tool, stored in i18n format, for human | |||
| @@ -67,28 +70,30 @@ class PublishedAppTool(db.Model): | |||
| tool_name = db.Column(db.String(40), nullable=False) | |||
| # author | |||
| author = db.Column(db.String(40), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def description_i18n(self) -> I18nObject: | |||
| return I18nObject(**json.loads(self.description)) | |||
| @property | |||
| def app(self) -> App: | |||
| return db.session.query(App).filter(App.id == self.app_id).first() | |||
| class ApiToolProvider(db.Model): | |||
| """ | |||
| The table stores the api providers. | |||
| """ | |||
| __tablename__ = 'tool_api_providers' | |||
| __tablename__ = "tool_api_providers" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tool_api_provider_pkey'), | |||
| db.UniqueConstraint('name', 'tenant_id', name='unique_api_tool_provider') | |||
| db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), | |||
| db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # name of the api provider | |||
| name = db.Column(db.String(40), nullable=False) | |||
| # icon | |||
| @@ -111,21 +116,21 @@ class ApiToolProvider(db.Model): | |||
| # custom_disclaimer | |||
| custom_disclaimer = db.Column(db.String(255), nullable=True) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def schema_type(self) -> ApiProviderSchemaType: | |||
| return ApiProviderSchemaType.value_of(self.schema_type_str) | |||
| @property | |||
| def tools(self) -> list[ApiToolBundle]: | |||
| return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] | |||
| @property | |||
| def credentials(self) -> dict: | |||
| return json.loads(self.credentials_str) | |||
| @property | |||
| def user(self) -> Account: | |||
| return db.session.query(Account).filter(Account.id == self.user_id).first() | |||
| @@ -134,17 +139,19 @@ class ApiToolProvider(db.Model): | |||
| def tenant(self) -> Tenant: | |||
| return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() | |||
| class ToolLabelBinding(db.Model): | |||
| """ | |||
| The table stores the labels for tools. | |||
| """ | |||
| __tablename__ = 'tool_label_bindings' | |||
| __tablename__ = "tool_label_bindings" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tool_label_bind_pkey'), | |||
| db.UniqueConstraint('tool_id', 'label_name', name='unique_tool_label_bind'), | |||
| db.PrimaryKeyConstraint("id", name="tool_label_bind_pkey"), | |||
| db.UniqueConstraint("tool_id", "label_name", name="unique_tool_label_bind"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # tool id | |||
| tool_id = db.Column(db.String(64), nullable=False) | |||
| # tool type | |||
| @@ -152,28 +159,30 @@ class ToolLabelBinding(db.Model): | |||
| # label name | |||
| label_name = db.Column(db.String(40), nullable=False) | |||
| class WorkflowToolProvider(db.Model): | |||
| """ | |||
| The table stores the workflow providers. | |||
| """ | |||
| __tablename__ = 'tool_workflow_providers' | |||
| __tablename__ = "tool_workflow_providers" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tool_workflow_provider_pkey'), | |||
| db.UniqueConstraint('name', 'tenant_id', name='unique_workflow_tool_provider'), | |||
| db.UniqueConstraint('tenant_id', 'app_id', name='unique_workflow_tool_provider_app_id'), | |||
| db.PrimaryKeyConstraint("id", name="tool_workflow_provider_pkey"), | |||
| db.UniqueConstraint("name", "tenant_id", name="unique_workflow_tool_provider"), | |||
| db.UniqueConstraint("tenant_id", "app_id", name="unique_workflow_tool_provider_app_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # name of the workflow provider | |||
| name = db.Column(db.String(40), nullable=False) | |||
| # label of the workflow provider | |||
| label = db.Column(db.String(255), nullable=False, server_default='') | |||
| label = db.Column(db.String(255), nullable=False, server_default="") | |||
| # icon | |||
| icon = db.Column(db.String(255), nullable=False) | |||
| # app id of the workflow provider | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| # version of the workflow provider | |||
| version = db.Column(db.String(255), nullable=False, server_default='') | |||
| version = db.Column(db.String(255), nullable=False, server_default="") | |||
| # who created this tool | |||
| user_id = db.Column(StringUUID, nullable=False) | |||
| # tenant id | |||
| @@ -181,17 +190,17 @@ class WorkflowToolProvider(db.Model): | |||
| # description of the provider | |||
| description = db.Column(db.Text, nullable=False) | |||
| # parameter configuration | |||
| parameter_configuration = db.Column(db.Text, nullable=False, server_default='[]') | |||
| parameter_configuration = db.Column(db.Text, nullable=False, server_default="[]") | |||
| # privacy policy | |||
| privacy_policy = db.Column(db.String(255), nullable=True, server_default='') | |||
| privacy_policy = db.Column(db.String(255), nullable=True, server_default="") | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def schema_type(self) -> ApiProviderSchemaType: | |||
| return ApiProviderSchemaType.value_of(self.schema_type_str) | |||
| @property | |||
| def user(self) -> Account: | |||
| return db.session.query(Account).filter(Account.id == self.user_id).first() | |||
| @@ -199,28 +208,25 @@ class WorkflowToolProvider(db.Model): | |||
| @property | |||
| def tenant(self) -> Tenant: | |||
| return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() | |||
| @property | |||
| def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: | |||
| return [ | |||
| WorkflowToolParameterConfiguration(**config) | |||
| for config in json.loads(self.parameter_configuration) | |||
| ] | |||
| return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] | |||
| @property | |||
| def app(self) -> App: | |||
| return db.session.query(App).filter(App.id == self.app_id).first() | |||
| class ToolModelInvoke(db.Model): | |||
| """ | |||
| store the invoke logs from tool invoke | |||
| """ | |||
| __tablename__ = "tool_model_invokes" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tool_model_invoke_pkey'), | |||
| ) | |||
| __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # who invoke this tool | |||
| user_id = db.Column(StringUUID, nullable=False) | |||
| # tenant id | |||
| @@ -238,29 +244,31 @@ class ToolModelInvoke(db.Model): | |||
| # invoke response | |||
| model_response = db.Column(db.Text, nullable=False) | |||
| prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | |||
| answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | |||
| prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||
| answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||
| answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | |||
| answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||
| provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) | |||
| answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) | |||
| provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) | |||
| total_price = db.Column(db.Numeric(10, 7)) | |||
| currency = db.Column(db.String(255), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| class ToolConversationVariables(db.Model): | |||
| """ | |||
| store the conversation variables from tool invoke | |||
| """ | |||
| __tablename__ = "tool_conversation_variables" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tool_conversation_variables_pkey'), | |||
| db.PrimaryKeyConstraint("id", name="tool_conversation_variables_pkey"), | |||
| # add index for user_id and conversation_id | |||
| db.Index('user_id_idx', 'user_id'), | |||
| db.Index('conversation_id_idx', 'conversation_id'), | |||
| db.Index("user_id_idx", "user_id"), | |||
| db.Index("conversation_id_idx", "conversation_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # conversation user id | |||
| user_id = db.Column(StringUUID, nullable=False) | |||
| # tenant id | |||
| @@ -270,25 +278,27 @@ class ToolConversationVariables(db.Model): | |||
| # variables pool | |||
| variables_str = db.Column(db.Text, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def variables(self) -> dict: | |||
| return json.loads(self.variables_str) | |||
| class ToolFile(db.Model): | |||
| """ | |||
| store the file created by agent | |||
| """ | |||
| __tablename__ = "tool_files" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='tool_file_pkey'), | |||
| db.PrimaryKeyConstraint("id", name="tool_file_pkey"), | |||
| # add index for conversation_id | |||
| db.Index('tool_file_conversation_id_idx', 'conversation_id'), | |||
| db.Index("tool_file_conversation_id_idx", "conversation_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| # conversation user id | |||
| user_id = db.Column(StringUUID, nullable=False) | |||
| # tenant id | |||
| @@ -300,4 +310,4 @@ class ToolFile(db.Model): | |||
| # mime type | |||
| mimetype = db.Column(db.String(255), nullable=False) | |||
| # original url | |||
| original_url = db.Column(db.String(2048), nullable=True) | |||
| original_url = db.Column(db.String(2048), nullable=True) | |||
| @@ -9,13 +9,13 @@ class StringUUID(TypeDecorator): | |||
| def process_bind_param(self, value, dialect): | |||
| if value is None: | |||
| return value | |||
| elif dialect.name == 'postgresql': | |||
| elif dialect.name == "postgresql": | |||
| return str(value) | |||
| else: | |||
| return value.hex | |||
| def load_dialect_impl(self, dialect): | |||
| if dialect.name == 'postgresql': | |||
| if dialect.name == "postgresql": | |||
| return dialect.type_descriptor(UUID()) | |||
| else: | |||
| return dialect.type_descriptor(CHAR(36)) | |||
| @@ -23,4 +23,4 @@ class StringUUID(TypeDecorator): | |||
| def process_result_value(self, value, dialect): | |||
| if value is None: | |||
| return value | |||
| return str(value) | |||
| return str(value) | |||
| @@ -1,4 +1,3 @@ | |||
| from extensions.ext_database import db | |||
| from .model import Message | |||
| @@ -6,18 +5,18 @@ from .types import StringUUID | |||
| class SavedMessage(db.Model): | |||
| __tablename__ = 'saved_messages' | |||
| __tablename__ = "saved_messages" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='saved_message_pkey'), | |||
| db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by_role', 'created_by'), | |||
| db.PrimaryKeyConstraint("id", name="saved_message_pkey"), | |||
| db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| message_id = db.Column(StringUUID, nullable=False) | |||
| created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def message(self): | |||
| @@ -25,15 +24,15 @@ class SavedMessage(db.Model): | |||
| class PinnedConversation(db.Model): | |||
| __tablename__ = 'pinned_conversations' | |||
| __tablename__ = "pinned_conversations" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='pinned_conversation_pkey'), | |||
| db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by_role', 'created_by'), | |||
| db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"), | |||
| db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| conversation_id = db.Column(StringUUID, nullable=False) | |||
| created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @@ -22,11 +22,12 @@ class CreatedByRole(Enum): | |||
| """ | |||
| Created By Role Enum | |||
| """ | |||
| ACCOUNT = 'account' | |||
| END_USER = 'end_user' | |||
| ACCOUNT = "account" | |||
| END_USER = "end_user" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'CreatedByRole': | |||
| def value_of(cls, value: str) -> "CreatedByRole": | |||
| """ | |||
| Get value of given mode. | |||
| @@ -36,18 +37,19 @@ class CreatedByRole(Enum): | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid created by role value {value}') | |||
| raise ValueError(f"invalid created by role value {value}") | |||
| class WorkflowType(Enum): | |||
| """ | |||
| Workflow Type Enum | |||
| """ | |||
| WORKFLOW = 'workflow' | |||
| CHAT = 'chat' | |||
| WORKFLOW = "workflow" | |||
| CHAT = "chat" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'WorkflowType': | |||
| def value_of(cls, value: str) -> "WorkflowType": | |||
| """ | |||
| Get value of given mode. | |||
| @@ -57,10 +59,10 @@ class WorkflowType(Enum): | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid workflow type value {value}') | |||
| raise ValueError(f"invalid workflow type value {value}") | |||
| @classmethod | |||
| def from_app_mode(cls, app_mode: Union[str, 'AppMode']) -> 'WorkflowType': | |||
| def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType": | |||
| """ | |||
| Get workflow type from app mode. | |||
| @@ -68,6 +70,7 @@ class WorkflowType(Enum): | |||
| :return: workflow type | |||
| """ | |||
| from models.model import AppMode | |||
| app_mode = app_mode if isinstance(app_mode, AppMode) else AppMode.value_of(app_mode) | |||
| return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT | |||
| @@ -105,13 +108,13 @@ class Workflow(db.Model): | |||
| - updated_at (timestamp) `optional` Last update time | |||
| """ | |||
| __tablename__ = 'workflows' | |||
| __tablename__ = "workflows" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='workflow_pkey'), | |||
| db.Index('workflow_version_idx', 'tenant_id', 'app_id', 'version'), | |||
| db.PrimaryKeyConstraint("id", name="workflow_pkey"), | |||
| db.Index("workflow_version_idx", "tenant_id", "app_id", "version"), | |||
| ) | |||
| id: Mapped[str] = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) | |||
| app_id: Mapped[str] = db.Column(StringUUID, nullable=False) | |||
| type: Mapped[str] = db.Column(db.String(255), nullable=False) | |||
| @@ -119,15 +122,31 @@ class Workflow(db.Model): | |||
| graph: Mapped[str] = db.Column(db.Text) | |||
| features: Mapped[str] = db.Column(db.Text) | |||
| created_by: Mapped[str] = db.Column(StringUUID, nullable=False) | |||
| created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at: Mapped[datetime] = db.Column( | |||
| db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") | |||
| ) | |||
| updated_by: Mapped[str] = db.Column(StringUUID) | |||
| updated_at: Mapped[datetime] = db.Column(db.DateTime) | |||
| _environment_variables: Mapped[str] = db.Column('environment_variables', db.Text, nullable=False, server_default='{}') | |||
| _conversation_variables: Mapped[str] = db.Column('conversation_variables', db.Text, nullable=False, server_default='{}') | |||
| _environment_variables: Mapped[str] = db.Column( | |||
| "environment_variables", db.Text, nullable=False, server_default="{}" | |||
| ) | |||
| _conversation_variables: Mapped[str] = db.Column( | |||
| "conversation_variables", db.Text, nullable=False, server_default="{}" | |||
| ) | |||
| def __init__(self, *, tenant_id: str, app_id: str, type: str, version: str, graph: str, | |||
| features: str, created_by: str, environment_variables: Sequence[Variable], | |||
| conversation_variables: Sequence[Variable]): | |||
| def __init__( | |||
| self, | |||
| *, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| type: str, | |||
| version: str, | |||
| graph: str, | |||
| features: str, | |||
| created_by: str, | |||
| environment_variables: Sequence[Variable], | |||
| conversation_variables: Sequence[Variable], | |||
| ): | |||
| self.tenant_id = tenant_id | |||
| self.app_id = app_id | |||
| self.type = type | |||
| @@ -160,22 +179,20 @@ class Workflow(db.Model): | |||
| return [] | |||
| graph_dict = self.graph_dict | |||
| if 'nodes' not in graph_dict: | |||
| if "nodes" not in graph_dict: | |||
| return [] | |||
| start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) | |||
| start_node = next((node for node in graph_dict["nodes"] if node["data"]["type"] == "start"), None) | |||
| if not start_node: | |||
| return [] | |||
| # get user_input_form from start node | |||
| variables = start_node.get('data', {}).get('variables', []) | |||
| variables = start_node.get("data", {}).get("variables", []) | |||
| if to_old_structure: | |||
| old_structure_variables = [] | |||
| for variable in variables: | |||
| old_structure_variables.append({ | |||
| variable['type']: variable | |||
| }) | |||
| old_structure_variables.append({variable["type"]: variable}) | |||
| return old_structure_variables | |||
| @@ -188,25 +205,24 @@ class Workflow(db.Model): | |||
| :return: hash | |||
| """ | |||
| entity = { | |||
| 'graph': self.graph_dict, | |||
| 'features': self.features_dict | |||
| } | |||
| entity = {"graph": self.graph_dict, "features": self.features_dict} | |||
| return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) | |||
| @property | |||
| def tool_published(self) -> bool: | |||
| from models.tools import WorkflowToolProvider | |||
| return db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.app_id == self.app_id | |||
| ).first() is not None | |||
| return ( | |||
| db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.app_id == self.app_id).first() | |||
| is not None | |||
| ) | |||
| @property | |||
| def environment_variables(self) -> Sequence[Variable]: | |||
| # TODO: find some way to init `self._environment_variables` when instance created. | |||
| if self._environment_variables is None: | |||
| self._environment_variables = '{}' | |||
| self._environment_variables = "{}" | |||
| tenant_id = contexts.tenant_id.get() | |||
| @@ -215,9 +231,7 @@ class Workflow(db.Model): | |||
| # decrypt secret variables value | |||
| decrypt_func = ( | |||
| lambda var: var.model_copy( | |||
| update={'value': encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)} | |||
| ) | |||
| lambda var: var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) | |||
| if isinstance(var, SecretVariable) | |||
| else var | |||
| ) | |||
| @@ -230,19 +244,17 @@ class Workflow(db.Model): | |||
| value = list(value) | |||
| if any(var for var in value if not var.id): | |||
| raise ValueError('environment variable require a unique id') | |||
| raise ValueError("environment variable require a unique id") | |||
| # Compare inputs and origin variables, if the value is HIDDEN_VALUE, use the origin variable value (only update `name`). | |||
| origin_variables_dictionary = {var.id: var for var in self.environment_variables} | |||
| for i, variable in enumerate(value): | |||
| if variable.id in origin_variables_dictionary and variable.value == HIDDEN_VALUE: | |||
| value[i] = origin_variables_dictionary[variable.id].model_copy(update={'name': variable.name}) | |||
| value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) | |||
| # encrypt secret variables value | |||
| encrypt_func = ( | |||
| lambda var: var.model_copy( | |||
| update={'value': encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)} | |||
| ) | |||
| lambda var: var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) | |||
| if isinstance(var, SecretVariable) | |||
| else var | |||
| ) | |||
| @@ -256,15 +268,15 @@ class Workflow(db.Model): | |||
| def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: | |||
| environment_variables = list(self.environment_variables) | |||
| environment_variables = [ | |||
| v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={'value': ''}) | |||
| v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) | |||
| for v in environment_variables | |||
| ] | |||
| result = { | |||
| 'graph': self.graph_dict, | |||
| 'features': self.features_dict, | |||
| 'environment_variables': [var.model_dump(mode='json') for var in environment_variables], | |||
| 'conversation_variables': [var.model_dump(mode='json') for var in self.conversation_variables], | |||
| "graph": self.graph_dict, | |||
| "features": self.features_dict, | |||
| "environment_variables": [var.model_dump(mode="json") for var in environment_variables], | |||
| "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], | |||
| } | |||
| return result | |||
| @@ -272,7 +284,7 @@ class Workflow(db.Model): | |||
| def conversation_variables(self) -> Sequence[Variable]: | |||
| # TODO: find some way to init `self._conversation_variables` when instance created. | |||
| if self._conversation_variables is None: | |||
| self._conversation_variables = '{}' | |||
| self._conversation_variables = "{}" | |||
| variables_dict: dict[str, Any] = json.loads(self._conversation_variables) | |||
| results = [factory.build_variable_from_mapping(v) for v in variables_dict.values()] | |||
| @@ -290,11 +302,12 @@ class WorkflowRunTriggeredFrom(Enum): | |||
| """ | |||
| Workflow Run Triggered From Enum | |||
| """ | |||
| DEBUGGING = 'debugging' | |||
| APP_RUN = 'app-run' | |||
| DEBUGGING = "debugging" | |||
| APP_RUN = "app-run" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'WorkflowRunTriggeredFrom': | |||
| def value_of(cls, value: str) -> "WorkflowRunTriggeredFrom": | |||
| """ | |||
| Get value of given mode. | |||
| @@ -304,20 +317,21 @@ class WorkflowRunTriggeredFrom(Enum): | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid workflow run triggered from value {value}') | |||
| raise ValueError(f"invalid workflow run triggered from value {value}") | |||
| class WorkflowRunStatus(Enum): | |||
| """ | |||
| Workflow Run Status Enum | |||
| """ | |||
| RUNNING = 'running' | |||
| SUCCEEDED = 'succeeded' | |||
| FAILED = 'failed' | |||
| STOPPED = 'stopped' | |||
| RUNNING = "running" | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| STOPPED = "stopped" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'WorkflowRunStatus': | |||
| def value_of(cls, value: str) -> "WorkflowRunStatus": | |||
| """ | |||
| Get value of given mode. | |||
| @@ -327,7 +341,7 @@ class WorkflowRunStatus(Enum): | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid workflow run status value {value}') | |||
| raise ValueError(f"invalid workflow run status value {value}") | |||
| class WorkflowRun(db.Model): | |||
| @@ -368,14 +382,14 @@ class WorkflowRun(db.Model): | |||
| - finished_at (timestamp) End time | |||
| """ | |||
| __tablename__ = 'workflow_runs' | |||
| __tablename__ = "workflow_runs" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='workflow_run_pkey'), | |||
| db.Index('workflow_run_triggerd_from_idx', 'tenant_id', 'app_id', 'triggered_from'), | |||
| db.Index('workflow_run_tenant_app_sequence_idx', 'tenant_id', 'app_id', 'sequence_number'), | |||
| db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), | |||
| db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), | |||
| db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| sequence_number = db.Column(db.Integer, nullable=False) | |||
| @@ -388,26 +402,25 @@ class WorkflowRun(db.Model): | |||
| status = db.Column(db.String(255), nullable=False) | |||
| outputs = db.Column(db.Text) | |||
| error = db.Column(db.Text) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) | |||
| total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | |||
| total_steps = db.Column(db.Integer, server_default=db.text('0')) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) | |||
| total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) | |||
| total_steps = db.Column(db.Integer, server_default=db.text("0")) | |||
| created_by_role = db.Column(db.String(255), nullable=False) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| finished_at = db.Column(db.DateTime) | |||
| @property | |||
| def created_by_account(self): | |||
| created_by_role = CreatedByRole.value_of(self.created_by_role) | |||
| return db.session.get(Account, self.created_by) \ | |||
| if created_by_role == CreatedByRole.ACCOUNT else None | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||
| @property | |||
| def created_by_end_user(self): | |||
| from models.model import EndUser | |||
| created_by_role = CreatedByRole.value_of(self.created_by_role) | |||
| return db.session.get(EndUser, self.created_by) \ | |||
| if created_by_role == CreatedByRole.END_USER else None | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||
| @property | |||
| def graph_dict(self): | |||
| @@ -422,12 +435,12 @@ class WorkflowRun(db.Model): | |||
| return json.loads(self.outputs) if self.outputs else None | |||
| @property | |||
| def message(self) -> Optional['Message']: | |||
| def message(self) -> Optional["Message"]: | |||
| from models.model import Message | |||
| return db.session.query(Message).filter( | |||
| Message.app_id == self.app_id, | |||
| Message.workflow_run_id == self.id | |||
| ).first() | |||
| return ( | |||
| db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() | |||
| ) | |||
| @property | |||
| def workflow(self): | |||
| @@ -435,51 +448,51 @@ class WorkflowRun(db.Model): | |||
| def to_dict(self): | |||
| return { | |||
| 'id': self.id, | |||
| 'tenant_id': self.tenant_id, | |||
| 'app_id': self.app_id, | |||
| 'sequence_number': self.sequence_number, | |||
| 'workflow_id': self.workflow_id, | |||
| 'type': self.type, | |||
| 'triggered_from': self.triggered_from, | |||
| 'version': self.version, | |||
| 'graph': self.graph_dict, | |||
| 'inputs': self.inputs_dict, | |||
| 'status': self.status, | |||
| 'outputs': self.outputs_dict, | |||
| 'error': self.error, | |||
| 'elapsed_time': self.elapsed_time, | |||
| 'total_tokens': self.total_tokens, | |||
| 'total_steps': self.total_steps, | |||
| 'created_by_role': self.created_by_role, | |||
| 'created_by': self.created_by, | |||
| 'created_at': self.created_at, | |||
| 'finished_at': self.finished_at, | |||
| "id": self.id, | |||
| "tenant_id": self.tenant_id, | |||
| "app_id": self.app_id, | |||
| "sequence_number": self.sequence_number, | |||
| "workflow_id": self.workflow_id, | |||
| "type": self.type, | |||
| "triggered_from": self.triggered_from, | |||
| "version": self.version, | |||
| "graph": self.graph_dict, | |||
| "inputs": self.inputs_dict, | |||
| "status": self.status, | |||
| "outputs": self.outputs_dict, | |||
| "error": self.error, | |||
| "elapsed_time": self.elapsed_time, | |||
| "total_tokens": self.total_tokens, | |||
| "total_steps": self.total_steps, | |||
| "created_by_role": self.created_by_role, | |||
| "created_by": self.created_by, | |||
| "created_at": self.created_at, | |||
| "finished_at": self.finished_at, | |||
| } | |||
| @classmethod | |||
| def from_dict(cls, data: dict) -> 'WorkflowRun': | |||
| def from_dict(cls, data: dict) -> "WorkflowRun": | |||
| return cls( | |||
| id=data.get('id'), | |||
| tenant_id=data.get('tenant_id'), | |||
| app_id=data.get('app_id'), | |||
| sequence_number=data.get('sequence_number'), | |||
| workflow_id=data.get('workflow_id'), | |||
| type=data.get('type'), | |||
| triggered_from=data.get('triggered_from'), | |||
| version=data.get('version'), | |||
| graph=json.dumps(data.get('graph')), | |||
| inputs=json.dumps(data.get('inputs')), | |||
| status=data.get('status'), | |||
| outputs=json.dumps(data.get('outputs')), | |||
| error=data.get('error'), | |||
| elapsed_time=data.get('elapsed_time'), | |||
| total_tokens=data.get('total_tokens'), | |||
| total_steps=data.get('total_steps'), | |||
| created_by_role=data.get('created_by_role'), | |||
| created_by=data.get('created_by'), | |||
| created_at=data.get('created_at'), | |||
| finished_at=data.get('finished_at'), | |||
| id=data.get("id"), | |||
| tenant_id=data.get("tenant_id"), | |||
| app_id=data.get("app_id"), | |||
| sequence_number=data.get("sequence_number"), | |||
| workflow_id=data.get("workflow_id"), | |||
| type=data.get("type"), | |||
| triggered_from=data.get("triggered_from"), | |||
| version=data.get("version"), | |||
| graph=json.dumps(data.get("graph")), | |||
| inputs=json.dumps(data.get("inputs")), | |||
| status=data.get("status"), | |||
| outputs=json.dumps(data.get("outputs")), | |||
| error=data.get("error"), | |||
| elapsed_time=data.get("elapsed_time"), | |||
| total_tokens=data.get("total_tokens"), | |||
| total_steps=data.get("total_steps"), | |||
| created_by_role=data.get("created_by_role"), | |||
| created_by=data.get("created_by"), | |||
| created_at=data.get("created_at"), | |||
| finished_at=data.get("finished_at"), | |||
| ) | |||
| @@ -487,11 +500,12 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): | |||
| """ | |||
| Workflow Node Execution Triggered From Enum | |||
| """ | |||
| SINGLE_STEP = 'single-step' | |||
| WORKFLOW_RUN = 'workflow-run' | |||
| SINGLE_STEP = "single-step" | |||
| WORKFLOW_RUN = "workflow-run" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'WorkflowNodeExecutionTriggeredFrom': | |||
| def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": | |||
| """ | |||
| Get value of given mode. | |||
| @@ -501,19 +515,20 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid workflow node execution triggered from value {value}') | |||
| raise ValueError(f"invalid workflow node execution triggered from value {value}") | |||
| class WorkflowNodeExecutionStatus(Enum): | |||
| """ | |||
| Workflow Node Execution Status Enum | |||
| """ | |||
| RUNNING = 'running' | |||
| SUCCEEDED = 'succeeded' | |||
| FAILED = 'failed' | |||
| RUNNING = "running" | |||
| SUCCEEDED = "succeeded" | |||
| FAILED = "failed" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'WorkflowNodeExecutionStatus': | |||
| def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": | |||
| """ | |||
| Get value of given mode. | |||
| @@ -523,7 +538,7 @@ class WorkflowNodeExecutionStatus(Enum): | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid workflow node execution status value {value}') | |||
| raise ValueError(f"invalid workflow node execution status value {value}") | |||
| class WorkflowNodeExecution(db.Model): | |||
| @@ -574,18 +589,31 @@ class WorkflowNodeExecution(db.Model): | |||
| - finished_at (timestamp) End time | |||
| """ | |||
| __tablename__ = 'workflow_node_executions' | |||
| __tablename__ = "workflow_node_executions" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='workflow_node_execution_pkey'), | |||
| db.Index('workflow_node_execution_workflow_run_idx', 'tenant_id', 'app_id', 'workflow_id', | |||
| 'triggered_from', 'workflow_run_id'), | |||
| db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id', | |||
| 'triggered_from', 'node_id'), | |||
| db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id', | |||
| 'triggered_from', 'node_execution_id'), | |||
| db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), | |||
| db.Index( | |||
| "workflow_node_execution_workflow_run_idx", | |||
| "tenant_id", | |||
| "app_id", | |||
| "workflow_id", | |||
| "triggered_from", | |||
| "workflow_run_id", | |||
| ), | |||
| db.Index( | |||
| "workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id" | |||
| ), | |||
| db.Index( | |||
| "workflow_node_execution_id_idx", | |||
| "tenant_id", | |||
| "app_id", | |||
| "workflow_id", | |||
| "triggered_from", | |||
| "node_execution_id", | |||
| ), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| workflow_id = db.Column(StringUUID, nullable=False) | |||
| @@ -602,9 +630,9 @@ class WorkflowNodeExecution(db.Model): | |||
| outputs = db.Column(db.Text) | |||
| status = db.Column(db.String(255), nullable=False) | |||
| error = db.Column(db.Text) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text('0')) | |||
| elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) | |||
| execution_metadata = db.Column(db.Text) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| created_by_role = db.Column(db.String(255), nullable=False) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| finished_at = db.Column(db.DateTime) | |||
| @@ -612,15 +640,14 @@ class WorkflowNodeExecution(db.Model): | |||
| @property | |||
| def created_by_account(self): | |||
| created_by_role = CreatedByRole.value_of(self.created_by_role) | |||
| return db.session.get(Account, self.created_by) \ | |||
| if created_by_role == CreatedByRole.ACCOUNT else None | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||
| @property | |||
| def created_by_end_user(self): | |||
| from models.model import EndUser | |||
| created_by_role = CreatedByRole.value_of(self.created_by_role) | |||
| return db.session.get(EndUser, self.created_by) \ | |||
| if created_by_role == CreatedByRole.END_USER else None | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||
| @property | |||
| def inputs_dict(self): | |||
| @@ -641,15 +668,17 @@ class WorkflowNodeExecution(db.Model): | |||
| @property | |||
| def extras(self): | |||
| from core.tools.tool_manager import ToolManager | |||
| extras = {} | |||
| if self.execution_metadata_dict: | |||
| from core.workflow.entities.node_entities import NodeType | |||
| if self.node_type == NodeType.TOOL.value and 'tool_info' in self.execution_metadata_dict: | |||
| tool_info = self.execution_metadata_dict['tool_info'] | |||
| extras['icon'] = ToolManager.get_tool_icon( | |||
| if self.node_type == NodeType.TOOL.value and "tool_info" in self.execution_metadata_dict: | |||
| tool_info = self.execution_metadata_dict["tool_info"] | |||
| extras["icon"] = ToolManager.get_tool_icon( | |||
| tenant_id=self.tenant_id, | |||
| provider_type=tool_info['provider_type'], | |||
| provider_id=tool_info['provider_id'] | |||
| provider_type=tool_info["provider_type"], | |||
| provider_id=tool_info["provider_id"], | |||
| ) | |||
| return extras | |||
| @@ -659,12 +688,13 @@ class WorkflowAppLogCreatedFrom(Enum): | |||
| """ | |||
| Workflow App Log Created From Enum | |||
| """ | |||
| SERVICE_API = 'service-api' | |||
| WEB_APP = 'web-app' | |||
| INSTALLED_APP = 'installed-app' | |||
| SERVICE_API = "service-api" | |||
| WEB_APP = "web-app" | |||
| INSTALLED_APP = "installed-app" | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'WorkflowAppLogCreatedFrom': | |||
| def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": | |||
| """ | |||
| Get value of given mode. | |||
| @@ -674,7 +704,7 @@ class WorkflowAppLogCreatedFrom(Enum): | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid workflow app log created from value {value}') | |||
| raise ValueError(f"invalid workflow app log created from value {value}") | |||
| class WorkflowAppLog(db.Model): | |||
| @@ -706,13 +736,13 @@ class WorkflowAppLog(db.Model): | |||
| - created_at (timestamp) Creation time | |||
| """ | |||
| __tablename__ = 'workflow_app_logs' | |||
| __tablename__ = "workflow_app_logs" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='workflow_app_log_pkey'), | |||
| db.Index('workflow_app_log_app_idx', 'tenant_id', 'app_id'), | |||
| db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), | |||
| db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| app_id = db.Column(StringUUID, nullable=False) | |||
| workflow_id = db.Column(StringUUID, nullable=False) | |||
| @@ -720,7 +750,7 @@ class WorkflowAppLog(db.Model): | |||
| created_from = db.Column(db.String(255), nullable=False) | |||
| created_by_role = db.Column(db.String(255), nullable=False) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @property | |||
| def workflow_run(self): | |||
| @@ -729,26 +759,27 @@ class WorkflowAppLog(db.Model): | |||
| @property | |||
| def created_by_account(self): | |||
| created_by_role = CreatedByRole.value_of(self.created_by_role) | |||
| return db.session.get(Account, self.created_by) \ | |||
| if created_by_role == CreatedByRole.ACCOUNT else None | |||
| return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None | |||
| @property | |||
| def created_by_end_user(self): | |||
| from models.model import EndUser | |||
| created_by_role = CreatedByRole.value_of(self.created_by_role) | |||
| return db.session.get(EndUser, self.created_by) \ | |||
| if created_by_role == CreatedByRole.END_USER else None | |||
| return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None | |||
| class ConversationVariable(db.Model): | |||
| __tablename__ = 'workflow_conversation_variables' | |||
| __tablename__ = "workflow_conversation_variables" | |||
| id: Mapped[str] = db.Column(StringUUID, primary_key=True) | |||
| conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) | |||
| app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) | |||
| data = db.Column(db.Text, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()) | |||
| created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_at = db.Column( | |||
| db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() | |||
| ) | |||
| def __init__(self, *, id: str, app_id: str, conversation_id: str, data: str) -> None: | |||
| self.id = id | |||
| @@ -757,7 +788,7 @@ class ConversationVariable(db.Model): | |||
| self.data = data | |||
| @classmethod | |||
| def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> 'ConversationVariable': | |||
| def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> "ConversationVariable": | |||
| obj = cls( | |||
| id=variable.id, | |||
| app_id=app_id, | |||
| @@ -68,7 +68,6 @@ ignore = [ | |||
| [tool.ruff.format] | |||
| exclude = [ | |||
| "models/**/*.py", | |||
| "migrations/**/*", | |||
| ] | |||