You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

account.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import enum
  2. import json
  3. from datetime import datetime
  4. from typing import Optional, cast
  5. import sqlalchemy as sa
  6. from flask_login import UserMixin # type: ignore
  7. from sqlalchemy import DateTime, String, func, select
  8. from sqlalchemy.orm import Mapped, mapped_column, reconstructor
  9. from models.base import Base
  10. from .engine import db
  11. from .types import StringUUID
  12. class TenantAccountRole(enum.StrEnum):
  13. OWNER = "owner"
  14. ADMIN = "admin"
  15. EDITOR = "editor"
  16. NORMAL = "normal"
  17. DATASET_OPERATOR = "dataset_operator"
  18. @staticmethod
  19. def is_valid_role(role: str) -> bool:
  20. if not role:
  21. return False
  22. return role in {
  23. TenantAccountRole.OWNER,
  24. TenantAccountRole.ADMIN,
  25. TenantAccountRole.EDITOR,
  26. TenantAccountRole.NORMAL,
  27. TenantAccountRole.DATASET_OPERATOR,
  28. }
  29. @staticmethod
  30. def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool:
  31. if not role:
  32. return False
  33. return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
  34. @staticmethod
  35. def is_admin_role(role: Optional["TenantAccountRole"]) -> bool:
  36. if not role:
  37. return False
  38. return role == TenantAccountRole.ADMIN
  39. @staticmethod
  40. def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool:
  41. if not role:
  42. return False
  43. return role in {
  44. TenantAccountRole.ADMIN,
  45. TenantAccountRole.EDITOR,
  46. TenantAccountRole.NORMAL,
  47. TenantAccountRole.DATASET_OPERATOR,
  48. }
  49. @staticmethod
  50. def is_editing_role(role: Optional["TenantAccountRole"]) -> bool:
  51. if not role:
  52. return False
  53. return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
  54. @staticmethod
  55. def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool:
  56. if not role:
  57. return False
  58. return role in {
  59. TenantAccountRole.OWNER,
  60. TenantAccountRole.ADMIN,
  61. TenantAccountRole.EDITOR,
  62. TenantAccountRole.DATASET_OPERATOR,
  63. }
  64. class AccountStatus(enum.StrEnum):
  65. PENDING = "pending"
  66. UNINITIALIZED = "uninitialized"
  67. ACTIVE = "active"
  68. BANNED = "banned"
  69. CLOSED = "closed"
  70. class Account(UserMixin, Base):
  71. __tablename__ = "accounts"
  72. __table_args__ = (sa.PrimaryKeyConstraint("id", name="account_pkey"), sa.Index("account_email_idx", "email"))
  73. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  74. name: Mapped[str] = mapped_column(String(255))
  75. email: Mapped[str] = mapped_column(String(255))
  76. password: Mapped[Optional[str]] = mapped_column(String(255))
  77. password_salt: Mapped[Optional[str]] = mapped_column(String(255))
  78. avatar: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
  79. interface_language: Mapped[Optional[str]] = mapped_column(String(255))
  80. interface_theme: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
  81. timezone: Mapped[Optional[str]] = mapped_column(String(255))
  82. last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
  83. last_login_ip: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
  84. last_active_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
  85. status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'::character varying"))
  86. initialized_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
  87. created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
  88. updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
  89. @reconstructor
  90. def init_on_load(self):
  91. self.role: Optional[TenantAccountRole] = None
  92. self._current_tenant: Optional[Tenant] = None
  93. @property
  94. def is_password_set(self):
  95. return self.password is not None
  96. @property
  97. def current_tenant(self):
  98. return self._current_tenant
  99. @current_tenant.setter
  100. def current_tenant(self, tenant: "Tenant"):
  101. ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1))
  102. if ta:
  103. self.role = TenantAccountRole(ta.role)
  104. self._current_tenant = tenant
  105. return
  106. self._current_tenant = None
  107. @property
  108. def current_tenant_id(self) -> str | None:
  109. return self._current_tenant.id if self._current_tenant else None
  110. def set_tenant_id(self, tenant_id: str):
  111. tenant_account_join = cast(
  112. tuple[Tenant, TenantAccountJoin],
  113. (
  114. db.session.query(Tenant, TenantAccountJoin)
  115. .where(Tenant.id == tenant_id)
  116. .where(TenantAccountJoin.tenant_id == Tenant.id)
  117. .where(TenantAccountJoin.account_id == self.id)
  118. .one_or_none()
  119. ),
  120. )
  121. if not tenant_account_join:
  122. return
  123. tenant, join = tenant_account_join
  124. self.role = TenantAccountRole(join.role)
  125. self._current_tenant = tenant
  126. @property
  127. def current_role(self):
  128. return self.role
  129. def get_status(self) -> AccountStatus:
  130. status_str = self.status
  131. return AccountStatus(status_str)
  132. @classmethod
  133. def get_by_openid(cls, provider: str, open_id: str):
  134. account_integrate = (
  135. db.session.query(AccountIntegrate)
  136. .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
  137. .one_or_none()
  138. )
  139. if account_integrate:
  140. return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
  141. return None
  142. # check current_user.current_tenant.current_role in ['admin', 'owner']
  143. @property
  144. def is_admin_or_owner(self):
  145. return TenantAccountRole.is_privileged_role(self.role)
  146. @property
  147. def is_admin(self):
  148. return TenantAccountRole.is_admin_role(self.role)
  149. @property
  150. def is_editor(self):
  151. return TenantAccountRole.is_editing_role(self.role)
  152. @property
  153. def is_dataset_editor(self):
  154. return TenantAccountRole.is_dataset_edit_role(self.role)
  155. @property
  156. def is_dataset_operator(self):
  157. return self.role == TenantAccountRole.DATASET_OPERATOR
  158. class TenantStatus(enum.StrEnum):
  159. NORMAL = "normal"
  160. ARCHIVE = "archive"
  161. class Tenant(Base):
  162. __tablename__ = "tenants"
  163. __table_args__ = (sa.PrimaryKeyConstraint("id", name="tenant_pkey"),)
  164. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  165. name: Mapped[str] = mapped_column(String(255))
  166. encrypt_public_key = db.Column(sa.Text)
  167. plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'::character varying"))
  168. status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'::character varying"))
  169. custom_config: Mapped[Optional[str]] = mapped_column(sa.Text)
  170. created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp(), nullable=False)
  171. updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
  172. def get_accounts(self) -> list[Account]:
  173. return (
  174. db.session.query(Account)
  175. .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
  176. .all()
  177. )
  178. @property
  179. def custom_config_dict(self) -> dict:
  180. return json.loads(self.custom_config) if self.custom_config else {}
  181. @custom_config_dict.setter
  182. def custom_config_dict(self, value: dict):
  183. self.custom_config = json.dumps(value)
  184. class TenantAccountJoin(Base):
  185. __tablename__ = "tenant_account_joins"
  186. __table_args__ = (
  187. sa.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
  188. sa.Index("tenant_account_join_account_id_idx", "account_id"),
  189. sa.Index("tenant_account_join_tenant_id_idx", "tenant_id"),
  190. sa.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
  191. )
  192. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  193. tenant_id: Mapped[str] = mapped_column(StringUUID)
  194. account_id: Mapped[str] = mapped_column(StringUUID)
  195. current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"))
  196. role: Mapped[str] = mapped_column(String(16), server_default="normal")
  197. invited_by: Mapped[Optional[str]] = mapped_column(StringUUID)
  198. created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
  199. updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
  200. class AccountIntegrate(Base):
  201. __tablename__ = "account_integrates"
  202. __table_args__ = (
  203. sa.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
  204. sa.UniqueConstraint("account_id", "provider", name="unique_account_provider"),
  205. sa.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
  206. )
  207. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  208. account_id: Mapped[str] = mapped_column(StringUUID)
  209. provider: Mapped[str] = mapped_column(String(16))
  210. open_id: Mapped[str] = mapped_column(String(255))
  211. encrypted_token: Mapped[str] = mapped_column(String(255))
  212. created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
  213. updated_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp())
  214. class InvitationCode(Base):
  215. __tablename__ = "invitation_codes"
  216. __table_args__ = (
  217. sa.PrimaryKeyConstraint("id", name="invitation_code_pkey"),
  218. sa.Index("invitation_codes_batch_idx", "batch"),
  219. sa.Index("invitation_codes_code_idx", "code", "status"),
  220. )
  221. id: Mapped[int] = mapped_column(sa.Integer)
  222. batch: Mapped[str] = mapped_column(String(255))
  223. code: Mapped[str] = mapped_column(String(32))
  224. status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'::character varying"))
  225. used_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
  226. used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID)
  227. used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
  228. deprecated_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
  229. created_at: Mapped[datetime] = mapped_column(DateTime, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
  230. class TenantPluginPermission(Base):
  231. class InstallPermission(enum.StrEnum):
  232. EVERYONE = "everyone"
  233. ADMINS = "admins"
  234. NOBODY = "noone"
  235. class DebugPermission(enum.StrEnum):
  236. EVERYONE = "everyone"
  237. ADMINS = "admins"
  238. NOBODY = "noone"
  239. __tablename__ = "account_plugin_permissions"
  240. __table_args__ = (
  241. sa.PrimaryKeyConstraint("id", name="account_plugin_permission_pkey"),
  242. sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin"),
  243. )
  244. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  245. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  246. install_permission: Mapped[InstallPermission] = mapped_column(String(16), nullable=False, server_default="everyone")
  247. debug_permission: Mapped[DebugPermission] = mapped_column(String(16), nullable=False, server_default="noone")
  248. class TenantPluginAutoUpgradeStrategy(Base):
  249. class StrategySetting(enum.StrEnum):
  250. DISABLED = "disabled"
  251. FIX_ONLY = "fix_only"
  252. LATEST = "latest"
  253. class UpgradeMode(enum.StrEnum):
  254. ALL = "all"
  255. PARTIAL = "partial"
  256. EXCLUDE = "exclude"
  257. __tablename__ = "tenant_plugin_auto_upgrade_strategies"
  258. __table_args__ = (
  259. sa.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
  260. sa.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
  261. )
  262. id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
  263. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  264. strategy_setting: Mapped[StrategySetting] = mapped_column(String(16), nullable=False, server_default="fix_only")
  265. upgrade_time_of_day: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) # seconds of the day
  266. upgrade_mode: Mapped[UpgradeMode] = mapped_column(String(16), nullable=False, server_default="exclude")
  267. exclude_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
  268. include_plugins: Mapped[list[str]] = mapped_column(sa.ARRAY(String(255)), nullable=False) # plugin_id (author/name)
  269. created_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())
  270. updated_at = db.Column(DateTime, nullable=False, server_default=func.current_timestamp())