您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

provider_configuration.py 78KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863
  1. import json
  2. import logging
  3. import re
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from pydantic import BaseModel, ConfigDict, Field
  8. from sqlalchemy import func, select
  9. from sqlalchemy.orm import Session
  10. from constants import HIDDEN_VALUE
  11. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  12. from core.entities.provider_entities import (
  13. CustomConfiguration,
  14. ModelSettings,
  15. SystemConfiguration,
  16. SystemConfigurationStatus,
  17. )
  18. from core.helper import encrypter
  19. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  20. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  21. from core.model_runtime.entities.provider_entities import (
  22. ConfigurateMethod,
  23. CredentialFormSchema,
  24. FormType,
  25. ProviderEntity,
  26. )
  27. from core.model_runtime.model_providers.__base.ai_model import AIModel
  28. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  29. from extensions.ext_database import db
  30. from libs.datetime_utils import naive_utc_now
  31. from models.provider import (
  32. LoadBalancingModelConfig,
  33. Provider,
  34. ProviderCredential,
  35. ProviderModel,
  36. ProviderModelCredential,
  37. ProviderModelSetting,
  38. ProviderType,
  39. TenantPreferredModelProvider,
  40. )
  41. from models.provider_ids import ModelProviderID
  42. from services.enterprise.plugin_manager_service import PluginCredentialType
  43. logger = logging.getLogger(__name__)
  44. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  45. class ProviderConfiguration(BaseModel):
  46. """
  47. Provider configuration entity for managing model provider settings.
  48. This class handles:
  49. - Provider credentials CRUD and switch
  50. - Custom Model credentials CRUD and switch
  51. - System vs custom provider switching
  52. - Load balancing configurations
  53. - Model enablement/disablement
  54. TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
  55. """
  56. tenant_id: str
  57. provider: ProviderEntity
  58. preferred_provider_type: ProviderType
  59. using_provider_type: ProviderType
  60. system_configuration: SystemConfiguration
  61. custom_configuration: CustomConfiguration
  62. model_settings: list[ModelSettings]
  63. # pydantic configs
  64. model_config = ConfigDict(protected_namespaces=())
  65. def __init__(self, **data):
  66. super().__init__(**data)
  67. if self.provider.provider not in original_provider_configurate_methods:
  68. original_provider_configurate_methods[self.provider.provider] = []
  69. for configurate_method in self.provider.configurate_methods:
  70. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  71. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  72. if (
  73. any(
  74. len(quota_configuration.restrict_models) > 0
  75. for quota_configuration in self.system_configuration.quota_configurations
  76. )
  77. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  78. ):
  79. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  80. def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
  81. """
  82. Get current credentials.
  83. :param model_type: model type
  84. :param model: model name
  85. :return:
  86. """
  87. if self.model_settings:
  88. # check if model is disabled by admin
  89. for model_setting in self.model_settings:
  90. if model_setting.model_type == model_type and model_setting.model == model:
  91. if not model_setting.enabled:
  92. raise ValueError(f"Model {model} is disabled.")
  93. if self.using_provider_type == ProviderType.SYSTEM:
  94. restrict_models = []
  95. for quota_configuration in self.system_configuration.quota_configurations:
  96. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  97. continue
  98. restrict_models = quota_configuration.restrict_models
  99. copy_credentials = (
  100. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  101. )
  102. if restrict_models:
  103. for restrict_model in restrict_models:
  104. if (
  105. restrict_model.model_type == model_type
  106. and restrict_model.model == model
  107. and restrict_model.base_model_name
  108. ):
  109. copy_credentials["base_model_name"] = restrict_model.base_model_name
  110. return copy_credentials
  111. else:
  112. credentials = None
  113. current_credential_id = None
  114. if self.custom_configuration.models:
  115. for model_configuration in self.custom_configuration.models:
  116. if model_configuration.model_type == model_type and model_configuration.model == model:
  117. credentials = model_configuration.credentials
  118. current_credential_id = model_configuration.current_credential_id
  119. break
  120. if not credentials and self.custom_configuration.provider:
  121. credentials = self.custom_configuration.provider.credentials
  122. current_credential_id = self.custom_configuration.provider.current_credential_id
  123. if current_credential_id:
  124. from core.helper.credential_utils import check_credential_policy_compliance
  125. check_credential_policy_compliance(
  126. credential_id=current_credential_id,
  127. provider=self.provider.provider,
  128. credential_type=PluginCredentialType.MODEL,
  129. )
  130. else:
  131. # no current credential id, check all available credentials
  132. if self.custom_configuration.provider:
  133. for credential_configuration in self.custom_configuration.provider.available_credentials:
  134. from core.helper.credential_utils import check_credential_policy_compliance
  135. check_credential_policy_compliance(
  136. credential_id=credential_configuration.credential_id,
  137. provider=self.provider.provider,
  138. credential_type=PluginCredentialType.MODEL,
  139. )
  140. return credentials
  141. def get_system_configuration_status(self) -> SystemConfigurationStatus | None:
  142. """
  143. Get system configuration status.
  144. :return:
  145. """
  146. if self.system_configuration.enabled is False:
  147. return SystemConfigurationStatus.UNSUPPORTED
  148. current_quota_type = self.system_configuration.current_quota_type
  149. current_quota_configuration = next(
  150. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  151. )
  152. if current_quota_configuration is None:
  153. return None
  154. if not current_quota_configuration:
  155. return SystemConfigurationStatus.UNSUPPORTED
  156. return (
  157. SystemConfigurationStatus.ACTIVE
  158. if current_quota_configuration.is_valid
  159. else SystemConfigurationStatus.QUOTA_EXCEEDED
  160. )
  161. def is_custom_configuration_available(self) -> bool:
  162. """
  163. Check custom configuration available.
  164. :return:
  165. """
  166. has_provider_credentials = (
  167. self.custom_configuration.provider is not None
  168. and len(self.custom_configuration.provider.available_credentials) > 0
  169. )
  170. has_model_configurations = len(self.custom_configuration.models) > 0
  171. return has_provider_credentials or has_model_configurations
  172. def _get_provider_record(self, session: Session) -> Provider | None:
  173. """
  174. Get custom provider record.
  175. """
  176. stmt = select(Provider).where(
  177. Provider.tenant_id == self.tenant_id,
  178. Provider.provider_type == ProviderType.CUSTOM.value,
  179. Provider.provider_name.in_(self._get_provider_names()),
  180. )
  181. return session.execute(stmt).scalar_one_or_none()
  182. def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
  183. """
  184. Get a specific provider credential by ID.
  185. :param credential_id: Credential ID
  186. :return:
  187. """
  188. # Extract secret variables from provider credential schema
  189. credential_secret_variables = self.extract_secret_variables(
  190. self.provider.provider_credential_schema.credential_form_schemas
  191. if self.provider.provider_credential_schema
  192. else []
  193. )
  194. with Session(db.engine) as session:
  195. # Prefer the actual provider record name if exists (to handle aliased provider names)
  196. provider_record = self._get_provider_record(session)
  197. provider_name = provider_record.provider_name if provider_record else self.provider.provider
  198. stmt = select(ProviderCredential).where(
  199. ProviderCredential.id == credential_id,
  200. ProviderCredential.tenant_id == self.tenant_id,
  201. ProviderCredential.provider_name == provider_name,
  202. )
  203. credential = session.execute(stmt).scalar_one_or_none()
  204. if not credential or not credential.encrypted_config:
  205. raise ValueError(f"Credential with id {credential_id} not found.")
  206. try:
  207. credentials = json.loads(credential.encrypted_config)
  208. except JSONDecodeError:
  209. credentials = {}
  210. # Decrypt secret variables
  211. for key in credential_secret_variables:
  212. if key in credentials and credentials[key] is not None:
  213. try:
  214. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  215. except Exception:
  216. pass
  217. return self.obfuscated_credentials(
  218. credentials=credentials,
  219. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  220. if self.provider.provider_credential_schema
  221. else [],
  222. )
  223. def _check_provider_credential_name_exists(
  224. self, credential_name: str, session: Session, exclude_id: str | None = None
  225. ) -> bool:
  226. """
  227. not allowed same name when create or update a credential
  228. """
  229. stmt = select(ProviderCredential.id).where(
  230. ProviderCredential.tenant_id == self.tenant_id,
  231. ProviderCredential.provider_name.in_(self._get_provider_names()),
  232. ProviderCredential.credential_name == credential_name,
  233. )
  234. if exclude_id:
  235. stmt = stmt.where(ProviderCredential.id != exclude_id)
  236. return session.execute(stmt).scalar_one_or_none() is not None
  237. def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
  238. """
  239. Get provider credentials.
  240. :param credential_id: if provided, return the specified credential
  241. :return:
  242. """
  243. if credential_id:
  244. return self._get_specific_provider_credential(credential_id)
  245. # Default behavior: return current active provider credentials
  246. credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {}
  247. return self.obfuscated_credentials(
  248. credentials=credentials,
  249. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  250. if self.provider.provider_credential_schema
  251. else [],
  252. )
  253. def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
  254. """
  255. Validate custom credentials.
  256. :param credentials: provider credentials
  257. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  258. :param session: optional database session
  259. :return:
  260. """
  261. def _validate(s: Session):
  262. # Get provider credential secret variables
  263. provider_credential_secret_variables = self.extract_secret_variables(
  264. self.provider.provider_credential_schema.credential_form_schemas
  265. if self.provider.provider_credential_schema
  266. else []
  267. )
  268. if credential_id:
  269. try:
  270. stmt = select(ProviderCredential).where(
  271. ProviderCredential.tenant_id == self.tenant_id,
  272. ProviderCredential.provider_name.in_(self._get_provider_names()),
  273. ProviderCredential.id == credential_id,
  274. )
  275. credential_record = s.execute(stmt).scalar_one_or_none()
  276. # fix origin data
  277. if credential_record and credential_record.encrypted_config:
  278. if not credential_record.encrypted_config.startswith("{"):
  279. original_credentials = {"openai_api_key": credential_record.encrypted_config}
  280. else:
  281. original_credentials = json.loads(credential_record.encrypted_config)
  282. else:
  283. original_credentials = {}
  284. except JSONDecodeError:
  285. original_credentials = {}
  286. # encrypt credentials
  287. for key, value in credentials.items():
  288. if key in provider_credential_secret_variables:
  289. # if send [__HIDDEN__] in secret input, it will be same as original value
  290. if value == HIDDEN_VALUE and key in original_credentials:
  291. credentials[key] = encrypter.decrypt_token(
  292. tenant_id=self.tenant_id, token=original_credentials[key]
  293. )
  294. model_provider_factory = ModelProviderFactory(self.tenant_id)
  295. validated_credentials = model_provider_factory.provider_credentials_validate(
  296. provider=self.provider.provider, credentials=credentials
  297. )
  298. for key, value in validated_credentials.items():
  299. if key in provider_credential_secret_variables:
  300. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  301. return validated_credentials
  302. if session:
  303. return _validate(session)
  304. else:
  305. with Session(db.engine) as new_session:
  306. return _validate(new_session)
  307. def _generate_provider_credential_name(self, session) -> str:
  308. """
  309. Generate a unique credential name for provider.
  310. :return: credential name
  311. """
  312. return self._generate_next_api_key_name(
  313. session=session,
  314. query_factory=lambda: select(ProviderCredential).where(
  315. ProviderCredential.tenant_id == self.tenant_id,
  316. ProviderCredential.provider_name.in_(self._get_provider_names()),
  317. ),
  318. )
  319. def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
  320. """
  321. Generate a unique credential name for custom model.
  322. :return: credential name
  323. """
  324. return self._generate_next_api_key_name(
  325. session=session,
  326. query_factory=lambda: select(ProviderModelCredential).where(
  327. ProviderModelCredential.tenant_id == self.tenant_id,
  328. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  329. ProviderModelCredential.model_name == model,
  330. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  331. ),
  332. )
  333. def _generate_next_api_key_name(self, session, query_factory) -> str:
  334. """
  335. Generate next available API KEY name by finding the highest numbered suffix.
  336. :param session: database session
  337. :param query_factory: function that returns the SQLAlchemy query
  338. :return: next available API KEY name
  339. """
  340. try:
  341. stmt = query_factory()
  342. credential_records = session.execute(stmt).scalars().all()
  343. if not credential_records:
  344. return "API KEY 1"
  345. # Extract numbers from API KEY pattern using list comprehension
  346. pattern = re.compile(r"^API KEY\s+(\d+)$")
  347. numbers = [
  348. int(match.group(1))
  349. for cr in credential_records
  350. if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
  351. ]
  352. # Return next sequential number
  353. next_number = max(numbers, default=0) + 1
  354. return f"API KEY {next_number}"
  355. except Exception as e:
  356. logger.warning("Error generating next credential name: %s", str(e))
  357. return "API KEY 1"
  358. def _get_provider_names(self):
  359. """
  360. The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
  361. """
  362. model_provider_id = ModelProviderID(self.provider.provider)
  363. provider_names = [self.provider.provider]
  364. if model_provider_id.is_langgenius():
  365. provider_names.append(model_provider_id.provider_name)
  366. return provider_names
  367. def create_provider_credential(self, credentials: dict, credential_name: str | None):
  368. """
  369. Add custom provider credentials.
  370. :param credentials: provider credentials
  371. :param credential_name: credential name
  372. :return:
  373. """
  374. with Session(db.engine) as session:
  375. if credential_name:
  376. if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
  377. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  378. else:
  379. credential_name = self._generate_provider_credential_name(session)
  380. credentials = self.validate_provider_credentials(credentials=credentials, session=session)
  381. provider_record = self._get_provider_record(session)
  382. try:
  383. new_record = ProviderCredential(
  384. tenant_id=self.tenant_id,
  385. provider_name=self.provider.provider,
  386. encrypted_config=json.dumps(credentials),
  387. credential_name=credential_name,
  388. )
  389. session.add(new_record)
  390. session.flush()
  391. if not provider_record:
  392. # If provider record does not exist, create it
  393. provider_record = Provider(
  394. tenant_id=self.tenant_id,
  395. provider_name=self.provider.provider,
  396. provider_type=ProviderType.CUSTOM.value,
  397. is_valid=True,
  398. credential_id=new_record.id,
  399. )
  400. session.add(provider_record)
  401. provider_model_credentials_cache = ProviderCredentialsCache(
  402. tenant_id=self.tenant_id,
  403. identity_id=provider_record.id,
  404. cache_type=ProviderCredentialsCacheType.PROVIDER,
  405. )
  406. provider_model_credentials_cache.delete()
  407. self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
  408. session.commit()
  409. except Exception:
  410. session.rollback()
  411. raise
  412. def update_provider_credential(
  413. self,
  414. credentials: dict,
  415. credential_id: str,
  416. credential_name: str | None,
  417. ):
  418. """
  419. update a saved provider credential (by credential_id).
  420. :param credentials: provider credentials
  421. :param credential_id: credential id
  422. :param credential_name: credential name
  423. :return:
  424. """
  425. with Session(db.engine) as session:
  426. if credential_name and self._check_provider_credential_name_exists(
  427. credential_name=credential_name, session=session, exclude_id=credential_id
  428. ):
  429. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  430. credentials = self.validate_provider_credentials(
  431. credentials=credentials, credential_id=credential_id, session=session
  432. )
  433. provider_record = self._get_provider_record(session)
  434. stmt = select(ProviderCredential).where(
  435. ProviderCredential.id == credential_id,
  436. ProviderCredential.tenant_id == self.tenant_id,
  437. ProviderCredential.provider_name.in_(self._get_provider_names()),
  438. )
  439. # Get the credential record to update
  440. credential_record = session.execute(stmt).scalar_one_or_none()
  441. if not credential_record:
  442. raise ValueError("Credential record not found.")
  443. try:
  444. # Update credential
  445. credential_record.encrypted_config = json.dumps(credentials)
  446. credential_record.updated_at = naive_utc_now()
  447. if credential_name:
  448. credential_record.credential_name = credential_name
  449. session.commit()
  450. if provider_record and provider_record.credential_id == credential_id:
  451. provider_model_credentials_cache = ProviderCredentialsCache(
  452. tenant_id=self.tenant_id,
  453. identity_id=provider_record.id,
  454. cache_type=ProviderCredentialsCacheType.PROVIDER,
  455. )
  456. provider_model_credentials_cache.delete()
  457. self._update_load_balancing_configs_with_credential(
  458. credential_id=credential_id,
  459. credential_record=credential_record,
  460. credential_source="provider",
  461. session=session,
  462. )
  463. except Exception:
  464. session.rollback()
  465. raise
  466. def _update_load_balancing_configs_with_credential(
  467. self,
  468. credential_id: str,
  469. credential_record: ProviderCredential | ProviderModelCredential,
  470. credential_source: str,
  471. session: Session,
  472. ):
  473. """
  474. Update load balancing configurations that reference the given credential_id.
  475. :param credential_id: credential id
  476. :param credential_record: the encrypted_config and credential_name
  477. :param credential_source: the credential comes from the provider_credential(`provider`)
  478. or the provider_model_credential(`custom_model`)
  479. :param session: the database session
  480. :return:
  481. """
  482. # Find all load balancing configs that use this credential_id
  483. stmt = select(LoadBalancingModelConfig).where(
  484. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  485. LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
  486. LoadBalancingModelConfig.credential_id == credential_id,
  487. LoadBalancingModelConfig.credential_source_type == credential_source,
  488. )
  489. load_balancing_configs = session.execute(stmt).scalars().all()
  490. if not load_balancing_configs:
  491. return
  492. # Update each load balancing config with the new credentials
  493. for lb_config in load_balancing_configs:
  494. # Update the encrypted_config with the new credentials
  495. lb_config.encrypted_config = credential_record.encrypted_config
  496. lb_config.name = credential_record.credential_name
  497. lb_config.updated_at = naive_utc_now()
  498. # Clear cache for this load balancing config
  499. lb_credentials_cache = ProviderCredentialsCache(
  500. tenant_id=self.tenant_id,
  501. identity_id=lb_config.id,
  502. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  503. )
  504. lb_credentials_cache.delete()
  505. session.commit()
  506. def delete_provider_credential(self, credential_id: str):
  507. """
  508. Delete a saved provider credential (by credential_id).
  509. :param credential_id: credential id
  510. :return:
  511. """
  512. with Session(db.engine) as session:
  513. stmt = select(ProviderCredential).where(
  514. ProviderCredential.id == credential_id,
  515. ProviderCredential.tenant_id == self.tenant_id,
  516. ProviderCredential.provider_name.in_(self._get_provider_names()),
  517. )
  518. # Get the credential record to update
  519. credential_record = session.execute(stmt).scalar_one_or_none()
  520. if not credential_record:
  521. raise ValueError("Credential record not found.")
  522. # Check if this credential is used in load balancing configs
  523. lb_stmt = select(LoadBalancingModelConfig).where(
  524. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  525. LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
  526. LoadBalancingModelConfig.credential_id == credential_id,
  527. LoadBalancingModelConfig.credential_source_type == "provider",
  528. )
  529. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  530. try:
  531. for lb_config in lb_configs_using_credential:
  532. lb_credentials_cache = ProviderCredentialsCache(
  533. tenant_id=self.tenant_id,
  534. identity_id=lb_config.id,
  535. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  536. )
  537. lb_credentials_cache.delete()
  538. session.delete(lb_config)
  539. # Check if this is the currently active credential
  540. provider_record = self._get_provider_record(session)
  541. # Check available credentials count BEFORE deleting
  542. # if this is the last credential, we need to delete the provider record
  543. count_stmt = select(func.count(ProviderCredential.id)).where(
  544. ProviderCredential.tenant_id == self.tenant_id,
  545. ProviderCredential.provider_name.in_(self._get_provider_names()),
  546. )
  547. available_credentials_count = session.execute(count_stmt).scalar() or 0
  548. session.delete(credential_record)
  549. if provider_record and available_credentials_count <= 1:
  550. # If all credentials are deleted, delete the provider record
  551. session.delete(provider_record)
  552. provider_model_credentials_cache = ProviderCredentialsCache(
  553. tenant_id=self.tenant_id,
  554. identity_id=provider_record.id,
  555. cache_type=ProviderCredentialsCacheType.PROVIDER,
  556. )
  557. provider_model_credentials_cache.delete()
  558. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  559. elif provider_record and provider_record.credential_id == credential_id:
  560. provider_record.credential_id = None
  561. provider_record.updated_at = naive_utc_now()
  562. provider_model_credentials_cache = ProviderCredentialsCache(
  563. tenant_id=self.tenant_id,
  564. identity_id=provider_record.id,
  565. cache_type=ProviderCredentialsCacheType.PROVIDER,
  566. )
  567. provider_model_credentials_cache.delete()
  568. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  569. session.commit()
  570. except Exception:
  571. session.rollback()
  572. raise
  573. def switch_active_provider_credential(self, credential_id: str):
  574. """
  575. Switch active provider credential (copy the selected one into current active snapshot).
  576. :param credential_id: credential id
  577. :return:
  578. """
  579. with Session(db.engine) as session:
  580. stmt = select(ProviderCredential).where(
  581. ProviderCredential.id == credential_id,
  582. ProviderCredential.tenant_id == self.tenant_id,
  583. ProviderCredential.provider_name.in_(self._get_provider_names()),
  584. )
  585. credential_record = session.execute(stmt).scalar_one_or_none()
  586. if not credential_record:
  587. raise ValueError("Credential record not found.")
  588. provider_record = self._get_provider_record(session)
  589. if not provider_record:
  590. raise ValueError("Provider record not found.")
  591. try:
  592. provider_record.credential_id = credential_record.id
  593. provider_record.updated_at = naive_utc_now()
  594. session.commit()
  595. provider_model_credentials_cache = ProviderCredentialsCache(
  596. tenant_id=self.tenant_id,
  597. identity_id=provider_record.id,
  598. cache_type=ProviderCredentialsCacheType.PROVIDER,
  599. )
  600. provider_model_credentials_cache.delete()
  601. self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session)
  602. except Exception:
  603. session.rollback()
  604. raise
  605. def _get_custom_model_record(
  606. self,
  607. model_type: ModelType,
  608. model: str,
  609. session: Session,
  610. ) -> ProviderModel | None:
  611. """
  612. Get custom model credentials.
  613. """
  614. # get provider model
  615. model_provider_id = ModelProviderID(self.provider.provider)
  616. provider_names = [self.provider.provider]
  617. if model_provider_id.is_langgenius():
  618. provider_names.append(model_provider_id.provider_name)
  619. stmt = select(ProviderModel).where(
  620. ProviderModel.tenant_id == self.tenant_id,
  621. ProviderModel.provider_name.in_(provider_names),
  622. ProviderModel.model_name == model,
  623. ProviderModel.model_type == model_type.to_origin_model_type(),
  624. )
  625. return session.execute(stmt).scalar_one_or_none()
  626. def _get_specific_custom_model_credential(
  627. self, model_type: ModelType, model: str, credential_id: str
  628. ) -> dict | None:
  629. """
  630. Get a specific provider credential by ID.
  631. :param credential_id: Credential ID
  632. :return:
  633. """
  634. model_credential_secret_variables = self.extract_secret_variables(
  635. self.provider.model_credential_schema.credential_form_schemas
  636. if self.provider.model_credential_schema
  637. else []
  638. )
  639. with Session(db.engine) as session:
  640. stmt = select(ProviderModelCredential).where(
  641. ProviderModelCredential.id == credential_id,
  642. ProviderModelCredential.tenant_id == self.tenant_id,
  643. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  644. ProviderModelCredential.model_name == model,
  645. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  646. )
  647. credential_record = session.execute(stmt).scalar_one_or_none()
  648. if not credential_record or not credential_record.encrypted_config:
  649. raise ValueError(f"Credential with id {credential_id} not found.")
  650. try:
  651. credentials = json.loads(credential_record.encrypted_config)
  652. except JSONDecodeError:
  653. credentials = {}
  654. # Decrypt secret variables
  655. for key in model_credential_secret_variables:
  656. if key in credentials and credentials[key] is not None:
  657. try:
  658. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  659. except Exception:
  660. pass
  661. current_credential_id = credential_record.id
  662. current_credential_name = credential_record.credential_name
  663. credentials = self.obfuscated_credentials(
  664. credentials=credentials,
  665. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  666. if self.provider.model_credential_schema
  667. else [],
  668. )
  669. return {
  670. "current_credential_id": current_credential_id,
  671. "current_credential_name": current_credential_name,
  672. "credentials": credentials,
  673. }
  674. def _check_custom_model_credential_name_exists(
  675. self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None
  676. ) -> bool:
  677. """
  678. not allowed same name when create or update a credential
  679. """
  680. stmt = select(ProviderModelCredential).where(
  681. ProviderModelCredential.tenant_id == self.tenant_id,
  682. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  683. ProviderModelCredential.model_name == model,
  684. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  685. ProviderModelCredential.credential_name == credential_name,
  686. )
  687. if exclude_id:
  688. stmt = stmt.where(ProviderModelCredential.id != exclude_id)
  689. return session.execute(stmt).scalar_one_or_none() is not None
  690. def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
  691. """
  692. Get custom model credentials.
  693. :param model_type: model type
  694. :param model: model name
  695. :return:
  696. """
  697. # If credential_id is provided, return the specific credential
  698. if credential_id:
  699. return self._get_specific_custom_model_credential(
  700. model_type=model_type, model=model, credential_id=credential_id
  701. )
  702. for model_configuration in self.custom_configuration.models:
  703. if (
  704. model_configuration.model_type == model_type
  705. and model_configuration.model == model
  706. and model_configuration.credentials
  707. ):
  708. current_credential_id = model_configuration.current_credential_id
  709. current_credential_name = model_configuration.current_credential_name
  710. credentials = self.obfuscated_credentials(
  711. credentials=model_configuration.credentials,
  712. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  713. if self.provider.model_credential_schema
  714. else [],
  715. )
  716. return {
  717. "current_credential_id": current_credential_id,
  718. "current_credential_name": current_credential_name,
  719. "credentials": credentials,
  720. }
  721. return None
  722. def validate_custom_model_credentials(
  723. self,
  724. model_type: ModelType,
  725. model: str,
  726. credentials: dict,
  727. credential_id: str = "",
  728. session: Session | None = None,
  729. ):
  730. """
  731. Validate custom model credentials.
  732. :param model_type: model type
  733. :param model: model name
  734. :param credentials: model credentials dict
  735. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  736. :return:
  737. """
  738. def _validate(s: Session):
  739. # Get provider credential secret variables
  740. provider_credential_secret_variables = self.extract_secret_variables(
  741. self.provider.model_credential_schema.credential_form_schemas
  742. if self.provider.model_credential_schema
  743. else []
  744. )
  745. if credential_id:
  746. try:
  747. stmt = select(ProviderModelCredential).where(
  748. ProviderModelCredential.id == credential_id,
  749. ProviderModelCredential.tenant_id == self.tenant_id,
  750. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  751. ProviderModelCredential.model_name == model,
  752. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  753. )
  754. credential_record = s.execute(stmt).scalar_one_or_none()
  755. original_credentials = (
  756. json.loads(credential_record.encrypted_config)
  757. if credential_record and credential_record.encrypted_config
  758. else {}
  759. )
  760. except JSONDecodeError:
  761. original_credentials = {}
  762. # decrypt credentials
  763. for key, value in credentials.items():
  764. if key in provider_credential_secret_variables:
  765. # if send [__HIDDEN__] in secret input, it will be same as original value
  766. if value == HIDDEN_VALUE and key in original_credentials:
  767. credentials[key] = encrypter.decrypt_token(
  768. tenant_id=self.tenant_id, token=original_credentials[key]
  769. )
  770. model_provider_factory = ModelProviderFactory(self.tenant_id)
  771. validated_credentials = model_provider_factory.model_credentials_validate(
  772. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  773. )
  774. for key, value in validated_credentials.items():
  775. if key in provider_credential_secret_variables:
  776. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  777. return validated_credentials
  778. if session:
  779. return _validate(session)
  780. else:
  781. with Session(db.engine) as new_session:
  782. return _validate(new_session)
  783. def create_custom_model_credential(
  784. self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
  785. ) -> None:
  786. """
  787. Create a custom model credential.
  788. :param model_type: model type
  789. :param model: model name
  790. :param credentials: model credentials dict
  791. :return:
  792. """
  793. with Session(db.engine) as session:
  794. if credential_name:
  795. if self._check_custom_model_credential_name_exists(
  796. model=model, model_type=model_type, credential_name=credential_name, session=session
  797. ):
  798. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  799. else:
  800. credential_name = self._generate_custom_model_credential_name(
  801. model=model, model_type=model_type, session=session
  802. )
  803. # validate custom model config
  804. credentials = self.validate_custom_model_credentials(
  805. model_type=model_type, model=model, credentials=credentials, session=session
  806. )
  807. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  808. try:
  809. credential = ProviderModelCredential(
  810. tenant_id=self.tenant_id,
  811. provider_name=self.provider.provider,
  812. model_name=model,
  813. model_type=model_type.to_origin_model_type(),
  814. encrypted_config=json.dumps(credentials),
  815. credential_name=credential_name,
  816. )
  817. session.add(credential)
  818. session.flush()
  819. # save provider model
  820. if not provider_model_record:
  821. provider_model_record = ProviderModel(
  822. tenant_id=self.tenant_id,
  823. provider_name=self.provider.provider,
  824. model_name=model,
  825. model_type=model_type.to_origin_model_type(),
  826. credential_id=credential.id,
  827. is_valid=True,
  828. )
  829. session.add(provider_model_record)
  830. session.commit()
  831. provider_model_credentials_cache = ProviderCredentialsCache(
  832. tenant_id=self.tenant_id,
  833. identity_id=provider_model_record.id,
  834. cache_type=ProviderCredentialsCacheType.MODEL,
  835. )
  836. provider_model_credentials_cache.delete()
  837. except Exception:
  838. session.rollback()
  839. raise
  840. def update_custom_model_credential(
  841. self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
  842. ) -> None:
  843. """
  844. Update a custom model credential.
  845. :param model_type: model type
  846. :param model: model name
  847. :param credentials: model credentials dict
  848. :param credential_name: credential name
  849. :param credential_id: credential id
  850. :return:
  851. """
  852. with Session(db.engine) as session:
  853. if credential_name and self._check_custom_model_credential_name_exists(
  854. model=model,
  855. model_type=model_type,
  856. credential_name=credential_name,
  857. session=session,
  858. exclude_id=credential_id,
  859. ):
  860. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  861. # validate custom model config
  862. credentials = self.validate_custom_model_credentials(
  863. model_type=model_type,
  864. model=model,
  865. credentials=credentials,
  866. credential_id=credential_id,
  867. session=session,
  868. )
  869. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  870. stmt = select(ProviderModelCredential).where(
  871. ProviderModelCredential.id == credential_id,
  872. ProviderModelCredential.tenant_id == self.tenant_id,
  873. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  874. ProviderModelCredential.model_name == model,
  875. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  876. )
  877. credential_record = session.execute(stmt).scalar_one_or_none()
  878. if not credential_record:
  879. raise ValueError("Credential record not found.")
  880. try:
  881. # Update credential
  882. credential_record.encrypted_config = json.dumps(credentials)
  883. credential_record.updated_at = naive_utc_now()
  884. if credential_name:
  885. credential_record.credential_name = credential_name
  886. session.commit()
  887. if provider_model_record and provider_model_record.credential_id == credential_id:
  888. provider_model_credentials_cache = ProviderCredentialsCache(
  889. tenant_id=self.tenant_id,
  890. identity_id=provider_model_record.id,
  891. cache_type=ProviderCredentialsCacheType.MODEL,
  892. )
  893. provider_model_credentials_cache.delete()
  894. self._update_load_balancing_configs_with_credential(
  895. credential_id=credential_id,
  896. credential_record=credential_record,
  897. credential_source="custom_model",
  898. session=session,
  899. )
  900. except Exception:
  901. session.rollback()
  902. raise
  903. def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
  904. """
  905. Delete a saved provider credential (by credential_id).
  906. :param credential_id: credential id
  907. :return:
  908. """
  909. with Session(db.engine) as session:
  910. stmt = select(ProviderModelCredential).where(
  911. ProviderModelCredential.id == credential_id,
  912. ProviderModelCredential.tenant_id == self.tenant_id,
  913. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  914. ProviderModelCredential.model_name == model,
  915. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  916. )
  917. credential_record = session.execute(stmt).scalar_one_or_none()
  918. if not credential_record:
  919. raise ValueError("Credential record not found.")
  920. lb_stmt = select(LoadBalancingModelConfig).where(
  921. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  922. LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
  923. LoadBalancingModelConfig.credential_id == credential_id,
  924. LoadBalancingModelConfig.credential_source_type == "custom_model",
  925. )
  926. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  927. try:
  928. for lb_config in lb_configs_using_credential:
  929. lb_credentials_cache = ProviderCredentialsCache(
  930. tenant_id=self.tenant_id,
  931. identity_id=lb_config.id,
  932. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  933. )
  934. lb_credentials_cache.delete()
  935. session.delete(lb_config)
  936. # Check if this is the currently active credential
  937. provider_model_record = self._get_custom_model_record(model_type, model, session=session)
  938. # Check available credentials count BEFORE deleting
  939. # if this is the last credential, we need to delete the custom model record
  940. count_stmt = select(func.count(ProviderModelCredential.id)).where(
  941. ProviderModelCredential.tenant_id == self.tenant_id,
  942. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  943. ProviderModelCredential.model_name == model,
  944. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  945. )
  946. available_credentials_count = session.execute(count_stmt).scalar() or 0
  947. session.delete(credential_record)
  948. if provider_model_record and available_credentials_count <= 1:
  949. # If all credentials are deleted, delete the custom model record
  950. session.delete(provider_model_record)
  951. elif provider_model_record and provider_model_record.credential_id == credential_id:
  952. provider_model_record.credential_id = None
  953. provider_model_record.updated_at = naive_utc_now()
  954. provider_model_credentials_cache = ProviderCredentialsCache(
  955. tenant_id=self.tenant_id,
  956. identity_id=provider_model_record.id,
  957. cache_type=ProviderCredentialsCacheType.PROVIDER,
  958. )
  959. provider_model_credentials_cache.delete()
  960. session.commit()
  961. except Exception:
  962. session.rollback()
  963. raise
  964. def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
  965. """
  966. if model list exist this custom model, switch the custom model credential.
  967. if model list not exist this custom model, use the credential to add a new custom model record.
  968. :param model_type: model type
  969. :param model: model name
  970. :param credential_id: credential id
  971. :return:
  972. """
  973. with Session(db.engine) as session:
  974. stmt = select(ProviderModelCredential).where(
  975. ProviderModelCredential.id == credential_id,
  976. ProviderModelCredential.tenant_id == self.tenant_id,
  977. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  978. ProviderModelCredential.model_name == model,
  979. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  980. )
  981. credential_record = session.execute(stmt).scalar_one_or_none()
  982. if not credential_record:
  983. raise ValueError("Credential record not found.")
  984. # validate custom model config
  985. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  986. if not provider_model_record:
  987. # create provider model record
  988. provider_model_record = ProviderModel(
  989. tenant_id=self.tenant_id,
  990. provider_name=self.provider.provider,
  991. model_name=model,
  992. model_type=model_type.to_origin_model_type(),
  993. is_valid=True,
  994. credential_id=credential_id,
  995. )
  996. else:
  997. if provider_model_record.credential_id == credential_record.id:
  998. raise ValueError("Can't add same credential")
  999. provider_model_record.credential_id = credential_record.id
  1000. provider_model_record.updated_at = naive_utc_now()
  1001. session.add(provider_model_record)
  1002. session.commit()
  1003. def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
  1004. """
  1005. switch the custom model credential.
  1006. :param model_type: model type
  1007. :param model: model name
  1008. :param credential_id: credential id
  1009. :return:
  1010. """
  1011. with Session(db.engine) as session:
  1012. stmt = select(ProviderModelCredential).where(
  1013. ProviderModelCredential.id == credential_id,
  1014. ProviderModelCredential.tenant_id == self.tenant_id,
  1015. ProviderModelCredential.provider_name.in_(self._get_provider_names()),
  1016. ProviderModelCredential.model_name == model,
  1017. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  1018. )
  1019. credential_record = session.execute(stmt).scalar_one_or_none()
  1020. if not credential_record:
  1021. raise ValueError("Credential record not found.")
  1022. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  1023. if not provider_model_record:
  1024. raise ValueError("The custom model record not found.")
  1025. provider_model_record.credential_id = credential_record.id
  1026. provider_model_record.updated_at = naive_utc_now()
  1027. session.add(provider_model_record)
  1028. session.commit()
  1029. def delete_custom_model(self, model_type: ModelType, model: str):
  1030. """
  1031. Delete custom model.
  1032. :param model_type: model type
  1033. :param model: model name
  1034. :return:
  1035. """
  1036. with Session(db.engine) as session:
  1037. # get provider model
  1038. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  1039. # delete provider model
  1040. if provider_model_record:
  1041. session.delete(provider_model_record)
  1042. session.commit()
  1043. provider_model_credentials_cache = ProviderCredentialsCache(
  1044. tenant_id=self.tenant_id,
  1045. identity_id=provider_model_record.id,
  1046. cache_type=ProviderCredentialsCacheType.MODEL,
  1047. )
  1048. provider_model_credentials_cache.delete()
  1049. def _get_provider_model_setting(
  1050. self, model_type: ModelType, model: str, session: Session
  1051. ) -> ProviderModelSetting | None:
  1052. """
  1053. Get provider model setting.
  1054. """
  1055. stmt = select(ProviderModelSetting).where(
  1056. ProviderModelSetting.tenant_id == self.tenant_id,
  1057. ProviderModelSetting.provider_name.in_(self._get_provider_names()),
  1058. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  1059. ProviderModelSetting.model_name == model,
  1060. )
  1061. return session.execute(stmt).scalars().first()
  1062. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1063. """
  1064. Enable model.
  1065. :param model_type: model type
  1066. :param model: model name
  1067. :return:
  1068. """
  1069. with Session(db.engine) as session:
  1070. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1071. if model_setting:
  1072. model_setting.enabled = True
  1073. model_setting.updated_at = naive_utc_now()
  1074. else:
  1075. model_setting = ProviderModelSetting(
  1076. tenant_id=self.tenant_id,
  1077. provider_name=self.provider.provider,
  1078. model_type=model_type.to_origin_model_type(),
  1079. model_name=model,
  1080. enabled=True,
  1081. )
  1082. session.add(model_setting)
  1083. session.commit()
  1084. return model_setting
  1085. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1086. """
  1087. Disable model.
  1088. :param model_type: model type
  1089. :param model: model name
  1090. :return:
  1091. """
  1092. with Session(db.engine) as session:
  1093. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1094. if model_setting:
  1095. model_setting.enabled = False
  1096. model_setting.updated_at = naive_utc_now()
  1097. else:
  1098. model_setting = ProviderModelSetting(
  1099. tenant_id=self.tenant_id,
  1100. provider_name=self.provider.provider,
  1101. model_type=model_type.to_origin_model_type(),
  1102. model_name=model,
  1103. enabled=False,
  1104. )
  1105. session.add(model_setting)
  1106. session.commit()
  1107. return model_setting
  1108. def get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
  1109. """
  1110. Get provider model setting.
  1111. :param model_type: model type
  1112. :param model: model name
  1113. :return:
  1114. """
  1115. with Session(db.engine) as session:
  1116. return self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1117. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1118. """
  1119. Enable model load balancing.
  1120. :param model_type: model type
  1121. :param model: model name
  1122. :return:
  1123. """
  1124. model_provider_id = ModelProviderID(self.provider.provider)
  1125. provider_names = [self.provider.provider]
  1126. if model_provider_id.is_langgenius():
  1127. provider_names.append(model_provider_id.provider_name)
  1128. with Session(db.engine) as session:
  1129. stmt = select(func.count(LoadBalancingModelConfig.id)).where(
  1130. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  1131. LoadBalancingModelConfig.provider_name.in_(provider_names),
  1132. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  1133. LoadBalancingModelConfig.model_name == model,
  1134. )
  1135. load_balancing_config_count = session.execute(stmt).scalar() or 0
  1136. if load_balancing_config_count <= 1:
  1137. raise ValueError("Model load balancing configuration must be more than 1.")
  1138. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1139. if model_setting:
  1140. model_setting.load_balancing_enabled = True
  1141. model_setting.updated_at = naive_utc_now()
  1142. else:
  1143. model_setting = ProviderModelSetting(
  1144. tenant_id=self.tenant_id,
  1145. provider_name=self.provider.provider,
  1146. model_type=model_type.to_origin_model_type(),
  1147. model_name=model,
  1148. load_balancing_enabled=True,
  1149. )
  1150. session.add(model_setting)
  1151. session.commit()
  1152. return model_setting
  1153. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1154. """
  1155. Disable model load balancing.
  1156. :param model_type: model type
  1157. :param model: model name
  1158. :return:
  1159. """
  1160. with Session(db.engine) as session:
  1161. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1162. if model_setting:
  1163. model_setting.load_balancing_enabled = False
  1164. model_setting.updated_at = naive_utc_now()
  1165. else:
  1166. model_setting = ProviderModelSetting(
  1167. tenant_id=self.tenant_id,
  1168. provider_name=self.provider.provider,
  1169. model_type=model_type.to_origin_model_type(),
  1170. model_name=model,
  1171. load_balancing_enabled=False,
  1172. )
  1173. session.add(model_setting)
  1174. session.commit()
  1175. return model_setting
  1176. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  1177. """
  1178. Get current model type instance.
  1179. :param model_type: model type
  1180. :return:
  1181. """
  1182. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1183. # Get model instance of LLM
  1184. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  1185. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
  1186. """
  1187. Get model schema
  1188. """
  1189. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1190. return model_provider_factory.get_model_schema(
  1191. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  1192. )
  1193. def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
  1194. """
  1195. Switch preferred provider type.
  1196. :param provider_type:
  1197. :return:
  1198. """
  1199. if provider_type == self.preferred_provider_type:
  1200. return
  1201. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  1202. return
  1203. def _switch(s: Session):
  1204. stmt = select(TenantPreferredModelProvider).where(
  1205. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  1206. TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
  1207. )
  1208. preferred_model_provider = s.execute(stmt).scalars().first()
  1209. if preferred_model_provider:
  1210. preferred_model_provider.preferred_provider_type = provider_type.value
  1211. else:
  1212. preferred_model_provider = TenantPreferredModelProvider(
  1213. tenant_id=self.tenant_id,
  1214. provider_name=self.provider.provider,
  1215. preferred_provider_type=provider_type.value,
  1216. )
  1217. s.add(preferred_model_provider)
  1218. s.commit()
  1219. if session:
  1220. return _switch(session)
  1221. else:
  1222. with Session(db.engine) as session:
  1223. return _switch(session)
  1224. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  1225. """
  1226. Extract secret input form variables.
  1227. :param credential_form_schemas:
  1228. :return:
  1229. """
  1230. secret_input_form_variables = []
  1231. for credential_form_schema in credential_form_schemas:
  1232. if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
  1233. secret_input_form_variables.append(credential_form_schema.variable)
  1234. return secret_input_form_variables
  1235. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
  1236. """
  1237. Obfuscated credentials.
  1238. :param credentials: credentials
  1239. :param credential_form_schemas: credential form schemas
  1240. :return:
  1241. """
  1242. # Get provider credential secret variables
  1243. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  1244. # Obfuscate provider credentials
  1245. copy_credentials = credentials.copy()
  1246. for key, value in copy_credentials.items():
  1247. if key in credential_secret_variables:
  1248. copy_credentials[key] = encrypter.obfuscated_token(value)
  1249. return copy_credentials
  1250. def get_provider_model(
  1251. self, model_type: ModelType, model: str, only_active: bool = False
  1252. ) -> ModelWithProviderEntity | None:
  1253. """
  1254. Get provider model.
  1255. :param model_type: model type
  1256. :param model: model name
  1257. :param only_active: return active model only
  1258. :return:
  1259. """
  1260. provider_models = self.get_provider_models(model_type, only_active, model)
  1261. for provider_model in provider_models:
  1262. if provider_model.model == model:
  1263. return provider_model
  1264. return None
  1265. def get_provider_models(
  1266. self, model_type: ModelType | None = None, only_active: bool = False, model: str | None = None
  1267. ) -> list[ModelWithProviderEntity]:
  1268. """
  1269. Get provider models.
  1270. :param model_type: model type
  1271. :param only_active: only active models
  1272. :param model: model name
  1273. :return:
  1274. """
  1275. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1276. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  1277. model_types: list[ModelType] = []
  1278. if model_type:
  1279. model_types.append(model_type)
  1280. else:
  1281. model_types = list(provider_schema.supported_model_types)
  1282. # Group model settings by model type and model
  1283. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  1284. for model_setting in self.model_settings:
  1285. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  1286. if self.using_provider_type == ProviderType.SYSTEM:
  1287. provider_models = self._get_system_provider_models(
  1288. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  1289. )
  1290. else:
  1291. provider_models = self._get_custom_provider_models(
  1292. model_types=model_types,
  1293. provider_schema=provider_schema,
  1294. model_setting_map=model_setting_map,
  1295. model=model,
  1296. )
  1297. if only_active:
  1298. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  1299. # resort provider_models
  1300. # Optimize sorting logic: first sort by provider.position order, then by model_type.value
  1301. # Get the position list for model types (retrieve only once for better performance)
  1302. model_type_positions = {}
  1303. if hasattr(self.provider, "position") and self.provider.position:
  1304. model_type_positions = self.provider.position
  1305. def get_sort_key(model: ModelWithProviderEntity):
  1306. # Get the position list for the current model type
  1307. positions = model_type_positions.get(model.model_type.value, [])
  1308. # If the model name is in the position list, use its index for sorting
  1309. # Otherwise use a large value (list length) to place undefined models at the end
  1310. position_index = positions.index(model.model) if model.model in positions else len(positions)
  1311. # Return composite sort key: (model_type value, model position index)
  1312. return (model.model_type.value, position_index)
  1313. # Sort using the composite sort key
  1314. return sorted(provider_models, key=get_sort_key)
  1315. def _get_system_provider_models(
  1316. self,
  1317. model_types: Sequence[ModelType],
  1318. provider_schema: ProviderEntity,
  1319. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1320. ) -> list[ModelWithProviderEntity]:
  1321. """
  1322. Get system provider models.
  1323. :param model_types: model types
  1324. :param provider_schema: provider schema
  1325. :param model_setting_map: model setting map
  1326. :return:
  1327. """
  1328. provider_models = []
  1329. for model_type in model_types:
  1330. for m in provider_schema.models:
  1331. if m.model_type != model_type:
  1332. continue
  1333. status = ModelStatus.ACTIVE
  1334. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1335. model_setting = model_setting_map[m.model_type][m.model]
  1336. if model_setting.enabled is False:
  1337. status = ModelStatus.DISABLED
  1338. provider_models.append(
  1339. ModelWithProviderEntity(
  1340. model=m.model,
  1341. label=m.label,
  1342. model_type=m.model_type,
  1343. features=m.features,
  1344. fetch_from=m.fetch_from,
  1345. model_properties=m.model_properties,
  1346. deprecated=m.deprecated,
  1347. provider=SimpleModelProviderEntity(self.provider),
  1348. status=status,
  1349. )
  1350. )
  1351. if self.provider.provider not in original_provider_configurate_methods:
  1352. original_provider_configurate_methods[self.provider.provider] = []
  1353. for configurate_method in provider_schema.configurate_methods:
  1354. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  1355. should_use_custom_model = False
  1356. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  1357. should_use_custom_model = True
  1358. for quota_configuration in self.system_configuration.quota_configurations:
  1359. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  1360. continue
  1361. restrict_models = quota_configuration.restrict_models
  1362. if len(restrict_models) == 0:
  1363. break
  1364. if should_use_custom_model:
  1365. if original_provider_configurate_methods[self.provider.provider] == [
  1366. ConfigurateMethod.CUSTOMIZABLE_MODEL
  1367. ]:
  1368. # only customizable model
  1369. for restrict_model in restrict_models:
  1370. copy_credentials = (
  1371. self.system_configuration.credentials.copy()
  1372. if self.system_configuration.credentials
  1373. else {}
  1374. )
  1375. if restrict_model.base_model_name:
  1376. copy_credentials["base_model_name"] = restrict_model.base_model_name
  1377. try:
  1378. custom_model_schema = self.get_model_schema(
  1379. model_type=restrict_model.model_type,
  1380. model=restrict_model.model,
  1381. credentials=copy_credentials,
  1382. )
  1383. except Exception as ex:
  1384. logger.warning("get custom model schema failed, %s", ex)
  1385. continue
  1386. if not custom_model_schema:
  1387. continue
  1388. if custom_model_schema.model_type not in model_types:
  1389. continue
  1390. status = ModelStatus.ACTIVE
  1391. if (
  1392. custom_model_schema.model_type in model_setting_map
  1393. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1394. ):
  1395. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1396. if model_setting.enabled is False:
  1397. status = ModelStatus.DISABLED
  1398. provider_models.append(
  1399. ModelWithProviderEntity(
  1400. model=custom_model_schema.model,
  1401. label=custom_model_schema.label,
  1402. model_type=custom_model_schema.model_type,
  1403. features=custom_model_schema.features,
  1404. fetch_from=FetchFrom.PREDEFINED_MODEL,
  1405. model_properties=custom_model_schema.model_properties,
  1406. deprecated=custom_model_schema.deprecated,
  1407. provider=SimpleModelProviderEntity(self.provider),
  1408. status=status,
  1409. )
  1410. )
  1411. # if llm name not in restricted llm list, remove it
  1412. restrict_model_names = [rm.model for rm in restrict_models]
  1413. for model in provider_models:
  1414. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  1415. model.status = ModelStatus.NO_PERMISSION
  1416. elif not quota_configuration.is_valid:
  1417. model.status = ModelStatus.QUOTA_EXCEEDED
  1418. return provider_models
  1419. def _get_custom_provider_models(
  1420. self,
  1421. model_types: Sequence[ModelType],
  1422. provider_schema: ProviderEntity,
  1423. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1424. model: str | None = None,
  1425. ) -> list[ModelWithProviderEntity]:
  1426. """
  1427. Get custom provider models.
  1428. :param model_types: model types
  1429. :param provider_schema: provider schema
  1430. :param model_setting_map: model setting map
  1431. :return:
  1432. """
  1433. provider_models = []
  1434. credentials = None
  1435. if self.custom_configuration.provider:
  1436. credentials = self.custom_configuration.provider.credentials
  1437. for model_type in model_types:
  1438. if model_type not in self.provider.supported_model_types:
  1439. continue
  1440. for m in provider_schema.models:
  1441. if m.model_type != model_type:
  1442. continue
  1443. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  1444. load_balancing_enabled = False
  1445. has_invalid_load_balancing_configs = False
  1446. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1447. model_setting = model_setting_map[m.model_type][m.model]
  1448. if model_setting.enabled is False:
  1449. status = ModelStatus.DISABLED
  1450. provider_model_lb_configs = [
  1451. config
  1452. for config in model_setting.load_balancing_configs
  1453. if config.credential_source_type != "custom_model"
  1454. ]
  1455. load_balancing_enabled = model_setting.load_balancing_enabled
  1456. # when the user enable load_balancing but available configs are less than 2 display warning
  1457. has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
  1458. provider_models.append(
  1459. ModelWithProviderEntity(
  1460. model=m.model,
  1461. label=m.label,
  1462. model_type=m.model_type,
  1463. features=m.features,
  1464. fetch_from=m.fetch_from,
  1465. model_properties=m.model_properties,
  1466. deprecated=m.deprecated,
  1467. provider=SimpleModelProviderEntity(self.provider),
  1468. status=status,
  1469. load_balancing_enabled=load_balancing_enabled,
  1470. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1471. )
  1472. )
  1473. # custom models
  1474. for model_configuration in self.custom_configuration.models:
  1475. if model_configuration.model_type not in model_types:
  1476. continue
  1477. if model_configuration.unadded_to_model_list:
  1478. continue
  1479. if model and model != model_configuration.model:
  1480. continue
  1481. try:
  1482. custom_model_schema = self.get_model_schema(
  1483. model_type=model_configuration.model_type,
  1484. model=model_configuration.model,
  1485. credentials=model_configuration.credentials,
  1486. )
  1487. except Exception as ex:
  1488. logger.warning("get custom model schema failed, %s", ex)
  1489. continue
  1490. if not custom_model_schema:
  1491. continue
  1492. status = ModelStatus.ACTIVE
  1493. load_balancing_enabled = False
  1494. has_invalid_load_balancing_configs = False
  1495. if (
  1496. custom_model_schema.model_type in model_setting_map
  1497. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1498. ):
  1499. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1500. if model_setting.enabled is False:
  1501. status = ModelStatus.DISABLED
  1502. custom_model_lb_configs = [
  1503. config
  1504. for config in model_setting.load_balancing_configs
  1505. if config.credential_source_type != "provider"
  1506. ]
  1507. load_balancing_enabled = model_setting.load_balancing_enabled
  1508. # when the user enable load_balancing but available configs are less than 2 display warning
  1509. has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
  1510. if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
  1511. status = ModelStatus.CREDENTIAL_REMOVED
  1512. provider_models.append(
  1513. ModelWithProviderEntity(
  1514. model=custom_model_schema.model,
  1515. label=custom_model_schema.label,
  1516. model_type=custom_model_schema.model_type,
  1517. features=custom_model_schema.features,
  1518. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  1519. model_properties=custom_model_schema.model_properties,
  1520. deprecated=custom_model_schema.deprecated,
  1521. provider=SimpleModelProviderEntity(self.provider),
  1522. status=status,
  1523. load_balancing_enabled=load_balancing_enabled,
  1524. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1525. )
  1526. )
  1527. return provider_models
  1528. class ProviderConfigurations(BaseModel):
  1529. """
  1530. Model class for provider configuration dict.
  1531. """
  1532. tenant_id: str
  1533. configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
  1534. def __init__(self, tenant_id: str):
  1535. super().__init__(tenant_id=tenant_id)
  1536. def get_models(
  1537. self, provider: str | None = None, model_type: ModelType | None = None, only_active: bool = False
  1538. ) -> list[ModelWithProviderEntity]:
  1539. """
  1540. Get available models.
  1541. If preferred provider type is `system`:
  1542. Get the current **system mode** if provider supported,
  1543. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  1544. If there is no model configured in custom mode, it is treated as no_configure.
  1545. system > custom > no_configure
  1546. If preferred provider type is `custom`:
  1547. If custom credentials are configured, it is treated as custom mode.
  1548. Otherwise, get the current **system mode** if supported,
  1549. If all system modes are not available (no quota), it is treated as no_configure.
  1550. custom > system > no_configure
  1551. If real mode is `system`, use system credentials to get models,
  1552. paid quotas > provider free quotas > system free quotas
  1553. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  1554. If real mode is `custom`, use workspace custom credentials to get models,
  1555. include pre-defined models, custom models(manual append).
  1556. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  1557. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  1558. model status marked as `active` is available.
  1559. :param provider: provider name
  1560. :param model_type: model type
  1561. :param only_active: only active models
  1562. :return:
  1563. """
  1564. all_models = []
  1565. for provider_configuration in self.values():
  1566. if provider and provider_configuration.provider.provider != provider:
  1567. continue
  1568. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  1569. return all_models
  1570. def to_list(self) -> list[ProviderConfiguration]:
  1571. """
  1572. Convert to list.
  1573. :return:
  1574. """
  1575. return list(self.values())
  1576. def __getitem__(self, key):
  1577. if "/" not in key:
  1578. key = str(ModelProviderID(key))
  1579. return self.configurations[key]
  1580. def __setitem__(self, key, value):
  1581. self.configurations[key] = value
  1582. def __contains__(self, key):
  1583. if "/" not in key:
  1584. key = str(ModelProviderID(key))
  1585. return key in self.configurations
  1586. def __iter__(self):
  1587. # Return an iterator of (key, value) tuples to match BaseModel's __iter__
  1588. yield from self.configurations.items()
  1589. def values(self) -> Iterator[ProviderConfiguration]:
  1590. return iter(self.configurations.values())
  1591. def get(self, key, default=None) -> ProviderConfiguration | None:
  1592. if "/" not in key:
  1593. key = str(ModelProviderID(key))
  1594. return self.configurations.get(key, default) # type: ignore
  1595. class ProviderModelBundle(BaseModel):
  1596. """
  1597. Provider model bundle.
  1598. """
  1599. configuration: ProviderConfiguration
  1600. model_type_instance: AIModel
  1601. # pydantic configs
  1602. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())