You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

provider_configuration.py 75KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785
  1. import json
  2. import logging
  3. from collections import defaultdict
  4. from collections.abc import Iterator, Sequence
  5. from json import JSONDecodeError
  6. from typing import Optional
  7. from pydantic import BaseModel, ConfigDict, Field
  8. from sqlalchemy import func, select
  9. from sqlalchemy.orm import Session
  10. from constants import HIDDEN_VALUE
  11. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  12. from core.entities.provider_entities import (
  13. CustomConfiguration,
  14. ModelSettings,
  15. SystemConfiguration,
  16. SystemConfigurationStatus,
  17. )
  18. from core.helper import encrypter
  19. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  20. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  21. from core.model_runtime.entities.provider_entities import (
  22. ConfigurateMethod,
  23. CredentialFormSchema,
  24. FormType,
  25. ProviderEntity,
  26. )
  27. from core.model_runtime.model_providers.__base.ai_model import AIModel
  28. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  29. from core.plugin.entities.plugin import ModelProviderID
  30. from extensions.ext_database import db
  31. from libs.datetime_utils import naive_utc_now
  32. from models.provider import (
  33. LoadBalancingModelConfig,
  34. Provider,
  35. ProviderCredential,
  36. ProviderModel,
  37. ProviderModelCredential,
  38. ProviderModelSetting,
  39. ProviderType,
  40. TenantPreferredModelProvider,
  41. )
  42. logger = logging.getLogger(__name__)
  43. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  44. class ProviderConfiguration(BaseModel):
  45. """
  46. Provider configuration entity for managing model provider settings.
  47. This class handles:
  48. - Provider credentials CRUD and switch
  49. - Custom Model credentials CRUD and switch
  50. - System vs custom provider switching
  51. - Load balancing configurations
  52. - Model enablement/disablement
  53. TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
  54. """
  55. tenant_id: str
  56. provider: ProviderEntity
  57. preferred_provider_type: ProviderType
  58. using_provider_type: ProviderType
  59. system_configuration: SystemConfiguration
  60. custom_configuration: CustomConfiguration
  61. model_settings: list[ModelSettings]
  62. # pydantic configs
  63. model_config = ConfigDict(protected_namespaces=())
  64. def __init__(self, **data):
  65. super().__init__(**data)
  66. if self.provider.provider not in original_provider_configurate_methods:
  67. original_provider_configurate_methods[self.provider.provider] = []
  68. for configurate_method in self.provider.configurate_methods:
  69. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  70. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  71. if (
  72. any(
  73. len(quota_configuration.restrict_models) > 0
  74. for quota_configuration in self.system_configuration.quota_configurations
  75. )
  76. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  77. ):
  78. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  79. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  80. """
  81. Get current credentials.
  82. :param model_type: model type
  83. :param model: model name
  84. :return:
  85. """
  86. if self.model_settings:
  87. # check if model is disabled by admin
  88. for model_setting in self.model_settings:
  89. if model_setting.model_type == model_type and model_setting.model == model:
  90. if not model_setting.enabled:
  91. raise ValueError(f"Model {model} is disabled.")
  92. if self.using_provider_type == ProviderType.SYSTEM:
  93. restrict_models = []
  94. for quota_configuration in self.system_configuration.quota_configurations:
  95. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  96. continue
  97. restrict_models = quota_configuration.restrict_models
  98. copy_credentials = (
  99. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  100. )
  101. if restrict_models:
  102. for restrict_model in restrict_models:
  103. if (
  104. restrict_model.model_type == model_type
  105. and restrict_model.model == model
  106. and restrict_model.base_model_name
  107. ):
  108. copy_credentials["base_model_name"] = restrict_model.base_model_name
  109. return copy_credentials
  110. else:
  111. credentials = None
  112. if self.custom_configuration.models:
  113. for model_configuration in self.custom_configuration.models:
  114. if model_configuration.model_type == model_type and model_configuration.model == model:
  115. credentials = model_configuration.credentials
  116. break
  117. if not credentials and self.custom_configuration.provider:
  118. credentials = self.custom_configuration.provider.credentials
  119. return credentials
  120. def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
  121. """
  122. Get system configuration status.
  123. :return:
  124. """
  125. if self.system_configuration.enabled is False:
  126. return SystemConfigurationStatus.UNSUPPORTED
  127. current_quota_type = self.system_configuration.current_quota_type
  128. current_quota_configuration = next(
  129. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  130. )
  131. if current_quota_configuration is None:
  132. return None
  133. if not current_quota_configuration:
  134. return SystemConfigurationStatus.UNSUPPORTED
  135. return (
  136. SystemConfigurationStatus.ACTIVE
  137. if current_quota_configuration.is_valid
  138. else SystemConfigurationStatus.QUOTA_EXCEEDED
  139. )
  140. def is_custom_configuration_available(self) -> bool:
  141. """
  142. Check custom configuration available.
  143. :return:
  144. """
  145. has_provider_credentials = (
  146. self.custom_configuration.provider is not None
  147. and len(self.custom_configuration.provider.available_credentials) > 0
  148. )
  149. has_model_configurations = len(self.custom_configuration.models) > 0
  150. return has_provider_credentials or has_model_configurations
  151. def _get_provider_record(self, session: Session) -> Provider | None:
  152. """
  153. Get custom provider record.
  154. """
  155. # get provider
  156. model_provider_id = ModelProviderID(self.provider.provider)
  157. provider_names = [self.provider.provider]
  158. if model_provider_id.is_langgenius():
  159. provider_names.append(model_provider_id.provider_name)
  160. stmt = select(Provider).where(
  161. Provider.tenant_id == self.tenant_id,
  162. Provider.provider_type == ProviderType.CUSTOM.value,
  163. Provider.provider_name.in_(provider_names),
  164. )
  165. return session.execute(stmt).scalar_one_or_none()
  166. def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
  167. """
  168. Get a specific provider credential by ID.
  169. :param credential_id: Credential ID
  170. :return:
  171. """
  172. # Extract secret variables from provider credential schema
  173. credential_secret_variables = self.extract_secret_variables(
  174. self.provider.provider_credential_schema.credential_form_schemas
  175. if self.provider.provider_credential_schema
  176. else []
  177. )
  178. with Session(db.engine) as session:
  179. # Prefer the actual provider record name if exists (to handle aliased provider names)
  180. provider_record = self._get_provider_record(session)
  181. provider_name = provider_record.provider_name if provider_record else self.provider.provider
  182. stmt = select(ProviderCredential).where(
  183. ProviderCredential.id == credential_id,
  184. ProviderCredential.tenant_id == self.tenant_id,
  185. ProviderCredential.provider_name == provider_name,
  186. )
  187. credential = session.execute(stmt).scalar_one_or_none()
  188. if not credential or not credential.encrypted_config:
  189. raise ValueError(f"Credential with id {credential_id} not found.")
  190. try:
  191. credentials = json.loads(credential.encrypted_config)
  192. except JSONDecodeError:
  193. credentials = {}
  194. # Decrypt secret variables
  195. for key in credential_secret_variables:
  196. if key in credentials and credentials[key] is not None:
  197. try:
  198. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  199. except Exception:
  200. pass
  201. return self.obfuscated_credentials(
  202. credentials=credentials,
  203. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  204. if self.provider.provider_credential_schema
  205. else [],
  206. )
  207. def _check_provider_credential_name_exists(
  208. self, credential_name: str, session: Session, exclude_id: str | None = None
  209. ) -> bool:
  210. """
  211. not allowed same name when create or update a credential
  212. """
  213. stmt = select(ProviderCredential.id).where(
  214. ProviderCredential.tenant_id == self.tenant_id,
  215. ProviderCredential.provider_name == self.provider.provider,
  216. ProviderCredential.credential_name == credential_name,
  217. )
  218. if exclude_id:
  219. stmt = stmt.where(ProviderCredential.id != exclude_id)
  220. return session.execute(stmt).scalar_one_or_none() is not None
  221. def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
  222. """
  223. Get provider credentials.
  224. :param credential_id: if provided, return the specified credential
  225. :return:
  226. """
  227. if credential_id:
  228. return self._get_specific_provider_credential(credential_id)
  229. # Default behavior: return current active provider credentials
  230. credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {}
  231. return self.obfuscated_credentials(
  232. credentials=credentials,
  233. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  234. if self.provider.provider_credential_schema
  235. else [],
  236. )
  237. def validate_provider_credentials(
  238. self, credentials: dict, credential_id: str = "", session: Session | None = None
  239. ) -> dict:
  240. """
  241. Validate custom credentials.
  242. :param credentials: provider credentials
  243. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  244. :param session: optional database session
  245. :return:
  246. """
  247. def _validate(s: Session) -> dict:
  248. # Get provider credential secret variables
  249. provider_credential_secret_variables = self.extract_secret_variables(
  250. self.provider.provider_credential_schema.credential_form_schemas
  251. if self.provider.provider_credential_schema
  252. else []
  253. )
  254. if credential_id:
  255. try:
  256. stmt = select(ProviderCredential).where(
  257. ProviderCredential.tenant_id == self.tenant_id,
  258. ProviderCredential.provider_name == self.provider.provider,
  259. ProviderCredential.id == credential_id,
  260. )
  261. credential_record = s.execute(stmt).scalar_one_or_none()
  262. # fix origin data
  263. if credential_record and credential_record.encrypted_config:
  264. if not credential_record.encrypted_config.startswith("{"):
  265. original_credentials = {"openai_api_key": credential_record.encrypted_config}
  266. else:
  267. original_credentials = json.loads(credential_record.encrypted_config)
  268. else:
  269. original_credentials = {}
  270. except JSONDecodeError:
  271. original_credentials = {}
  272. # encrypt credentials
  273. for key, value in credentials.items():
  274. if key in provider_credential_secret_variables:
  275. # if send [__HIDDEN__] in secret input, it will be same as original value
  276. if value == HIDDEN_VALUE and key in original_credentials:
  277. credentials[key] = encrypter.decrypt_token(
  278. tenant_id=self.tenant_id, token=original_credentials[key]
  279. )
  280. model_provider_factory = ModelProviderFactory(self.tenant_id)
  281. validated_credentials = model_provider_factory.provider_credentials_validate(
  282. provider=self.provider.provider, credentials=credentials
  283. )
  284. for key, value in validated_credentials.items():
  285. if key in provider_credential_secret_variables:
  286. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  287. return validated_credentials
  288. if session:
  289. return _validate(session)
  290. else:
  291. with Session(db.engine) as new_session:
  292. return _validate(new_session)
  293. def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
  294. """
  295. Add custom provider credentials.
  296. :param credentials: provider credentials
  297. :param credential_name: credential name
  298. :return:
  299. """
  300. with Session(db.engine) as session:
  301. if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
  302. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  303. credentials = self.validate_provider_credentials(credentials=credentials, session=session)
  304. provider_record = self._get_provider_record(session)
  305. try:
  306. new_record = ProviderCredential(
  307. tenant_id=self.tenant_id,
  308. provider_name=self.provider.provider,
  309. encrypted_config=json.dumps(credentials),
  310. credential_name=credential_name,
  311. )
  312. session.add(new_record)
  313. session.flush()
  314. if not provider_record:
  315. # If provider record does not exist, create it
  316. provider_record = Provider(
  317. tenant_id=self.tenant_id,
  318. provider_name=self.provider.provider,
  319. provider_type=ProviderType.CUSTOM.value,
  320. is_valid=True,
  321. credential_id=new_record.id,
  322. )
  323. session.add(provider_record)
  324. provider_model_credentials_cache = ProviderCredentialsCache(
  325. tenant_id=self.tenant_id,
  326. identity_id=provider_record.id,
  327. cache_type=ProviderCredentialsCacheType.PROVIDER,
  328. )
  329. provider_model_credentials_cache.delete()
  330. self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
  331. session.commit()
  332. except Exception:
  333. session.rollback()
  334. raise
  335. def update_provider_credential(
  336. self,
  337. credentials: dict,
  338. credential_id: str,
  339. credential_name: str,
  340. ) -> None:
  341. """
  342. update a saved provider credential (by credential_id).
  343. :param credentials: provider credentials
  344. :param credential_id: credential id
  345. :param credential_name: credential name
  346. :return:
  347. """
  348. with Session(db.engine) as session:
  349. if self._check_provider_credential_name_exists(
  350. credential_name=credential_name, session=session, exclude_id=credential_id
  351. ):
  352. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  353. credentials = self.validate_provider_credentials(
  354. credentials=credentials, credential_id=credential_id, session=session
  355. )
  356. provider_record = self._get_provider_record(session)
  357. stmt = select(ProviderCredential).where(
  358. ProviderCredential.id == credential_id,
  359. ProviderCredential.tenant_id == self.tenant_id,
  360. ProviderCredential.provider_name == self.provider.provider,
  361. )
  362. # Get the credential record to update
  363. credential_record = session.execute(stmt).scalar_one_or_none()
  364. if not credential_record:
  365. raise ValueError("Credential record not found.")
  366. try:
  367. # Update credential
  368. credential_record.encrypted_config = json.dumps(credentials)
  369. credential_record.credential_name = credential_name
  370. credential_record.updated_at = naive_utc_now()
  371. session.commit()
  372. if provider_record and provider_record.credential_id == credential_id:
  373. provider_model_credentials_cache = ProviderCredentialsCache(
  374. tenant_id=self.tenant_id,
  375. identity_id=provider_record.id,
  376. cache_type=ProviderCredentialsCacheType.PROVIDER,
  377. )
  378. provider_model_credentials_cache.delete()
  379. self._update_load_balancing_configs_with_credential(
  380. credential_id=credential_id,
  381. credential_record=credential_record,
  382. credential_source="provider",
  383. session=session,
  384. )
  385. except Exception:
  386. session.rollback()
  387. raise
  388. def _update_load_balancing_configs_with_credential(
  389. self,
  390. credential_id: str,
  391. credential_record: ProviderCredential | ProviderModelCredential,
  392. credential_source: str,
  393. session: Session,
  394. ) -> None:
  395. """
  396. Update load balancing configurations that reference the given credential_id.
  397. :param credential_id: credential id
  398. :param credential_record: the encrypted_config and credential_name
  399. :param credential_source: the credential comes from the provider_credential(`provider`)
  400. or the provider_model_credential(`custom_model`)
  401. :param session: the database session
  402. :return:
  403. """
  404. # Find all load balancing configs that use this credential_id
  405. stmt = select(LoadBalancingModelConfig).where(
  406. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  407. LoadBalancingModelConfig.provider_name == self.provider.provider,
  408. LoadBalancingModelConfig.credential_id == credential_id,
  409. LoadBalancingModelConfig.credential_source_type == credential_source,
  410. )
  411. load_balancing_configs = session.execute(stmt).scalars().all()
  412. if not load_balancing_configs:
  413. return
  414. # Update each load balancing config with the new credentials
  415. for lb_config in load_balancing_configs:
  416. # Update the encrypted_config with the new credentials
  417. lb_config.encrypted_config = credential_record.encrypted_config
  418. lb_config.name = credential_record.credential_name
  419. lb_config.updated_at = naive_utc_now()
  420. # Clear cache for this load balancing config
  421. lb_credentials_cache = ProviderCredentialsCache(
  422. tenant_id=self.tenant_id,
  423. identity_id=lb_config.id,
  424. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  425. )
  426. lb_credentials_cache.delete()
  427. session.commit()
  428. def delete_provider_credential(self, credential_id: str) -> None:
  429. """
  430. Delete a saved provider credential (by credential_id).
  431. :param credential_id: credential id
  432. :return:
  433. """
  434. with Session(db.engine) as session:
  435. stmt = select(ProviderCredential).where(
  436. ProviderCredential.id == credential_id,
  437. ProviderCredential.tenant_id == self.tenant_id,
  438. ProviderCredential.provider_name == self.provider.provider,
  439. )
  440. # Get the credential record to update
  441. credential_record = session.execute(stmt).scalar_one_or_none()
  442. if not credential_record:
  443. raise ValueError("Credential record not found.")
  444. # Check if this credential is used in load balancing configs
  445. lb_stmt = select(LoadBalancingModelConfig).where(
  446. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  447. LoadBalancingModelConfig.provider_name == self.provider.provider,
  448. LoadBalancingModelConfig.credential_id == credential_id,
  449. LoadBalancingModelConfig.credential_source_type == "provider",
  450. )
  451. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  452. try:
  453. for lb_config in lb_configs_using_credential:
  454. lb_credentials_cache = ProviderCredentialsCache(
  455. tenant_id=self.tenant_id,
  456. identity_id=lb_config.id,
  457. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  458. )
  459. lb_credentials_cache.delete()
  460. lb_config.credential_id = None
  461. lb_config.encrypted_config = None
  462. lb_config.enabled = False
  463. lb_config.name = "__delete__"
  464. lb_config.updated_at = naive_utc_now()
  465. session.add(lb_config)
  466. # Check if this is the currently active credential
  467. provider_record = self._get_provider_record(session)
  468. # Check available credentials count BEFORE deleting
  469. # if this is the last credential, we need to delete the provider record
  470. count_stmt = select(func.count(ProviderCredential.id)).where(
  471. ProviderCredential.tenant_id == self.tenant_id,
  472. ProviderCredential.provider_name == self.provider.provider,
  473. )
  474. available_credentials_count = session.execute(count_stmt).scalar() or 0
  475. session.delete(credential_record)
  476. if provider_record and available_credentials_count <= 1:
  477. # If all credentials are deleted, delete the provider record
  478. session.delete(provider_record)
  479. provider_model_credentials_cache = ProviderCredentialsCache(
  480. tenant_id=self.tenant_id,
  481. identity_id=provider_record.id,
  482. cache_type=ProviderCredentialsCacheType.PROVIDER,
  483. )
  484. provider_model_credentials_cache.delete()
  485. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  486. elif provider_record and provider_record.credential_id == credential_id:
  487. provider_record.credential_id = None
  488. provider_record.updated_at = naive_utc_now()
  489. provider_model_credentials_cache = ProviderCredentialsCache(
  490. tenant_id=self.tenant_id,
  491. identity_id=provider_record.id,
  492. cache_type=ProviderCredentialsCacheType.PROVIDER,
  493. )
  494. provider_model_credentials_cache.delete()
  495. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  496. session.commit()
  497. except Exception:
  498. session.rollback()
  499. raise
  500. def switch_active_provider_credential(self, credential_id: str) -> None:
  501. """
  502. Switch active provider credential (copy the selected one into current active snapshot).
  503. :param credential_id: credential id
  504. :return:
  505. """
  506. with Session(db.engine) as session:
  507. stmt = select(ProviderCredential).where(
  508. ProviderCredential.id == credential_id,
  509. ProviderCredential.tenant_id == self.tenant_id,
  510. ProviderCredential.provider_name == self.provider.provider,
  511. )
  512. credential_record = session.execute(stmt).scalar_one_or_none()
  513. if not credential_record:
  514. raise ValueError("Credential record not found.")
  515. provider_record = self._get_provider_record(session)
  516. if not provider_record:
  517. raise ValueError("Provider record not found.")
  518. try:
  519. provider_record.credential_id = credential_record.id
  520. provider_record.updated_at = naive_utc_now()
  521. session.commit()
  522. provider_model_credentials_cache = ProviderCredentialsCache(
  523. tenant_id=self.tenant_id,
  524. identity_id=provider_record.id,
  525. cache_type=ProviderCredentialsCacheType.PROVIDER,
  526. )
  527. provider_model_credentials_cache.delete()
  528. self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session)
  529. except Exception:
  530. session.rollback()
  531. raise
  532. def _get_custom_model_record(
  533. self,
  534. model_type: ModelType,
  535. model: str,
  536. session: Session,
  537. ) -> ProviderModel | None:
  538. """
  539. Get custom model credentials.
  540. """
  541. # get provider model
  542. model_provider_id = ModelProviderID(self.provider.provider)
  543. provider_names = [self.provider.provider]
  544. if model_provider_id.is_langgenius():
  545. provider_names.append(model_provider_id.provider_name)
  546. stmt = select(ProviderModel).where(
  547. ProviderModel.tenant_id == self.tenant_id,
  548. ProviderModel.provider_name.in_(provider_names),
  549. ProviderModel.model_name == model,
  550. ProviderModel.model_type == model_type.to_origin_model_type(),
  551. )
  552. return session.execute(stmt).scalar_one_or_none()
  553. def _get_specific_custom_model_credential(
  554. self, model_type: ModelType, model: str, credential_id: str
  555. ) -> dict | None:
  556. """
  557. Get a specific provider credential by ID.
  558. :param credential_id: Credential ID
  559. :return:
  560. """
  561. model_credential_secret_variables = self.extract_secret_variables(
  562. self.provider.model_credential_schema.credential_form_schemas
  563. if self.provider.model_credential_schema
  564. else []
  565. )
  566. with Session(db.engine) as session:
  567. stmt = select(ProviderModelCredential).where(
  568. ProviderModelCredential.id == credential_id,
  569. ProviderModelCredential.tenant_id == self.tenant_id,
  570. ProviderModelCredential.provider_name == self.provider.provider,
  571. ProviderModelCredential.model_name == model,
  572. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  573. )
  574. credential_record = session.execute(stmt).scalar_one_or_none()
  575. if not credential_record or not credential_record.encrypted_config:
  576. raise ValueError(f"Credential with id {credential_id} not found.")
  577. try:
  578. credentials = json.loads(credential_record.encrypted_config)
  579. except JSONDecodeError:
  580. credentials = {}
  581. # Decrypt secret variables
  582. for key in model_credential_secret_variables:
  583. if key in credentials and credentials[key] is not None:
  584. try:
  585. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  586. except Exception:
  587. pass
  588. current_credential_id = credential_record.id
  589. current_credential_name = credential_record.credential_name
  590. credentials = self.obfuscated_credentials(
  591. credentials=credentials,
  592. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  593. if self.provider.model_credential_schema
  594. else [],
  595. )
  596. return {
  597. "current_credential_id": current_credential_id,
  598. "current_credential_name": current_credential_name,
  599. "credentials": credentials,
  600. }
  601. def _check_custom_model_credential_name_exists(
  602. self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None
  603. ) -> bool:
  604. """
  605. not allowed same name when create or update a credential
  606. """
  607. stmt = select(ProviderModelCredential).where(
  608. ProviderModelCredential.tenant_id == self.tenant_id,
  609. ProviderModelCredential.provider_name == self.provider.provider,
  610. ProviderModelCredential.model_name == model,
  611. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  612. ProviderModelCredential.credential_name == credential_name,
  613. )
  614. if exclude_id:
  615. stmt = stmt.where(ProviderModelCredential.id != exclude_id)
  616. return session.execute(stmt).scalar_one_or_none() is not None
  617. def get_custom_model_credential(
  618. self, model_type: ModelType, model: str, credential_id: str | None
  619. ) -> Optional[dict]:
  620. """
  621. Get custom model credentials.
  622. :param model_type: model type
  623. :param model: model name
  624. :return:
  625. """
  626. # If credential_id is provided, return the specific credential
  627. if credential_id:
  628. return self._get_specific_custom_model_credential(
  629. model_type=model_type, model=model, credential_id=credential_id
  630. )
  631. for model_configuration in self.custom_configuration.models:
  632. if (
  633. model_configuration.model_type == model_type
  634. and model_configuration.model == model
  635. and model_configuration.credentials
  636. ):
  637. current_credential_id = model_configuration.current_credential_id
  638. current_credential_name = model_configuration.current_credential_name
  639. credentials = self.obfuscated_credentials(
  640. credentials=model_configuration.credentials,
  641. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  642. if self.provider.model_credential_schema
  643. else [],
  644. )
  645. return {
  646. "current_credential_id": current_credential_id,
  647. "current_credential_name": current_credential_name,
  648. "credentials": credentials,
  649. }
  650. return None
  651. def validate_custom_model_credentials(
  652. self,
  653. model_type: ModelType,
  654. model: str,
  655. credentials: dict,
  656. credential_id: str = "",
  657. session: Session | None = None,
  658. ) -> dict:
  659. """
  660. Validate custom model credentials.
  661. :param model_type: model type
  662. :param model: model name
  663. :param credentials: model credentials dict
  664. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  665. :return:
  666. """
  667. def _validate(s: Session) -> dict:
  668. # Get provider credential secret variables
  669. provider_credential_secret_variables = self.extract_secret_variables(
  670. self.provider.model_credential_schema.credential_form_schemas
  671. if self.provider.model_credential_schema
  672. else []
  673. )
  674. if credential_id:
  675. try:
  676. stmt = select(ProviderModelCredential).where(
  677. ProviderModelCredential.id == credential_id,
  678. ProviderModelCredential.tenant_id == self.tenant_id,
  679. ProviderModelCredential.provider_name == self.provider.provider,
  680. ProviderModelCredential.model_name == model,
  681. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  682. )
  683. credential_record = s.execute(stmt).scalar_one_or_none()
  684. original_credentials = (
  685. json.loads(credential_record.encrypted_config)
  686. if credential_record and credential_record.encrypted_config
  687. else {}
  688. )
  689. except JSONDecodeError:
  690. original_credentials = {}
  691. # decrypt credentials
  692. for key, value in credentials.items():
  693. if key in provider_credential_secret_variables:
  694. # if send [__HIDDEN__] in secret input, it will be same as original value
  695. if value == HIDDEN_VALUE and key in original_credentials:
  696. credentials[key] = encrypter.decrypt_token(
  697. tenant_id=self.tenant_id, token=original_credentials[key]
  698. )
  699. model_provider_factory = ModelProviderFactory(self.tenant_id)
  700. validated_credentials = model_provider_factory.model_credentials_validate(
  701. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  702. )
  703. for key, value in validated_credentials.items():
  704. if key in provider_credential_secret_variables:
  705. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  706. return validated_credentials
  707. if session:
  708. return _validate(session)
  709. else:
  710. with Session(db.engine) as new_session:
  711. return _validate(new_session)
  712. def create_custom_model_credential(
  713. self, model_type: ModelType, model: str, credentials: dict, credential_name: str
  714. ) -> None:
  715. """
  716. Create a custom model credential.
  717. :param model_type: model type
  718. :param model: model name
  719. :param credentials: model credentials dict
  720. :return:
  721. """
  722. with Session(db.engine) as session:
  723. if self._check_custom_model_credential_name_exists(
  724. model=model, model_type=model_type, credential_name=credential_name, session=session
  725. ):
  726. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  727. # validate custom model config
  728. credentials = self.validate_custom_model_credentials(
  729. model_type=model_type, model=model, credentials=credentials, session=session
  730. )
  731. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  732. try:
  733. credential = ProviderModelCredential(
  734. tenant_id=self.tenant_id,
  735. provider_name=self.provider.provider,
  736. model_name=model,
  737. model_type=model_type.to_origin_model_type(),
  738. encrypted_config=json.dumps(credentials),
  739. credential_name=credential_name,
  740. )
  741. session.add(credential)
  742. session.flush()
  743. # save provider model
  744. if not provider_model_record:
  745. provider_model_record = ProviderModel(
  746. tenant_id=self.tenant_id,
  747. provider_name=self.provider.provider,
  748. model_name=model,
  749. model_type=model_type.to_origin_model_type(),
  750. credential_id=credential.id,
  751. is_valid=True,
  752. )
  753. session.add(provider_model_record)
  754. session.commit()
  755. provider_model_credentials_cache = ProviderCredentialsCache(
  756. tenant_id=self.tenant_id,
  757. identity_id=provider_model_record.id,
  758. cache_type=ProviderCredentialsCacheType.MODEL,
  759. )
  760. provider_model_credentials_cache.delete()
  761. except Exception:
  762. session.rollback()
  763. raise
  764. def update_custom_model_credential(
  765. self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
  766. ) -> None:
  767. """
  768. Update a custom model credential.
  769. :param model_type: model type
  770. :param model: model name
  771. :param credentials: model credentials dict
  772. :param credential_name: credential name
  773. :param credential_id: credential id
  774. :return:
  775. """
  776. with Session(db.engine) as session:
  777. if self._check_custom_model_credential_name_exists(
  778. model=model,
  779. model_type=model_type,
  780. credential_name=credential_name,
  781. session=session,
  782. exclude_id=credential_id,
  783. ):
  784. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  785. # validate custom model config
  786. credentials = self.validate_custom_model_credentials(
  787. model_type=model_type,
  788. model=model,
  789. credentials=credentials,
  790. credential_id=credential_id,
  791. session=session,
  792. )
  793. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  794. stmt = select(ProviderModelCredential).where(
  795. ProviderModelCredential.id == credential_id,
  796. ProviderModelCredential.tenant_id == self.tenant_id,
  797. ProviderModelCredential.provider_name == self.provider.provider,
  798. ProviderModelCredential.model_name == model,
  799. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  800. )
  801. credential_record = session.execute(stmt).scalar_one_or_none()
  802. if not credential_record:
  803. raise ValueError("Credential record not found.")
  804. try:
  805. # Update credential
  806. credential_record.encrypted_config = json.dumps(credentials)
  807. credential_record.credential_name = credential_name
  808. credential_record.updated_at = naive_utc_now()
  809. session.commit()
  810. if provider_model_record and provider_model_record.credential_id == credential_id:
  811. provider_model_credentials_cache = ProviderCredentialsCache(
  812. tenant_id=self.tenant_id,
  813. identity_id=provider_model_record.id,
  814. cache_type=ProviderCredentialsCacheType.MODEL,
  815. )
  816. provider_model_credentials_cache.delete()
  817. self._update_load_balancing_configs_with_credential(
  818. credential_id=credential_id,
  819. credential_record=credential_record,
  820. credential_source="custom_model",
  821. session=session,
  822. )
  823. except Exception:
  824. session.rollback()
  825. raise
  826. def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
  827. """
  828. Delete a saved provider credential (by credential_id).
  829. :param credential_id: credential id
  830. :return:
  831. """
  832. with Session(db.engine) as session:
  833. stmt = select(ProviderModelCredential).where(
  834. ProviderModelCredential.id == credential_id,
  835. ProviderModelCredential.tenant_id == self.tenant_id,
  836. ProviderModelCredential.provider_name == self.provider.provider,
  837. ProviderModelCredential.model_name == model,
  838. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  839. )
  840. credential_record = session.execute(stmt).scalar_one_or_none()
  841. if not credential_record:
  842. raise ValueError("Credential record not found.")
  843. lb_stmt = select(LoadBalancingModelConfig).where(
  844. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  845. LoadBalancingModelConfig.provider_name == self.provider.provider,
  846. LoadBalancingModelConfig.credential_id == credential_id,
  847. LoadBalancingModelConfig.credential_source_type == "custom_model",
  848. )
  849. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  850. try:
  851. for lb_config in lb_configs_using_credential:
  852. lb_credentials_cache = ProviderCredentialsCache(
  853. tenant_id=self.tenant_id,
  854. identity_id=lb_config.id,
  855. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  856. )
  857. lb_credentials_cache.delete()
  858. lb_config.credential_id = None
  859. lb_config.encrypted_config = None
  860. lb_config.enabled = False
  861. lb_config.name = "__delete__"
  862. lb_config.updated_at = naive_utc_now()
  863. session.add(lb_config)
  864. # Check if this is the currently active credential
  865. provider_model_record = self._get_custom_model_record(model_type, model, session=session)
  866. # Check available credentials count BEFORE deleting
  867. # if this is the last credential, we need to delete the custom model record
  868. count_stmt = select(func.count(ProviderModelCredential.id)).where(
  869. ProviderModelCredential.tenant_id == self.tenant_id,
  870. ProviderModelCredential.provider_name == self.provider.provider,
  871. ProviderModelCredential.model_name == model,
  872. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  873. )
  874. available_credentials_count = session.execute(count_stmt).scalar() or 0
  875. session.delete(credential_record)
  876. if provider_model_record and available_credentials_count <= 1:
  877. # If all credentials are deleted, delete the custom model record
  878. session.delete(provider_model_record)
  879. elif provider_model_record and provider_model_record.credential_id == credential_id:
  880. provider_model_record.credential_id = None
  881. provider_model_record.updated_at = naive_utc_now()
  882. provider_model_credentials_cache = ProviderCredentialsCache(
  883. tenant_id=self.tenant_id,
  884. identity_id=provider_model_record.id,
  885. cache_type=ProviderCredentialsCacheType.PROVIDER,
  886. )
  887. provider_model_credentials_cache.delete()
  888. session.commit()
  889. except Exception:
  890. session.rollback()
  891. raise
  892. def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
  893. """
  894. if model list exist this custom model, switch the custom model credential.
  895. if model list not exist this custom model, use the credential to add a new custom model record.
  896. :param model_type: model type
  897. :param model: model name
  898. :param credential_id: credential id
  899. :return:
  900. """
  901. with Session(db.engine) as session:
  902. stmt = select(ProviderModelCredential).where(
  903. ProviderModelCredential.id == credential_id,
  904. ProviderModelCredential.tenant_id == self.tenant_id,
  905. ProviderModelCredential.provider_name == self.provider.provider,
  906. ProviderModelCredential.model_name == model,
  907. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  908. )
  909. credential_record = session.execute(stmt).scalar_one_or_none()
  910. if not credential_record:
  911. raise ValueError("Credential record not found.")
  912. # validate custom model config
  913. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  914. if not provider_model_record:
  915. # create provider model record
  916. provider_model_record = ProviderModel(
  917. tenant_id=self.tenant_id,
  918. provider_name=self.provider.provider,
  919. model_name=model,
  920. model_type=model_type.to_origin_model_type(),
  921. credential_id=credential_id,
  922. )
  923. else:
  924. if provider_model_record.credential_id == credential_record.id:
  925. raise ValueError("Can't add same credential")
  926. provider_model_record.credential_id = credential_record.id
  927. provider_model_record.updated_at = naive_utc_now()
  928. session.add(provider_model_record)
  929. session.commit()
  930. def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
  931. """
  932. switch the custom model credential.
  933. :param model_type: model type
  934. :param model: model name
  935. :param credential_id: credential id
  936. :return:
  937. """
  938. with Session(db.engine) as session:
  939. stmt = select(ProviderModelCredential).where(
  940. ProviderModelCredential.id == credential_id,
  941. ProviderModelCredential.tenant_id == self.tenant_id,
  942. ProviderModelCredential.provider_name == self.provider.provider,
  943. ProviderModelCredential.model_name == model,
  944. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  945. )
  946. credential_record = session.execute(stmt).scalar_one_or_none()
  947. if not credential_record:
  948. raise ValueError("Credential record not found.")
  949. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  950. if not provider_model_record:
  951. raise ValueError("The custom model record not found.")
  952. provider_model_record.credential_id = credential_record.id
  953. provider_model_record.updated_at = naive_utc_now()
  954. session.add(provider_model_record)
  955. session.commit()
  956. def delete_custom_model(self, model_type: ModelType, model: str) -> None:
  957. """
  958. Delete custom model.
  959. :param model_type: model type
  960. :param model: model name
  961. :return:
  962. """
  963. with Session(db.engine) as session:
  964. # get provider model
  965. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  966. # delete provider model
  967. if provider_model_record:
  968. session.delete(provider_model_record)
  969. session.commit()
  970. provider_model_credentials_cache = ProviderCredentialsCache(
  971. tenant_id=self.tenant_id,
  972. identity_id=provider_model_record.id,
  973. cache_type=ProviderCredentialsCacheType.MODEL,
  974. )
  975. provider_model_credentials_cache.delete()
  976. def _get_provider_model_setting(
  977. self, model_type: ModelType, model: str, session: Session
  978. ) -> ProviderModelSetting | None:
  979. """
  980. Get provider model setting.
  981. """
  982. model_provider_id = ModelProviderID(self.provider.provider)
  983. provider_names = [self.provider.provider]
  984. if model_provider_id.is_langgenius():
  985. provider_names.append(model_provider_id.provider_name)
  986. stmt = select(ProviderModelSetting).where(
  987. ProviderModelSetting.tenant_id == self.tenant_id,
  988. ProviderModelSetting.provider_name.in_(provider_names),
  989. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  990. ProviderModelSetting.model_name == model,
  991. )
  992. return session.execute(stmt).scalars().first()
  993. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  994. """
  995. Enable model.
  996. :param model_type: model type
  997. :param model: model name
  998. :return:
  999. """
  1000. with Session(db.engine) as session:
  1001. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1002. if model_setting:
  1003. model_setting.enabled = True
  1004. model_setting.updated_at = naive_utc_now()
  1005. else:
  1006. model_setting = ProviderModelSetting(
  1007. tenant_id=self.tenant_id,
  1008. provider_name=self.provider.provider,
  1009. model_type=model_type.to_origin_model_type(),
  1010. model_name=model,
  1011. enabled=True,
  1012. )
  1013. session.add(model_setting)
  1014. session.commit()
  1015. return model_setting
  1016. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1017. """
  1018. Disable model.
  1019. :param model_type: model type
  1020. :param model: model name
  1021. :return:
  1022. """
  1023. with Session(db.engine) as session:
  1024. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1025. if model_setting:
  1026. model_setting.enabled = False
  1027. model_setting.updated_at = naive_utc_now()
  1028. else:
  1029. model_setting = ProviderModelSetting(
  1030. tenant_id=self.tenant_id,
  1031. provider_name=self.provider.provider,
  1032. model_type=model_type.to_origin_model_type(),
  1033. model_name=model,
  1034. enabled=False,
  1035. )
  1036. session.add(model_setting)
  1037. session.commit()
  1038. return model_setting
  1039. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  1040. """
  1041. Get provider model setting.
  1042. :param model_type: model type
  1043. :param model: model name
  1044. :return:
  1045. """
  1046. with Session(db.engine) as session:
  1047. return self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1048. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1049. """
  1050. Enable model load balancing.
  1051. :param model_type: model type
  1052. :param model: model name
  1053. :return:
  1054. """
  1055. model_provider_id = ModelProviderID(self.provider.provider)
  1056. provider_names = [self.provider.provider]
  1057. if model_provider_id.is_langgenius():
  1058. provider_names.append(model_provider_id.provider_name)
  1059. with Session(db.engine) as session:
  1060. stmt = select(func.count(LoadBalancingModelConfig.id)).where(
  1061. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  1062. LoadBalancingModelConfig.provider_name.in_(provider_names),
  1063. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  1064. LoadBalancingModelConfig.model_name == model,
  1065. )
  1066. load_balancing_config_count = session.execute(stmt).scalar() or 0
  1067. if load_balancing_config_count <= 1:
  1068. raise ValueError("Model load balancing configuration must be more than 1.")
  1069. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1070. if model_setting:
  1071. model_setting.load_balancing_enabled = True
  1072. model_setting.updated_at = naive_utc_now()
  1073. else:
  1074. model_setting = ProviderModelSetting(
  1075. tenant_id=self.tenant_id,
  1076. provider_name=self.provider.provider,
  1077. model_type=model_type.to_origin_model_type(),
  1078. model_name=model,
  1079. load_balancing_enabled=True,
  1080. )
  1081. session.add(model_setting)
  1082. session.commit()
  1083. return model_setting
  1084. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1085. """
  1086. Disable model load balancing.
  1087. :param model_type: model type
  1088. :param model: model name
  1089. :return:
  1090. """
  1091. with Session(db.engine) as session:
  1092. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1093. if model_setting:
  1094. model_setting.load_balancing_enabled = False
  1095. model_setting.updated_at = naive_utc_now()
  1096. else:
  1097. model_setting = ProviderModelSetting(
  1098. tenant_id=self.tenant_id,
  1099. provider_name=self.provider.provider,
  1100. model_type=model_type.to_origin_model_type(),
  1101. model_name=model,
  1102. load_balancing_enabled=False,
  1103. )
  1104. session.add(model_setting)
  1105. session.commit()
  1106. return model_setting
  1107. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  1108. """
  1109. Get current model type instance.
  1110. :param model_type: model type
  1111. :return:
  1112. """
  1113. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1114. # Get model instance of LLM
  1115. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  1116. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
  1117. """
  1118. Get model schema
  1119. """
  1120. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1121. return model_provider_factory.get_model_schema(
  1122. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  1123. )
  1124. def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None:
  1125. """
  1126. Switch preferred provider type.
  1127. :param provider_type:
  1128. :return:
  1129. """
  1130. if provider_type == self.preferred_provider_type:
  1131. return
  1132. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  1133. return
  1134. def _switch(s: Session) -> None:
  1135. # get preferred provider
  1136. model_provider_id = ModelProviderID(self.provider.provider)
  1137. provider_names = [self.provider.provider]
  1138. if model_provider_id.is_langgenius():
  1139. provider_names.append(model_provider_id.provider_name)
  1140. stmt = select(TenantPreferredModelProvider).where(
  1141. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  1142. TenantPreferredModelProvider.provider_name.in_(provider_names),
  1143. )
  1144. preferred_model_provider = s.execute(stmt).scalars().first()
  1145. if preferred_model_provider:
  1146. preferred_model_provider.preferred_provider_type = provider_type.value
  1147. else:
  1148. preferred_model_provider = TenantPreferredModelProvider(
  1149. tenant_id=self.tenant_id,
  1150. provider_name=self.provider.provider,
  1151. preferred_provider_type=provider_type.value,
  1152. )
  1153. s.add(preferred_model_provider)
  1154. s.commit()
  1155. if session:
  1156. return _switch(session)
  1157. else:
  1158. with Session(db.engine) as session:
  1159. return _switch(session)
  1160. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  1161. """
  1162. Extract secret input form variables.
  1163. :param credential_form_schemas:
  1164. :return:
  1165. """
  1166. secret_input_form_variables = []
  1167. for credential_form_schema in credential_form_schemas:
  1168. if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
  1169. secret_input_form_variables.append(credential_form_schema.variable)
  1170. return secret_input_form_variables
  1171. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  1172. """
  1173. Obfuscated credentials.
  1174. :param credentials: credentials
  1175. :param credential_form_schemas: credential form schemas
  1176. :return:
  1177. """
  1178. # Get provider credential secret variables
  1179. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  1180. # Obfuscate provider credentials
  1181. copy_credentials = credentials.copy()
  1182. for key, value in copy_credentials.items():
  1183. if key in credential_secret_variables:
  1184. copy_credentials[key] = encrypter.obfuscated_token(value)
  1185. return copy_credentials
  1186. def get_provider_model(
  1187. self, model_type: ModelType, model: str, only_active: bool = False
  1188. ) -> Optional[ModelWithProviderEntity]:
  1189. """
  1190. Get provider model.
  1191. :param model_type: model type
  1192. :param model: model name
  1193. :param only_active: return active model only
  1194. :return:
  1195. """
  1196. provider_models = self.get_provider_models(model_type, only_active, model)
  1197. for provider_model in provider_models:
  1198. if provider_model.model == model:
  1199. return provider_model
  1200. return None
  1201. def get_provider_models(
  1202. self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
  1203. ) -> list[ModelWithProviderEntity]:
  1204. """
  1205. Get provider models.
  1206. :param model_type: model type
  1207. :param only_active: only active models
  1208. :param model: model name
  1209. :return:
  1210. """
  1211. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1212. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  1213. model_types: list[ModelType] = []
  1214. if model_type:
  1215. model_types.append(model_type)
  1216. else:
  1217. model_types = list(provider_schema.supported_model_types)
  1218. # Group model settings by model type and model
  1219. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  1220. for model_setting in self.model_settings:
  1221. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  1222. if self.using_provider_type == ProviderType.SYSTEM:
  1223. provider_models = self._get_system_provider_models(
  1224. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  1225. )
  1226. else:
  1227. provider_models = self._get_custom_provider_models(
  1228. model_types=model_types,
  1229. provider_schema=provider_schema,
  1230. model_setting_map=model_setting_map,
  1231. model=model,
  1232. )
  1233. if only_active:
  1234. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  1235. # resort provider_models
  1236. # Optimize sorting logic: first sort by provider.position order, then by model_type.value
  1237. # Get the position list for model types (retrieve only once for better performance)
  1238. model_type_positions = {}
  1239. if hasattr(self.provider, "position") and self.provider.position:
  1240. model_type_positions = self.provider.position
  1241. def get_sort_key(model: ModelWithProviderEntity):
  1242. # Get the position list for the current model type
  1243. positions = model_type_positions.get(model.model_type.value, [])
  1244. # If the model name is in the position list, use its index for sorting
  1245. # Otherwise use a large value (list length) to place undefined models at the end
  1246. position_index = positions.index(model.model) if model.model in positions else len(positions)
  1247. # Return composite sort key: (model_type value, model position index)
  1248. return (model.model_type.value, position_index)
  1249. # Sort using the composite sort key
  1250. return sorted(provider_models, key=get_sort_key)
  1251. def _get_system_provider_models(
  1252. self,
  1253. model_types: Sequence[ModelType],
  1254. provider_schema: ProviderEntity,
  1255. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1256. ) -> list[ModelWithProviderEntity]:
  1257. """
  1258. Get system provider models.
  1259. :param model_types: model types
  1260. :param provider_schema: provider schema
  1261. :param model_setting_map: model setting map
  1262. :return:
  1263. """
  1264. provider_models = []
  1265. for model_type in model_types:
  1266. for m in provider_schema.models:
  1267. if m.model_type != model_type:
  1268. continue
  1269. status = ModelStatus.ACTIVE
  1270. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1271. model_setting = model_setting_map[m.model_type][m.model]
  1272. if model_setting.enabled is False:
  1273. status = ModelStatus.DISABLED
  1274. provider_models.append(
  1275. ModelWithProviderEntity(
  1276. model=m.model,
  1277. label=m.label,
  1278. model_type=m.model_type,
  1279. features=m.features,
  1280. fetch_from=m.fetch_from,
  1281. model_properties=m.model_properties,
  1282. deprecated=m.deprecated,
  1283. provider=SimpleModelProviderEntity(self.provider),
  1284. status=status,
  1285. )
  1286. )
  1287. if self.provider.provider not in original_provider_configurate_methods:
  1288. original_provider_configurate_methods[self.provider.provider] = []
  1289. for configurate_method in provider_schema.configurate_methods:
  1290. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  1291. should_use_custom_model = False
  1292. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  1293. should_use_custom_model = True
  1294. for quota_configuration in self.system_configuration.quota_configurations:
  1295. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  1296. continue
  1297. restrict_models = quota_configuration.restrict_models
  1298. if len(restrict_models) == 0:
  1299. break
  1300. if should_use_custom_model:
  1301. if original_provider_configurate_methods[self.provider.provider] == [
  1302. ConfigurateMethod.CUSTOMIZABLE_MODEL
  1303. ]:
  1304. # only customizable model
  1305. for restrict_model in restrict_models:
  1306. copy_credentials = (
  1307. self.system_configuration.credentials.copy()
  1308. if self.system_configuration.credentials
  1309. else {}
  1310. )
  1311. if restrict_model.base_model_name:
  1312. copy_credentials["base_model_name"] = restrict_model.base_model_name
  1313. try:
  1314. custom_model_schema = self.get_model_schema(
  1315. model_type=restrict_model.model_type,
  1316. model=restrict_model.model,
  1317. credentials=copy_credentials,
  1318. )
  1319. except Exception as ex:
  1320. logger.warning("get custom model schema failed, %s", ex)
  1321. continue
  1322. if not custom_model_schema:
  1323. continue
  1324. if custom_model_schema.model_type not in model_types:
  1325. continue
  1326. status = ModelStatus.ACTIVE
  1327. if (
  1328. custom_model_schema.model_type in model_setting_map
  1329. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1330. ):
  1331. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1332. if model_setting.enabled is False:
  1333. status = ModelStatus.DISABLED
  1334. provider_models.append(
  1335. ModelWithProviderEntity(
  1336. model=custom_model_schema.model,
  1337. label=custom_model_schema.label,
  1338. model_type=custom_model_schema.model_type,
  1339. features=custom_model_schema.features,
  1340. fetch_from=FetchFrom.PREDEFINED_MODEL,
  1341. model_properties=custom_model_schema.model_properties,
  1342. deprecated=custom_model_schema.deprecated,
  1343. provider=SimpleModelProviderEntity(self.provider),
  1344. status=status,
  1345. )
  1346. )
  1347. # if llm name not in restricted llm list, remove it
  1348. restrict_model_names = [rm.model for rm in restrict_models]
  1349. for model in provider_models:
  1350. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  1351. model.status = ModelStatus.NO_PERMISSION
  1352. elif not quota_configuration.is_valid:
  1353. model.status = ModelStatus.QUOTA_EXCEEDED
  1354. return provider_models
  1355. def _get_custom_provider_models(
  1356. self,
  1357. model_types: Sequence[ModelType],
  1358. provider_schema: ProviderEntity,
  1359. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1360. model: Optional[str] = None,
  1361. ) -> list[ModelWithProviderEntity]:
  1362. """
  1363. Get custom provider models.
  1364. :param model_types: model types
  1365. :param provider_schema: provider schema
  1366. :param model_setting_map: model setting map
  1367. :return:
  1368. """
  1369. provider_models = []
  1370. credentials = None
  1371. if self.custom_configuration.provider:
  1372. credentials = self.custom_configuration.provider.credentials
  1373. for model_type in model_types:
  1374. if model_type not in self.provider.supported_model_types:
  1375. continue
  1376. for m in provider_schema.models:
  1377. if m.model_type != model_type:
  1378. continue
  1379. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  1380. load_balancing_enabled = False
  1381. has_invalid_load_balancing_configs = False
  1382. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1383. model_setting = model_setting_map[m.model_type][m.model]
  1384. if model_setting.enabled is False:
  1385. status = ModelStatus.DISABLED
  1386. provider_model_lb_configs = [
  1387. config
  1388. for config in model_setting.load_balancing_configs
  1389. if config.credential_source_type != "custom_model"
  1390. ]
  1391. if len(provider_model_lb_configs) > 1:
  1392. load_balancing_enabled = True
  1393. if any(config.name == "__delete__" for config in provider_model_lb_configs):
  1394. has_invalid_load_balancing_configs = True
  1395. provider_models.append(
  1396. ModelWithProviderEntity(
  1397. model=m.model,
  1398. label=m.label,
  1399. model_type=m.model_type,
  1400. features=m.features,
  1401. fetch_from=m.fetch_from,
  1402. model_properties=m.model_properties,
  1403. deprecated=m.deprecated,
  1404. provider=SimpleModelProviderEntity(self.provider),
  1405. status=status,
  1406. load_balancing_enabled=load_balancing_enabled,
  1407. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1408. )
  1409. )
  1410. # custom models
  1411. for model_configuration in self.custom_configuration.models:
  1412. if model_configuration.model_type not in model_types:
  1413. continue
  1414. if model and model != model_configuration.model:
  1415. continue
  1416. try:
  1417. custom_model_schema = self.get_model_schema(
  1418. model_type=model_configuration.model_type,
  1419. model=model_configuration.model,
  1420. credentials=model_configuration.credentials,
  1421. )
  1422. except Exception as ex:
  1423. logger.warning("get custom model schema failed, %s", ex)
  1424. continue
  1425. if not custom_model_schema:
  1426. continue
  1427. status = ModelStatus.ACTIVE
  1428. load_balancing_enabled = False
  1429. has_invalid_load_balancing_configs = False
  1430. if (
  1431. custom_model_schema.model_type in model_setting_map
  1432. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1433. ):
  1434. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1435. if model_setting.enabled is False:
  1436. status = ModelStatus.DISABLED
  1437. custom_model_lb_configs = [
  1438. config
  1439. for config in model_setting.load_balancing_configs
  1440. if config.credential_source_type != "provider"
  1441. ]
  1442. if len(custom_model_lb_configs) > 1:
  1443. load_balancing_enabled = True
  1444. if any(config.name == "__delete__" for config in custom_model_lb_configs):
  1445. has_invalid_load_balancing_configs = True
  1446. if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
  1447. status = ModelStatus.CREDENTIAL_REMOVED
  1448. provider_models.append(
  1449. ModelWithProviderEntity(
  1450. model=custom_model_schema.model,
  1451. label=custom_model_schema.label,
  1452. model_type=custom_model_schema.model_type,
  1453. features=custom_model_schema.features,
  1454. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  1455. model_properties=custom_model_schema.model_properties,
  1456. deprecated=custom_model_schema.deprecated,
  1457. provider=SimpleModelProviderEntity(self.provider),
  1458. status=status,
  1459. load_balancing_enabled=load_balancing_enabled,
  1460. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1461. )
  1462. )
  1463. return provider_models
  1464. class ProviderConfigurations(BaseModel):
  1465. """
  1466. Model class for provider configuration dict.
  1467. """
  1468. tenant_id: str
  1469. configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
  1470. def __init__(self, tenant_id: str):
  1471. super().__init__(tenant_id=tenant_id)
  1472. def get_models(
  1473. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  1474. ) -> list[ModelWithProviderEntity]:
  1475. """
  1476. Get available models.
  1477. If preferred provider type is `system`:
  1478. Get the current **system mode** if provider supported,
  1479. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  1480. If there is no model configured in custom mode, it is treated as no_configure.
  1481. system > custom > no_configure
  1482. If preferred provider type is `custom`:
  1483. If custom credentials are configured, it is treated as custom mode.
  1484. Otherwise, get the current **system mode** if supported,
  1485. If all system modes are not available (no quota), it is treated as no_configure.
  1486. custom > system > no_configure
  1487. If real mode is `system`, use system credentials to get models,
  1488. paid quotas > provider free quotas > system free quotas
  1489. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  1490. If real mode is `custom`, use workspace custom credentials to get models,
  1491. include pre-defined models, custom models(manual append).
  1492. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  1493. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  1494. model status marked as `active` is available.
  1495. :param provider: provider name
  1496. :param model_type: model type
  1497. :param only_active: only active models
  1498. :return:
  1499. """
  1500. all_models = []
  1501. for provider_configuration in self.values():
  1502. if provider and provider_configuration.provider.provider != provider:
  1503. continue
  1504. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  1505. return all_models
  1506. def to_list(self) -> list[ProviderConfiguration]:
  1507. """
  1508. Convert to list.
  1509. :return:
  1510. """
  1511. return list(self.values())
  1512. def __getitem__(self, key):
  1513. if "/" not in key:
  1514. key = str(ModelProviderID(key))
  1515. return self.configurations[key]
  1516. def __setitem__(self, key, value):
  1517. self.configurations[key] = value
  1518. def __iter__(self):
  1519. return iter(self.configurations)
  1520. def values(self) -> Iterator[ProviderConfiguration]:
  1521. return iter(self.configurations.values())
  1522. def get(self, key, default=None) -> ProviderConfiguration | None:
  1523. if "/" not in key:
  1524. key = str(ModelProviderID(key))
  1525. return self.configurations.get(key, default) # type: ignore
  1526. class ProviderModelBundle(BaseModel):
  1527. """
  1528. Provider model bundle.
  1529. """
  1530. configuration: ProviderConfiguration
  1531. model_type_instance: AIModel
  1532. # pydantic configs
  1533. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())