You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

provider_configuration.py 44KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140
  1. import json
  2. import logging
  3. from collections import defaultdict
  4. from collections.abc import Iterator, Sequence
  5. from json import JSONDecodeError
  6. from typing import Optional
  7. from pydantic import BaseModel, ConfigDict, Field
  8. from constants import HIDDEN_VALUE
  9. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  10. from core.entities.provider_entities import (
  11. CustomConfiguration,
  12. ModelSettings,
  13. SystemConfiguration,
  14. SystemConfigurationStatus,
  15. )
  16. from core.helper import encrypter
  17. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  18. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  19. from core.model_runtime.entities.provider_entities import (
  20. ConfigurateMethod,
  21. CredentialFormSchema,
  22. FormType,
  23. ProviderEntity,
  24. )
  25. from core.model_runtime.model_providers.__base.ai_model import AIModel
  26. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  27. from core.plugin.entities.plugin import ModelProviderID
  28. from extensions.ext_database import db
  29. from libs.datetime_utils import naive_utc_now
  30. from models.provider import (
  31. LoadBalancingModelConfig,
  32. Provider,
  33. ProviderModel,
  34. ProviderModelSetting,
  35. ProviderType,
  36. TenantPreferredModelProvider,
  37. )
  38. logger = logging.getLogger(__name__)
  39. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  40. class ProviderConfiguration(BaseModel):
  41. """
  42. Model class for provider configuration.
  43. """
  44. tenant_id: str
  45. provider: ProviderEntity
  46. preferred_provider_type: ProviderType
  47. using_provider_type: ProviderType
  48. system_configuration: SystemConfiguration
  49. custom_configuration: CustomConfiguration
  50. model_settings: list[ModelSettings]
  51. # pydantic configs
  52. model_config = ConfigDict(protected_namespaces=())
  53. def __init__(self, **data):
  54. super().__init__(**data)
  55. if self.provider.provider not in original_provider_configurate_methods:
  56. original_provider_configurate_methods[self.provider.provider] = []
  57. for configurate_method in self.provider.configurate_methods:
  58. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  59. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  60. if (
  61. any(
  62. len(quota_configuration.restrict_models) > 0
  63. for quota_configuration in self.system_configuration.quota_configurations
  64. )
  65. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  66. ):
  67. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  68. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  69. """
  70. Get current credentials.
  71. :param model_type: model type
  72. :param model: model name
  73. :return:
  74. """
  75. if self.model_settings:
  76. # check if model is disabled by admin
  77. for model_setting in self.model_settings:
  78. if model_setting.model_type == model_type and model_setting.model == model:
  79. if not model_setting.enabled:
  80. raise ValueError(f"Model {model} is disabled.")
  81. if self.using_provider_type == ProviderType.SYSTEM:
  82. restrict_models = []
  83. for quota_configuration in self.system_configuration.quota_configurations:
  84. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  85. continue
  86. restrict_models = quota_configuration.restrict_models
  87. copy_credentials = (
  88. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  89. )
  90. if restrict_models:
  91. for restrict_model in restrict_models:
  92. if (
  93. restrict_model.model_type == model_type
  94. and restrict_model.model == model
  95. and restrict_model.base_model_name
  96. ):
  97. copy_credentials["base_model_name"] = restrict_model.base_model_name
  98. return copy_credentials
  99. else:
  100. credentials = None
  101. if self.custom_configuration.models:
  102. for model_configuration in self.custom_configuration.models:
  103. if model_configuration.model_type == model_type and model_configuration.model == model:
  104. credentials = model_configuration.credentials
  105. break
  106. if not credentials and self.custom_configuration.provider:
  107. credentials = self.custom_configuration.provider.credentials
  108. return credentials
  109. def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
  110. """
  111. Get system configuration status.
  112. :return:
  113. """
  114. if self.system_configuration.enabled is False:
  115. return SystemConfigurationStatus.UNSUPPORTED
  116. current_quota_type = self.system_configuration.current_quota_type
  117. current_quota_configuration = next(
  118. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  119. )
  120. if current_quota_configuration is None:
  121. return None
  122. if not current_quota_configuration:
  123. return SystemConfigurationStatus.UNSUPPORTED
  124. return (
  125. SystemConfigurationStatus.ACTIVE
  126. if current_quota_configuration.is_valid
  127. else SystemConfigurationStatus.QUOTA_EXCEEDED
  128. )
  129. def is_custom_configuration_available(self) -> bool:
  130. """
  131. Check custom configuration available.
  132. :return:
  133. """
  134. return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
  135. def get_custom_credentials(self, obfuscated: bool = False) -> dict | None:
  136. """
  137. Get custom credentials.
  138. :param obfuscated: obfuscated secret data in credentials
  139. :return:
  140. """
  141. if self.custom_configuration.provider is None:
  142. return None
  143. credentials = self.custom_configuration.provider.credentials
  144. if not obfuscated:
  145. return credentials
  146. # Obfuscate credentials
  147. return self.obfuscated_credentials(
  148. credentials=credentials,
  149. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  150. if self.provider.provider_credential_schema
  151. else [],
  152. )
  153. def _get_custom_provider_credentials(self) -> Provider | None:
  154. """
  155. Get custom provider credentials.
  156. """
  157. # get provider
  158. model_provider_id = ModelProviderID(self.provider.provider)
  159. provider_names = [self.provider.provider]
  160. if model_provider_id.is_langgenius():
  161. provider_names.append(model_provider_id.provider_name)
  162. provider_record = (
  163. db.session.query(Provider)
  164. .where(
  165. Provider.tenant_id == self.tenant_id,
  166. Provider.provider_type == ProviderType.CUSTOM.value,
  167. Provider.provider_name.in_(provider_names),
  168. )
  169. .first()
  170. )
  171. return provider_record
  172. def custom_credentials_validate(self, credentials: dict) -> tuple[Provider | None, dict]:
  173. """
  174. Validate custom credentials.
  175. :param credentials: provider credentials
  176. :return:
  177. """
  178. provider_record = self._get_custom_provider_credentials()
  179. # Get provider credential secret variables
  180. provider_credential_secret_variables = self.extract_secret_variables(
  181. self.provider.provider_credential_schema.credential_form_schemas
  182. if self.provider.provider_credential_schema
  183. else []
  184. )
  185. if provider_record:
  186. try:
  187. # fix origin data
  188. if provider_record.encrypted_config:
  189. if not provider_record.encrypted_config.startswith("{"):
  190. original_credentials = {"openai_api_key": provider_record.encrypted_config}
  191. else:
  192. original_credentials = json.loads(provider_record.encrypted_config)
  193. else:
  194. original_credentials = {}
  195. except JSONDecodeError:
  196. original_credentials = {}
  197. # encrypt credentials
  198. for key, value in credentials.items():
  199. if key in provider_credential_secret_variables:
  200. # if send [__HIDDEN__] in secret input, it will be same as original value
  201. if value == HIDDEN_VALUE and key in original_credentials:
  202. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  203. model_provider_factory = ModelProviderFactory(self.tenant_id)
  204. credentials = model_provider_factory.provider_credentials_validate(
  205. provider=self.provider.provider, credentials=credentials
  206. )
  207. for key, value in credentials.items():
  208. if key in provider_credential_secret_variables:
  209. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  210. return provider_record, credentials
  211. def add_or_update_custom_credentials(self, credentials: dict) -> None:
  212. """
  213. Add or update custom provider credentials.
  214. :param credentials:
  215. :return:
  216. """
  217. # validate custom provider config
  218. provider_record, credentials = self.custom_credentials_validate(credentials)
  219. # save provider
  220. # Note: Do not switch the preferred provider, which allows users to use quotas first
  221. if provider_record:
  222. provider_record.encrypted_config = json.dumps(credentials)
  223. provider_record.is_valid = True
  224. provider_record.updated_at = naive_utc_now()
  225. db.session.commit()
  226. else:
  227. provider_record = Provider()
  228. provider_record.tenant_id = self.tenant_id
  229. provider_record.provider_name = self.provider.provider
  230. provider_record.provider_type = ProviderType.CUSTOM.value
  231. provider_record.encrypted_config = json.dumps(credentials)
  232. provider_record.is_valid = True
  233. db.session.add(provider_record)
  234. db.session.commit()
  235. provider_model_credentials_cache = ProviderCredentialsCache(
  236. tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
  237. )
  238. provider_model_credentials_cache.delete()
  239. self.switch_preferred_provider_type(ProviderType.CUSTOM)
  240. def delete_custom_credentials(self) -> None:
  241. """
  242. Delete custom provider credentials.
  243. :return:
  244. """
  245. # get provider
  246. provider_record = self._get_custom_provider_credentials()
  247. # delete provider
  248. if provider_record:
  249. self.switch_preferred_provider_type(ProviderType.SYSTEM)
  250. db.session.delete(provider_record)
  251. db.session.commit()
  252. provider_model_credentials_cache = ProviderCredentialsCache(
  253. tenant_id=self.tenant_id,
  254. identity_id=provider_record.id,
  255. cache_type=ProviderCredentialsCacheType.PROVIDER,
  256. )
  257. provider_model_credentials_cache.delete()
  258. def get_custom_model_credentials(
  259. self, model_type: ModelType, model: str, obfuscated: bool = False
  260. ) -> Optional[dict]:
  261. """
  262. Get custom model credentials.
  263. :param model_type: model type
  264. :param model: model name
  265. :param obfuscated: obfuscated secret data in credentials
  266. :return:
  267. """
  268. if not self.custom_configuration.models:
  269. return None
  270. for model_configuration in self.custom_configuration.models:
  271. if model_configuration.model_type == model_type and model_configuration.model == model:
  272. credentials = model_configuration.credentials
  273. if not obfuscated:
  274. return credentials
  275. # Obfuscate credentials
  276. return self.obfuscated_credentials(
  277. credentials=credentials,
  278. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  279. if self.provider.model_credential_schema
  280. else [],
  281. )
  282. return None
  283. def _get_custom_model_credentials(
  284. self,
  285. model_type: ModelType,
  286. model: str,
  287. ) -> ProviderModel | None:
  288. """
  289. Get custom model credentials.
  290. """
  291. # get provider model
  292. model_provider_id = ModelProviderID(self.provider.provider)
  293. provider_names = [self.provider.provider]
  294. if model_provider_id.is_langgenius():
  295. provider_names.append(model_provider_id.provider_name)
  296. provider_model_record = (
  297. db.session.query(ProviderModel)
  298. .where(
  299. ProviderModel.tenant_id == self.tenant_id,
  300. ProviderModel.provider_name.in_(provider_names),
  301. ProviderModel.model_name == model,
  302. ProviderModel.model_type == model_type.to_origin_model_type(),
  303. )
  304. .first()
  305. )
  306. return provider_model_record
  307. def custom_model_credentials_validate(
  308. self, model_type: ModelType, model: str, credentials: dict
  309. ) -> tuple[ProviderModel | None, dict]:
  310. """
  311. Validate custom model credentials.
  312. :param model_type: model type
  313. :param model: model name
  314. :param credentials: model credentials
  315. :return:
  316. """
  317. # get provider model
  318. provider_model_record = self._get_custom_model_credentials(model_type, model)
  319. # Get provider credential secret variables
  320. provider_credential_secret_variables = self.extract_secret_variables(
  321. self.provider.model_credential_schema.credential_form_schemas
  322. if self.provider.model_credential_schema
  323. else []
  324. )
  325. if provider_model_record:
  326. try:
  327. original_credentials = (
  328. json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
  329. )
  330. except JSONDecodeError:
  331. original_credentials = {}
  332. # decrypt credentials
  333. for key, value in credentials.items():
  334. if key in provider_credential_secret_variables:
  335. # if send [__HIDDEN__] in secret input, it will be same as original value
  336. if value == HIDDEN_VALUE and key in original_credentials:
  337. credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
  338. model_provider_factory = ModelProviderFactory(self.tenant_id)
  339. credentials = model_provider_factory.model_credentials_validate(
  340. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  341. )
  342. for key, value in credentials.items():
  343. if key in provider_credential_secret_variables:
  344. credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  345. return provider_model_record, credentials
  346. def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
  347. """
  348. Add or update custom model credentials.
  349. :param model_type: model type
  350. :param model: model name
  351. :param credentials: model credentials
  352. :return:
  353. """
  354. # validate custom model config
  355. provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
  356. # save provider model
  357. # Note: Do not switch the preferred provider, which allows users to use quotas first
  358. if provider_model_record:
  359. provider_model_record.encrypted_config = json.dumps(credentials)
  360. provider_model_record.is_valid = True
  361. provider_model_record.updated_at = naive_utc_now()
  362. db.session.commit()
  363. else:
  364. provider_model_record = ProviderModel()
  365. provider_model_record.tenant_id = self.tenant_id
  366. provider_model_record.provider_name = self.provider.provider
  367. provider_model_record.model_name = model
  368. provider_model_record.model_type = model_type.to_origin_model_type()
  369. provider_model_record.encrypted_config = json.dumps(credentials)
  370. provider_model_record.is_valid = True
  371. db.session.add(provider_model_record)
  372. db.session.commit()
  373. provider_model_credentials_cache = ProviderCredentialsCache(
  374. tenant_id=self.tenant_id,
  375. identity_id=provider_model_record.id,
  376. cache_type=ProviderCredentialsCacheType.MODEL,
  377. )
  378. provider_model_credentials_cache.delete()
  379. def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
  380. """
  381. Delete custom model credentials.
  382. :param model_type: model type
  383. :param model: model name
  384. :return:
  385. """
  386. # get provider model
  387. provider_model_record = self._get_custom_model_credentials(model_type, model)
  388. # delete provider model
  389. if provider_model_record:
  390. db.session.delete(provider_model_record)
  391. db.session.commit()
  392. provider_model_credentials_cache = ProviderCredentialsCache(
  393. tenant_id=self.tenant_id,
  394. identity_id=provider_model_record.id,
  395. cache_type=ProviderCredentialsCacheType.MODEL,
  396. )
  397. provider_model_credentials_cache.delete()
  398. def _get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
  399. """
  400. Get provider model setting.
  401. """
  402. model_provider_id = ModelProviderID(self.provider.provider)
  403. provider_names = [self.provider.provider]
  404. if model_provider_id.is_langgenius():
  405. provider_names.append(model_provider_id.provider_name)
  406. return (
  407. db.session.query(ProviderModelSetting)
  408. .where(
  409. ProviderModelSetting.tenant_id == self.tenant_id,
  410. ProviderModelSetting.provider_name.in_(provider_names),
  411. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  412. ProviderModelSetting.model_name == model,
  413. )
  414. .first()
  415. )
  416. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  417. """
  418. Enable model.
  419. :param model_type: model type
  420. :param model: model name
  421. :return:
  422. """
  423. model_setting = self._get_provider_model_setting(model_type, model)
  424. if model_setting:
  425. model_setting.enabled = True
  426. model_setting.updated_at = naive_utc_now()
  427. db.session.commit()
  428. else:
  429. model_setting = ProviderModelSetting()
  430. model_setting.tenant_id = self.tenant_id
  431. model_setting.provider_name = self.provider.provider
  432. model_setting.model_type = model_type.to_origin_model_type()
  433. model_setting.model_name = model
  434. model_setting.enabled = True
  435. db.session.add(model_setting)
  436. db.session.commit()
  437. return model_setting
  438. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  439. """
  440. Disable model.
  441. :param model_type: model type
  442. :param model: model name
  443. :return:
  444. """
  445. model_setting = self._get_provider_model_setting(model_type, model)
  446. if model_setting:
  447. model_setting.enabled = False
  448. model_setting.updated_at = naive_utc_now()
  449. db.session.commit()
  450. else:
  451. model_setting = ProviderModelSetting()
  452. model_setting.tenant_id = self.tenant_id
  453. model_setting.provider_name = self.provider.provider
  454. model_setting.model_type = model_type.to_origin_model_type()
  455. model_setting.model_name = model
  456. model_setting.enabled = False
  457. db.session.add(model_setting)
  458. db.session.commit()
  459. return model_setting
  460. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  461. """
  462. Get provider model setting.
  463. :param model_type: model type
  464. :param model: model name
  465. :return:
  466. """
  467. return self._get_provider_model_setting(model_type, model)
  468. def _get_load_balancing_config(self, model_type: ModelType, model: str) -> Optional[LoadBalancingModelConfig]:
  469. """
  470. Get load balancing config.
  471. """
  472. model_provider_id = ModelProviderID(self.provider.provider)
  473. provider_names = [self.provider.provider]
  474. if model_provider_id.is_langgenius():
  475. provider_names.append(model_provider_id.provider_name)
  476. return (
  477. db.session.query(LoadBalancingModelConfig)
  478. .where(
  479. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  480. LoadBalancingModelConfig.provider_name.in_(provider_names),
  481. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  482. LoadBalancingModelConfig.model_name == model,
  483. )
  484. .first()
  485. )
  486. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  487. """
  488. Enable model load balancing.
  489. :param model_type: model type
  490. :param model: model name
  491. :return:
  492. """
  493. model_provider_id = ModelProviderID(self.provider.provider)
  494. provider_names = [self.provider.provider]
  495. if model_provider_id.is_langgenius():
  496. provider_names.append(model_provider_id.provider_name)
  497. load_balancing_config_count = (
  498. db.session.query(LoadBalancingModelConfig)
  499. .where(
  500. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  501. LoadBalancingModelConfig.provider_name.in_(provider_names),
  502. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  503. LoadBalancingModelConfig.model_name == model,
  504. )
  505. .count()
  506. )
  507. if load_balancing_config_count <= 1:
  508. raise ValueError("Model load balancing configuration must be more than 1.")
  509. model_setting = self._get_provider_model_setting(model_type, model)
  510. if model_setting:
  511. model_setting.load_balancing_enabled = True
  512. model_setting.updated_at = naive_utc_now()
  513. db.session.commit()
  514. else:
  515. model_setting = ProviderModelSetting()
  516. model_setting.tenant_id = self.tenant_id
  517. model_setting.provider_name = self.provider.provider
  518. model_setting.model_type = model_type.to_origin_model_type()
  519. model_setting.model_name = model
  520. model_setting.load_balancing_enabled = True
  521. db.session.add(model_setting)
  522. db.session.commit()
  523. return model_setting
  524. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  525. """
  526. Disable model load balancing.
  527. :param model_type: model type
  528. :param model: model name
  529. :return:
  530. """
  531. model_provider_id = ModelProviderID(self.provider.provider)
  532. provider_names = [self.provider.provider]
  533. if model_provider_id.is_langgenius():
  534. provider_names.append(model_provider_id.provider_name)
  535. model_setting = (
  536. db.session.query(ProviderModelSetting)
  537. .where(
  538. ProviderModelSetting.tenant_id == self.tenant_id,
  539. ProviderModelSetting.provider_name.in_(provider_names),
  540. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  541. ProviderModelSetting.model_name == model,
  542. )
  543. .first()
  544. )
  545. if model_setting:
  546. model_setting.load_balancing_enabled = False
  547. model_setting.updated_at = naive_utc_now()
  548. db.session.commit()
  549. else:
  550. model_setting = ProviderModelSetting()
  551. model_setting.tenant_id = self.tenant_id
  552. model_setting.provider_name = self.provider.provider
  553. model_setting.model_type = model_type.to_origin_model_type()
  554. model_setting.model_name = model
  555. model_setting.load_balancing_enabled = False
  556. db.session.add(model_setting)
  557. db.session.commit()
  558. return model_setting
  559. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  560. """
  561. Get current model type instance.
  562. :param model_type: model type
  563. :return:
  564. """
  565. model_provider_factory = ModelProviderFactory(self.tenant_id)
  566. # Get model instance of LLM
  567. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  568. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict) -> AIModelEntity | None:
  569. """
  570. Get model schema
  571. """
  572. model_provider_factory = ModelProviderFactory(self.tenant_id)
  573. return model_provider_factory.get_model_schema(
  574. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  575. )
  576. def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
  577. """
  578. Switch preferred provider type.
  579. :param provider_type:
  580. :return:
  581. """
  582. if provider_type == self.preferred_provider_type:
  583. return
  584. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  585. return
  586. # get preferred provider
  587. model_provider_id = ModelProviderID(self.provider.provider)
  588. provider_names = [self.provider.provider]
  589. if model_provider_id.is_langgenius():
  590. provider_names.append(model_provider_id.provider_name)
  591. preferred_model_provider = (
  592. db.session.query(TenantPreferredModelProvider)
  593. .where(
  594. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  595. TenantPreferredModelProvider.provider_name.in_(provider_names),
  596. )
  597. .first()
  598. )
  599. if preferred_model_provider:
  600. preferred_model_provider.preferred_provider_type = provider_type.value
  601. else:
  602. preferred_model_provider = TenantPreferredModelProvider()
  603. preferred_model_provider.tenant_id = self.tenant_id
  604. preferred_model_provider.provider_name = self.provider.provider
  605. preferred_model_provider.preferred_provider_type = provider_type.value
  606. db.session.add(preferred_model_provider)
  607. db.session.commit()
  608. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  609. """
  610. Extract secret input form variables.
  611. :param credential_form_schemas:
  612. :return:
  613. """
  614. secret_input_form_variables = []
  615. for credential_form_schema in credential_form_schemas:
  616. if credential_form_schema.type == FormType.SECRET_INPUT:
  617. secret_input_form_variables.append(credential_form_schema.variable)
  618. return secret_input_form_variables
  619. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  620. """
  621. Obfuscated credentials.
  622. :param credentials: credentials
  623. :param credential_form_schemas: credential form schemas
  624. :return:
  625. """
  626. # Get provider credential secret variables
  627. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  628. # Obfuscate provider credentials
  629. copy_credentials = credentials.copy()
  630. for key, value in copy_credentials.items():
  631. if key in credential_secret_variables:
  632. copy_credentials[key] = encrypter.obfuscated_token(value)
  633. return copy_credentials
  634. def get_provider_model(
  635. self, model_type: ModelType, model: str, only_active: bool = False
  636. ) -> Optional[ModelWithProviderEntity]:
  637. """
  638. Get provider model.
  639. :param model_type: model type
  640. :param model: model name
  641. :param only_active: return active model only
  642. :return:
  643. """
  644. provider_models = self.get_provider_models(model_type, only_active, model)
  645. for provider_model in provider_models:
  646. if provider_model.model == model:
  647. return provider_model
  648. return None
  649. def get_provider_models(
  650. self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
  651. ) -> list[ModelWithProviderEntity]:
  652. """
  653. Get provider models.
  654. :param model_type: model type
  655. :param only_active: only active models
  656. :param model: model name
  657. :return:
  658. """
  659. model_provider_factory = ModelProviderFactory(self.tenant_id)
  660. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  661. model_types: list[ModelType] = []
  662. if model_type:
  663. model_types.append(model_type)
  664. else:
  665. model_types = list(provider_schema.supported_model_types)
  666. # Group model settings by model type and model
  667. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  668. for model_setting in self.model_settings:
  669. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  670. if self.using_provider_type == ProviderType.SYSTEM:
  671. provider_models = self._get_system_provider_models(
  672. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  673. )
  674. else:
  675. provider_models = self._get_custom_provider_models(
  676. model_types=model_types,
  677. provider_schema=provider_schema,
  678. model_setting_map=model_setting_map,
  679. model=model,
  680. )
  681. if only_active:
  682. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  683. # resort provider_models
  684. # Optimize sorting logic: first sort by provider.position order, then by model_type.value
  685. # Get the position list for model types (retrieve only once for better performance)
  686. model_type_positions = {}
  687. if hasattr(self.provider, "position") and self.provider.position:
  688. model_type_positions = self.provider.position
  689. def get_sort_key(model: ModelWithProviderEntity):
  690. # Get the position list for the current model type
  691. positions = model_type_positions.get(model.model_type.value, [])
  692. # If the model name is in the position list, use its index for sorting
  693. # Otherwise use a large value (list length) to place undefined models at the end
  694. position_index = positions.index(model.model) if model.model in positions else len(positions)
  695. # Return composite sort key: (model_type value, model position index)
  696. return (model.model_type.value, position_index)
  697. # Sort using the composite sort key
  698. return sorted(provider_models, key=get_sort_key)
  699. def _get_system_provider_models(
  700. self,
  701. model_types: Sequence[ModelType],
  702. provider_schema: ProviderEntity,
  703. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  704. ) -> list[ModelWithProviderEntity]:
  705. """
  706. Get system provider models.
  707. :param model_types: model types
  708. :param provider_schema: provider schema
  709. :param model_setting_map: model setting map
  710. :return:
  711. """
  712. provider_models = []
  713. for model_type in model_types:
  714. for m in provider_schema.models:
  715. if m.model_type != model_type:
  716. continue
  717. status = ModelStatus.ACTIVE
  718. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  719. model_setting = model_setting_map[m.model_type][m.model]
  720. if model_setting.enabled is False:
  721. status = ModelStatus.DISABLED
  722. provider_models.append(
  723. ModelWithProviderEntity(
  724. model=m.model,
  725. label=m.label,
  726. model_type=m.model_type,
  727. features=m.features,
  728. fetch_from=m.fetch_from,
  729. model_properties=m.model_properties,
  730. deprecated=m.deprecated,
  731. provider=SimpleModelProviderEntity(self.provider),
  732. status=status,
  733. )
  734. )
  735. if self.provider.provider not in original_provider_configurate_methods:
  736. original_provider_configurate_methods[self.provider.provider] = []
  737. for configurate_method in provider_schema.configurate_methods:
  738. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  739. should_use_custom_model = False
  740. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  741. should_use_custom_model = True
  742. for quota_configuration in self.system_configuration.quota_configurations:
  743. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  744. continue
  745. restrict_models = quota_configuration.restrict_models
  746. if len(restrict_models) == 0:
  747. break
  748. if should_use_custom_model:
  749. if original_provider_configurate_methods[self.provider.provider] == [
  750. ConfigurateMethod.CUSTOMIZABLE_MODEL
  751. ]:
  752. # only customizable model
  753. for restrict_model in restrict_models:
  754. copy_credentials = (
  755. self.system_configuration.credentials.copy()
  756. if self.system_configuration.credentials
  757. else {}
  758. )
  759. if restrict_model.base_model_name:
  760. copy_credentials["base_model_name"] = restrict_model.base_model_name
  761. try:
  762. custom_model_schema = self.get_model_schema(
  763. model_type=restrict_model.model_type,
  764. model=restrict_model.model,
  765. credentials=copy_credentials,
  766. )
  767. except Exception as ex:
  768. logger.warning("get custom model schema failed, %s", ex)
  769. continue
  770. if not custom_model_schema:
  771. continue
  772. if custom_model_schema.model_type not in model_types:
  773. continue
  774. status = ModelStatus.ACTIVE
  775. if (
  776. custom_model_schema.model_type in model_setting_map
  777. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  778. ):
  779. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  780. if model_setting.enabled is False:
  781. status = ModelStatus.DISABLED
  782. provider_models.append(
  783. ModelWithProviderEntity(
  784. model=custom_model_schema.model,
  785. label=custom_model_schema.label,
  786. model_type=custom_model_schema.model_type,
  787. features=custom_model_schema.features,
  788. fetch_from=FetchFrom.PREDEFINED_MODEL,
  789. model_properties=custom_model_schema.model_properties,
  790. deprecated=custom_model_schema.deprecated,
  791. provider=SimpleModelProviderEntity(self.provider),
  792. status=status,
  793. )
  794. )
  795. # if llm name not in restricted llm list, remove it
  796. restrict_model_names = [rm.model for rm in restrict_models]
  797. for model in provider_models:
  798. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  799. model.status = ModelStatus.NO_PERMISSION
  800. elif not quota_configuration.is_valid:
  801. model.status = ModelStatus.QUOTA_EXCEEDED
  802. return provider_models
  803. def _get_custom_provider_models(
  804. self,
  805. model_types: Sequence[ModelType],
  806. provider_schema: ProviderEntity,
  807. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  808. model: Optional[str] = None,
  809. ) -> list[ModelWithProviderEntity]:
  810. """
  811. Get custom provider models.
  812. :param model_types: model types
  813. :param provider_schema: provider schema
  814. :param model_setting_map: model setting map
  815. :return:
  816. """
  817. provider_models = []
  818. credentials = None
  819. if self.custom_configuration.provider:
  820. credentials = self.custom_configuration.provider.credentials
  821. for model_type in model_types:
  822. if model_type not in self.provider.supported_model_types:
  823. continue
  824. for m in provider_schema.models:
  825. if m.model_type != model_type:
  826. continue
  827. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  828. load_balancing_enabled = False
  829. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  830. model_setting = model_setting_map[m.model_type][m.model]
  831. if model_setting.enabled is False:
  832. status = ModelStatus.DISABLED
  833. if len(model_setting.load_balancing_configs) > 1:
  834. load_balancing_enabled = True
  835. provider_models.append(
  836. ModelWithProviderEntity(
  837. model=m.model,
  838. label=m.label,
  839. model_type=m.model_type,
  840. features=m.features,
  841. fetch_from=m.fetch_from,
  842. model_properties=m.model_properties,
  843. deprecated=m.deprecated,
  844. provider=SimpleModelProviderEntity(self.provider),
  845. status=status,
  846. load_balancing_enabled=load_balancing_enabled,
  847. )
  848. )
  849. # custom models
  850. for model_configuration in self.custom_configuration.models:
  851. if model_configuration.model_type not in model_types:
  852. continue
  853. if model and model != model_configuration.model:
  854. continue
  855. try:
  856. custom_model_schema = self.get_model_schema(
  857. model_type=model_configuration.model_type,
  858. model=model_configuration.model,
  859. credentials=model_configuration.credentials,
  860. )
  861. except Exception as ex:
  862. logger.warning("get custom model schema failed, %s", ex)
  863. continue
  864. if not custom_model_schema:
  865. continue
  866. status = ModelStatus.ACTIVE
  867. load_balancing_enabled = False
  868. if (
  869. custom_model_schema.model_type in model_setting_map
  870. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  871. ):
  872. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  873. if model_setting.enabled is False:
  874. status = ModelStatus.DISABLED
  875. if len(model_setting.load_balancing_configs) > 1:
  876. load_balancing_enabled = True
  877. provider_models.append(
  878. ModelWithProviderEntity(
  879. model=custom_model_schema.model,
  880. label=custom_model_schema.label,
  881. model_type=custom_model_schema.model_type,
  882. features=custom_model_schema.features,
  883. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  884. model_properties=custom_model_schema.model_properties,
  885. deprecated=custom_model_schema.deprecated,
  886. provider=SimpleModelProviderEntity(self.provider),
  887. status=status,
  888. load_balancing_enabled=load_balancing_enabled,
  889. )
  890. )
  891. return provider_models
  892. class ProviderConfigurations(BaseModel):
  893. """
  894. Model class for provider configuration dict.
  895. """
  896. tenant_id: str
  897. configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
  898. def __init__(self, tenant_id: str):
  899. super().__init__(tenant_id=tenant_id)
  900. def get_models(
  901. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  902. ) -> list[ModelWithProviderEntity]:
  903. """
  904. Get available models.
  905. If preferred provider type is `system`:
  906. Get the current **system mode** if provider supported,
  907. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  908. If there is no model configured in custom mode, it is treated as no_configure.
  909. system > custom > no_configure
  910. If preferred provider type is `custom`:
  911. If custom credentials are configured, it is treated as custom mode.
  912. Otherwise, get the current **system mode** if supported,
  913. If all system modes are not available (no quota), it is treated as no_configure.
  914. custom > system > no_configure
  915. If real mode is `system`, use system credentials to get models,
  916. paid quotas > provider free quotas > system free quotas
  917. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  918. If real mode is `custom`, use workspace custom credentials to get models,
  919. include pre-defined models, custom models(manual append).
  920. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  921. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  922. model status marked as `active` is available.
  923. :param provider: provider name
  924. :param model_type: model type
  925. :param only_active: only active models
  926. :return:
  927. """
  928. all_models = []
  929. for provider_configuration in self.values():
  930. if provider and provider_configuration.provider.provider != provider:
  931. continue
  932. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  933. return all_models
  934. def to_list(self) -> list[ProviderConfiguration]:
  935. """
  936. Convert to list.
  937. :return:
  938. """
  939. return list(self.values())
  940. def __getitem__(self, key):
  941. if "/" not in key:
  942. key = str(ModelProviderID(key))
  943. return self.configurations[key]
  944. def __setitem__(self, key, value):
  945. self.configurations[key] = value
  946. def __iter__(self):
  947. return iter(self.configurations)
  948. def values(self) -> Iterator[ProviderConfiguration]:
  949. return iter(self.configurations.values())
  950. def get(self, key, default=None) -> ProviderConfiguration | None:
  951. if "/" not in key:
  952. key = str(ModelProviderID(key))
  953. return self.configurations.get(key, default) # type: ignore
  954. class ProviderModelBundle(BaseModel):
  955. """
  956. Provider model bundle.
  957. """
  958. configuration: ProviderConfiguration
  959. model_type_instance: AIModel
  960. # pydantic configs
  961. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())