You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

provider.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from datetime import datetime
  2. from enum import Enum
  3. from typing import Optional
  4. import sqlalchemy as sa
  5. from sqlalchemy import DateTime, String, func, text
  6. from sqlalchemy.orm import Mapped, mapped_column
  7. from .base import Base
  8. from .types import StringUUID
  9. class ProviderType(Enum):
  10. CUSTOM = "custom"
  11. SYSTEM = "system"
  12. @staticmethod
  13. def value_of(value):
  14. for member in ProviderType:
  15. if member.value == value:
  16. return member
  17. raise ValueError(f"No matching enum found for value '{value}'")
  18. class ProviderQuotaType(Enum):
  19. PAID = "paid"
  20. """hosted paid quota"""
  21. FREE = "free"
  22. """third-party free quota"""
  23. TRIAL = "trial"
  24. """hosted trial quota"""
  25. @staticmethod
  26. def value_of(value):
  27. for member in ProviderQuotaType:
  28. if member.value == value:
  29. return member
  30. raise ValueError(f"No matching enum found for value '{value}'")
  31. class Provider(Base):
  32. """
  33. Provider model representing the API providers and their configurations.
  34. """
  35. __tablename__ = "providers"
  36. __table_args__ = (
  37. sa.PrimaryKeyConstraint("id", name="provider_pkey"),
  38. sa.Index("provider_tenant_id_provider_idx", "tenant_id", "provider_name"),
  39. sa.UniqueConstraint(
  40. "tenant_id", "provider_name", "provider_type", "quota_type", name="unique_provider_name_type_quota"
  41. ),
  42. )
  43. id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
  44. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  45. provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
  46. provider_type: Mapped[str] = mapped_column(
  47. String(40), nullable=False, server_default=text("'custom'::character varying")
  48. )
  49. encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
  50. is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
  51. last_used: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
  52. quota_type: Mapped[Optional[str]] = mapped_column(
  53. String(40), nullable=True, server_default=text("''::character varying")
  54. )
  55. quota_limit: Mapped[Optional[int]] = mapped_column(sa.BigInteger, nullable=True)
  56. quota_used: Mapped[Optional[int]] = mapped_column(sa.BigInteger, default=0)
  57. created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  58. updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  59. def __repr__(self):
  60. return (
  61. f"<Provider(id={self.id}, tenant_id={self.tenant_id}, provider_name='{self.provider_name}',"
  62. f" provider_type='{self.provider_type}')>"
  63. )
  64. @property
  65. def token_is_set(self):
  66. """
  67. Returns True if the encrypted_config is not None, indicating that the token is set.
  68. """
  69. return self.encrypted_config is not None
  70. @property
  71. def is_enabled(self):
  72. """
  73. Returns True if the provider is enabled.
  74. """
  75. if self.provider_type == ProviderType.SYSTEM.value:
  76. return self.is_valid
  77. else:
  78. return self.is_valid and self.token_is_set
  79. class ProviderModel(Base):
  80. """
  81. Provider model representing the API provider_models and their configurations.
  82. """
  83. __tablename__ = "provider_models"
  84. __table_args__ = (
  85. sa.PrimaryKeyConstraint("id", name="provider_model_pkey"),
  86. sa.Index("provider_model_tenant_id_provider_idx", "tenant_id", "provider_name"),
  87. sa.UniqueConstraint(
  88. "tenant_id", "provider_name", "model_name", "model_type", name="unique_provider_model_name"
  89. ),
  90. )
  91. id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
  92. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  93. provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
  94. model_name: Mapped[str] = mapped_column(String(255), nullable=False)
  95. model_type: Mapped[str] = mapped_column(String(40), nullable=False)
  96. encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
  97. is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
  98. created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  99. updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  100. class TenantDefaultModel(Base):
  101. __tablename__ = "tenant_default_models"
  102. __table_args__ = (
  103. sa.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
  104. sa.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
  105. )
  106. id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
  107. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  108. provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
  109. model_name: Mapped[str] = mapped_column(String(255), nullable=False)
  110. model_type: Mapped[str] = mapped_column(String(40), nullable=False)
  111. created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  112. updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  113. class TenantPreferredModelProvider(Base):
  114. __tablename__ = "tenant_preferred_model_providers"
  115. __table_args__ = (
  116. sa.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
  117. sa.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
  118. )
  119. id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
  120. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  121. provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
  122. preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False)
  123. created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  124. updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  125. class ProviderOrder(Base):
  126. __tablename__ = "provider_orders"
  127. __table_args__ = (
  128. sa.PrimaryKeyConstraint("id", name="provider_order_pkey"),
  129. sa.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
  130. )
  131. id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
  132. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  133. provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
  134. account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  135. payment_product_id: Mapped[str] = mapped_column(String(191), nullable=False)
  136. payment_id: Mapped[Optional[str]] = mapped_column(String(191))
  137. transaction_id: Mapped[Optional[str]] = mapped_column(String(191))
  138. quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1"))
  139. currency: Mapped[Optional[str]] = mapped_column(String(40))
  140. total_amount: Mapped[Optional[int]] = mapped_column(sa.Integer)
  141. payment_status: Mapped[str] = mapped_column(
  142. String(40), nullable=False, server_default=text("'wait_pay'::character varying")
  143. )
  144. paid_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
  145. pay_failed_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
  146. refunded_at: Mapped[Optional[datetime]] = mapped_column(DateTime)
  147. created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  148. updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  149. class ProviderModelSetting(Base):
  150. """
  151. Provider model settings for record the model enabled status and load balancing status.
  152. """
  153. __tablename__ = "provider_model_settings"
  154. __table_args__ = (
  155. sa.PrimaryKeyConstraint("id", name="provider_model_setting_pkey"),
  156. sa.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
  157. )
  158. id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
  159. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  160. provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
  161. model_name: Mapped[str] = mapped_column(String(255), nullable=False)
  162. model_type: Mapped[str] = mapped_column(String(40), nullable=False)
  163. enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
  164. load_balancing_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"))
  165. created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  166. updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  167. class LoadBalancingModelConfig(Base):
  168. """
  169. Configurations for load balancing models.
  170. """
  171. __tablename__ = "load_balancing_model_configs"
  172. __table_args__ = (
  173. sa.PrimaryKeyConstraint("id", name="load_balancing_model_config_pkey"),
  174. sa.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
  175. )
  176. id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
  177. tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
  178. provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
  179. model_name: Mapped[str] = mapped_column(String(255), nullable=False)
  180. model_type: Mapped[str] = mapped_column(String(40), nullable=False)
  181. name: Mapped[str] = mapped_column(String(255), nullable=False)
  182. encrypted_config: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
  183. enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"))
  184. created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
  185. updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())