Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

provider_configuration.py 77KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839
  1. import json
  2. import logging
  3. import re
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict, Field
  9. from sqlalchemy import func, select
  10. from sqlalchemy.orm import Session
  11. from constants import HIDDEN_VALUE
  12. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  13. from core.entities.provider_entities import (
  14. CustomConfiguration,
  15. ModelSettings,
  16. SystemConfiguration,
  17. SystemConfigurationStatus,
  18. )
  19. from core.helper import encrypter
  20. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  21. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  22. from core.model_runtime.entities.provider_entities import (
  23. ConfigurateMethod,
  24. CredentialFormSchema,
  25. FormType,
  26. ProviderEntity,
  27. )
  28. from core.model_runtime.model_providers.__base.ai_model import AIModel
  29. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  30. from core.plugin.entities.plugin import ModelProviderID
  31. from extensions.ext_database import db
  32. from libs.datetime_utils import naive_utc_now
  33. from models.provider import (
  34. LoadBalancingModelConfig,
  35. Provider,
  36. ProviderCredential,
  37. ProviderModel,
  38. ProviderModelCredential,
  39. ProviderModelSetting,
  40. ProviderType,
  41. TenantPreferredModelProvider,
  42. )
  43. logger = logging.getLogger(__name__)
  44. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  45. class ProviderConfiguration(BaseModel):
  46. """
  47. Provider configuration entity for managing model provider settings.
  48. This class handles:
  49. - Provider credentials CRUD and switch
  50. - Custom Model credentials CRUD and switch
  51. - System vs custom provider switching
  52. - Load balancing configurations
  53. - Model enablement/disablement
  54. TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
  55. """
  56. tenant_id: str
  57. provider: ProviderEntity
  58. preferred_provider_type: ProviderType
  59. using_provider_type: ProviderType
  60. system_configuration: SystemConfiguration
  61. custom_configuration: CustomConfiguration
  62. model_settings: list[ModelSettings]
  63. # pydantic configs
  64. model_config = ConfigDict(protected_namespaces=())
  65. def __init__(self, **data):
  66. super().__init__(**data)
  67. if self.provider.provider not in original_provider_configurate_methods:
  68. original_provider_configurate_methods[self.provider.provider] = []
  69. for configurate_method in self.provider.configurate_methods:
  70. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  71. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  72. if (
  73. any(
  74. len(quota_configuration.restrict_models) > 0
  75. for quota_configuration in self.system_configuration.quota_configurations
  76. )
  77. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  78. ):
  79. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  80. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  81. """
  82. Get current credentials.
  83. :param model_type: model type
  84. :param model: model name
  85. :return:
  86. """
  87. if self.model_settings:
  88. # check if model is disabled by admin
  89. for model_setting in self.model_settings:
  90. if model_setting.model_type == model_type and model_setting.model == model:
  91. if not model_setting.enabled:
  92. raise ValueError(f"Model {model} is disabled.")
  93. if self.using_provider_type == ProviderType.SYSTEM:
  94. restrict_models = []
  95. for quota_configuration in self.system_configuration.quota_configurations:
  96. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  97. continue
  98. restrict_models = quota_configuration.restrict_models
  99. copy_credentials = (
  100. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  101. )
  102. if restrict_models:
  103. for restrict_model in restrict_models:
  104. if (
  105. restrict_model.model_type == model_type
  106. and restrict_model.model == model
  107. and restrict_model.base_model_name
  108. ):
  109. copy_credentials["base_model_name"] = restrict_model.base_model_name
  110. return copy_credentials
  111. else:
  112. credentials = None
  113. if self.custom_configuration.models:
  114. for model_configuration in self.custom_configuration.models:
  115. if model_configuration.model_type == model_type and model_configuration.model == model:
  116. credentials = model_configuration.credentials
  117. break
  118. if not credentials and self.custom_configuration.provider:
  119. credentials = self.custom_configuration.provider.credentials
  120. return credentials
  121. def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
  122. """
  123. Get system configuration status.
  124. :return:
  125. """
  126. if self.system_configuration.enabled is False:
  127. return SystemConfigurationStatus.UNSUPPORTED
  128. current_quota_type = self.system_configuration.current_quota_type
  129. current_quota_configuration = next(
  130. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  131. )
  132. if current_quota_configuration is None:
  133. return None
  134. if not current_quota_configuration:
  135. return SystemConfigurationStatus.UNSUPPORTED
  136. return (
  137. SystemConfigurationStatus.ACTIVE
  138. if current_quota_configuration.is_valid
  139. else SystemConfigurationStatus.QUOTA_EXCEEDED
  140. )
  141. def is_custom_configuration_available(self) -> bool:
  142. """
  143. Check custom configuration available.
  144. :return:
  145. """
  146. has_provider_credentials = (
  147. self.custom_configuration.provider is not None
  148. and len(self.custom_configuration.provider.available_credentials) > 0
  149. )
  150. has_model_configurations = len(self.custom_configuration.models) > 0
  151. return has_provider_credentials or has_model_configurations
  152. def _get_provider_record(self, session: Session) -> Provider | None:
  153. """
  154. Get custom provider record.
  155. """
  156. # get provider
  157. model_provider_id = ModelProviderID(self.provider.provider)
  158. provider_names = [self.provider.provider]
  159. if model_provider_id.is_langgenius():
  160. provider_names.append(model_provider_id.provider_name)
  161. stmt = select(Provider).where(
  162. Provider.tenant_id == self.tenant_id,
  163. Provider.provider_type == ProviderType.CUSTOM.value,
  164. Provider.provider_name.in_(provider_names),
  165. )
  166. return session.execute(stmt).scalar_one_or_none()
  167. def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
  168. """
  169. Get a specific provider credential by ID.
  170. :param credential_id: Credential ID
  171. :return:
  172. """
  173. # Extract secret variables from provider credential schema
  174. credential_secret_variables = self.extract_secret_variables(
  175. self.provider.provider_credential_schema.credential_form_schemas
  176. if self.provider.provider_credential_schema
  177. else []
  178. )
  179. with Session(db.engine) as session:
  180. # Prefer the actual provider record name if exists (to handle aliased provider names)
  181. provider_record = self._get_provider_record(session)
  182. provider_name = provider_record.provider_name if provider_record else self.provider.provider
  183. stmt = select(ProviderCredential).where(
  184. ProviderCredential.id == credential_id,
  185. ProviderCredential.tenant_id == self.tenant_id,
  186. ProviderCredential.provider_name == provider_name,
  187. )
  188. credential = session.execute(stmt).scalar_one_or_none()
  189. if not credential or not credential.encrypted_config:
  190. raise ValueError(f"Credential with id {credential_id} not found.")
  191. try:
  192. credentials = json.loads(credential.encrypted_config)
  193. except JSONDecodeError:
  194. credentials = {}
  195. # Decrypt secret variables
  196. for key in credential_secret_variables:
  197. if key in credentials and credentials[key] is not None:
  198. try:
  199. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  200. except Exception:
  201. pass
  202. return self.obfuscated_credentials(
  203. credentials=credentials,
  204. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  205. if self.provider.provider_credential_schema
  206. else [],
  207. )
  208. def _check_provider_credential_name_exists(
  209. self, credential_name: str, session: Session, exclude_id: str | None = None
  210. ) -> bool:
  211. """
  212. not allowed same name when create or update a credential
  213. """
  214. stmt = select(ProviderCredential.id).where(
  215. ProviderCredential.tenant_id == self.tenant_id,
  216. ProviderCredential.provider_name == self.provider.provider,
  217. ProviderCredential.credential_name == credential_name,
  218. )
  219. if exclude_id:
  220. stmt = stmt.where(ProviderCredential.id != exclude_id)
  221. return session.execute(stmt).scalar_one_or_none() is not None
  222. def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
  223. """
  224. Get provider credentials.
  225. :param credential_id: if provided, return the specified credential
  226. :return:
  227. """
  228. if credential_id:
  229. return self._get_specific_provider_credential(credential_id)
  230. # Default behavior: return current active provider credentials
  231. credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {}
  232. return self.obfuscated_credentials(
  233. credentials=credentials,
  234. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  235. if self.provider.provider_credential_schema
  236. else [],
  237. )
  238. def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
  239. """
  240. Validate custom credentials.
  241. :param credentials: provider credentials
  242. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  243. :param session: optional database session
  244. :return:
  245. """
  246. def _validate(s: Session):
  247. # Get provider credential secret variables
  248. provider_credential_secret_variables = self.extract_secret_variables(
  249. self.provider.provider_credential_schema.credential_form_schemas
  250. if self.provider.provider_credential_schema
  251. else []
  252. )
  253. if credential_id:
  254. try:
  255. stmt = select(ProviderCredential).where(
  256. ProviderCredential.tenant_id == self.tenant_id,
  257. ProviderCredential.provider_name == self.provider.provider,
  258. ProviderCredential.id == credential_id,
  259. )
  260. credential_record = s.execute(stmt).scalar_one_or_none()
  261. # fix origin data
  262. if credential_record and credential_record.encrypted_config:
  263. if not credential_record.encrypted_config.startswith("{"):
  264. original_credentials = {"openai_api_key": credential_record.encrypted_config}
  265. else:
  266. original_credentials = json.loads(credential_record.encrypted_config)
  267. else:
  268. original_credentials = {}
  269. except JSONDecodeError:
  270. original_credentials = {}
  271. # encrypt credentials
  272. for key, value in credentials.items():
  273. if key in provider_credential_secret_variables:
  274. # if send [__HIDDEN__] in secret input, it will be same as original value
  275. if value == HIDDEN_VALUE and key in original_credentials:
  276. credentials[key] = encrypter.decrypt_token(
  277. tenant_id=self.tenant_id, token=original_credentials[key]
  278. )
  279. model_provider_factory = ModelProviderFactory(self.tenant_id)
  280. validated_credentials = model_provider_factory.provider_credentials_validate(
  281. provider=self.provider.provider, credentials=credentials
  282. )
  283. for key, value in validated_credentials.items():
  284. if key in provider_credential_secret_variables:
  285. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  286. return validated_credentials
  287. if session:
  288. return _validate(session)
  289. else:
  290. with Session(db.engine) as new_session:
  291. return _validate(new_session)
  292. def _generate_provider_credential_name(self, session) -> str:
  293. """
  294. Generate a unique credential name for provider.
  295. :return: credential name
  296. """
  297. return self._generate_next_api_key_name(
  298. session=session,
  299. query_factory=lambda: select(ProviderCredential).where(
  300. ProviderCredential.tenant_id == self.tenant_id,
  301. ProviderCredential.provider_name == self.provider.provider,
  302. ),
  303. )
  304. def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
  305. """
  306. Generate a unique credential name for custom model.
  307. :return: credential name
  308. """
  309. return self._generate_next_api_key_name(
  310. session=session,
  311. query_factory=lambda: select(ProviderModelCredential).where(
  312. ProviderModelCredential.tenant_id == self.tenant_id,
  313. ProviderModelCredential.provider_name == self.provider.provider,
  314. ProviderModelCredential.model_name == model,
  315. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  316. ),
  317. )
  318. def _generate_next_api_key_name(self, session, query_factory) -> str:
  319. """
  320. Generate next available API KEY name by finding the highest numbered suffix.
  321. :param session: database session
  322. :param query_factory: function that returns the SQLAlchemy query
  323. :return: next available API KEY name
  324. """
  325. try:
  326. stmt = query_factory()
  327. credential_records = session.execute(stmt).scalars().all()
  328. if not credential_records:
  329. return "API KEY 1"
  330. # Extract numbers from API KEY pattern using list comprehension
  331. pattern = re.compile(r"^API KEY\s+(\d+)$")
  332. numbers = [
  333. int(match.group(1))
  334. for cr in credential_records
  335. if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
  336. ]
  337. # Return next sequential number
  338. next_number = max(numbers, default=0) + 1
  339. return f"API KEY {next_number}"
  340. except Exception as e:
  341. logger.warning("Error generating next credential name: %s", str(e))
  342. return "API KEY 1"
  343. def create_provider_credential(self, credentials: dict, credential_name: str | None):
  344. """
  345. Add custom provider credentials.
  346. :param credentials: provider credentials
  347. :param credential_name: credential name
  348. :return:
  349. """
  350. with Session(db.engine) as session:
  351. if credential_name:
  352. if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
  353. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  354. else:
  355. credential_name = self._generate_provider_credential_name(session)
  356. credentials = self.validate_provider_credentials(credentials=credentials, session=session)
  357. provider_record = self._get_provider_record(session)
  358. try:
  359. new_record = ProviderCredential(
  360. tenant_id=self.tenant_id,
  361. provider_name=self.provider.provider,
  362. encrypted_config=json.dumps(credentials),
  363. credential_name=credential_name,
  364. )
  365. session.add(new_record)
  366. session.flush()
  367. if not provider_record:
  368. # If provider record does not exist, create it
  369. provider_record = Provider(
  370. tenant_id=self.tenant_id,
  371. provider_name=self.provider.provider,
  372. provider_type=ProviderType.CUSTOM.value,
  373. is_valid=True,
  374. credential_id=new_record.id,
  375. )
  376. session.add(provider_record)
  377. provider_model_credentials_cache = ProviderCredentialsCache(
  378. tenant_id=self.tenant_id,
  379. identity_id=provider_record.id,
  380. cache_type=ProviderCredentialsCacheType.PROVIDER,
  381. )
  382. provider_model_credentials_cache.delete()
  383. self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
  384. session.commit()
  385. except Exception:
  386. session.rollback()
  387. raise
  388. def update_provider_credential(
  389. self,
  390. credentials: dict,
  391. credential_id: str,
  392. credential_name: str | None,
  393. ):
  394. """
  395. update a saved provider credential (by credential_id).
  396. :param credentials: provider credentials
  397. :param credential_id: credential id
  398. :param credential_name: credential name
  399. :return:
  400. """
  401. with Session(db.engine) as session:
  402. if credential_name and self._check_provider_credential_name_exists(
  403. credential_name=credential_name, session=session, exclude_id=credential_id
  404. ):
  405. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  406. credentials = self.validate_provider_credentials(
  407. credentials=credentials, credential_id=credential_id, session=session
  408. )
  409. provider_record = self._get_provider_record(session)
  410. stmt = select(ProviderCredential).where(
  411. ProviderCredential.id == credential_id,
  412. ProviderCredential.tenant_id == self.tenant_id,
  413. ProviderCredential.provider_name == self.provider.provider,
  414. )
  415. # Get the credential record to update
  416. credential_record = session.execute(stmt).scalar_one_or_none()
  417. if not credential_record:
  418. raise ValueError("Credential record not found.")
  419. try:
  420. # Update credential
  421. credential_record.encrypted_config = json.dumps(credentials)
  422. credential_record.updated_at = naive_utc_now()
  423. if credential_name:
  424. credential_record.credential_name = credential_name
  425. session.commit()
  426. if provider_record and provider_record.credential_id == credential_id:
  427. provider_model_credentials_cache = ProviderCredentialsCache(
  428. tenant_id=self.tenant_id,
  429. identity_id=provider_record.id,
  430. cache_type=ProviderCredentialsCacheType.PROVIDER,
  431. )
  432. provider_model_credentials_cache.delete()
  433. self._update_load_balancing_configs_with_credential(
  434. credential_id=credential_id,
  435. credential_record=credential_record,
  436. credential_source="provider",
  437. session=session,
  438. )
  439. except Exception:
  440. session.rollback()
  441. raise
  442. def _update_load_balancing_configs_with_credential(
  443. self,
  444. credential_id: str,
  445. credential_record: ProviderCredential | ProviderModelCredential,
  446. credential_source: str,
  447. session: Session,
  448. ):
  449. """
  450. Update load balancing configurations that reference the given credential_id.
  451. :param credential_id: credential id
  452. :param credential_record: the encrypted_config and credential_name
  453. :param credential_source: the credential comes from the provider_credential(`provider`)
  454. or the provider_model_credential(`custom_model`)
  455. :param session: the database session
  456. :return:
  457. """
  458. # Find all load balancing configs that use this credential_id
  459. stmt = select(LoadBalancingModelConfig).where(
  460. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  461. LoadBalancingModelConfig.provider_name == self.provider.provider,
  462. LoadBalancingModelConfig.credential_id == credential_id,
  463. LoadBalancingModelConfig.credential_source_type == credential_source,
  464. )
  465. load_balancing_configs = session.execute(stmt).scalars().all()
  466. if not load_balancing_configs:
  467. return
  468. # Update each load balancing config with the new credentials
  469. for lb_config in load_balancing_configs:
  470. # Update the encrypted_config with the new credentials
  471. lb_config.encrypted_config = credential_record.encrypted_config
  472. lb_config.name = credential_record.credential_name
  473. lb_config.updated_at = naive_utc_now()
  474. # Clear cache for this load balancing config
  475. lb_credentials_cache = ProviderCredentialsCache(
  476. tenant_id=self.tenant_id,
  477. identity_id=lb_config.id,
  478. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  479. )
  480. lb_credentials_cache.delete()
  481. session.commit()
  482. def delete_provider_credential(self, credential_id: str):
  483. """
  484. Delete a saved provider credential (by credential_id).
  485. :param credential_id: credential id
  486. :return:
  487. """
  488. with Session(db.engine) as session:
  489. stmt = select(ProviderCredential).where(
  490. ProviderCredential.id == credential_id,
  491. ProviderCredential.tenant_id == self.tenant_id,
  492. ProviderCredential.provider_name == self.provider.provider,
  493. )
  494. # Get the credential record to update
  495. credential_record = session.execute(stmt).scalar_one_or_none()
  496. if not credential_record:
  497. raise ValueError("Credential record not found.")
  498. # Check if this credential is used in load balancing configs
  499. lb_stmt = select(LoadBalancingModelConfig).where(
  500. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  501. LoadBalancingModelConfig.provider_name == self.provider.provider,
  502. LoadBalancingModelConfig.credential_id == credential_id,
  503. LoadBalancingModelConfig.credential_source_type == "provider",
  504. )
  505. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  506. try:
  507. for lb_config in lb_configs_using_credential:
  508. lb_credentials_cache = ProviderCredentialsCache(
  509. tenant_id=self.tenant_id,
  510. identity_id=lb_config.id,
  511. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  512. )
  513. lb_credentials_cache.delete()
  514. session.delete(lb_config)
  515. # Check if this is the currently active credential
  516. provider_record = self._get_provider_record(session)
  517. # Check available credentials count BEFORE deleting
  518. # if this is the last credential, we need to delete the provider record
  519. count_stmt = select(func.count(ProviderCredential.id)).where(
  520. ProviderCredential.tenant_id == self.tenant_id,
  521. ProviderCredential.provider_name == self.provider.provider,
  522. )
  523. available_credentials_count = session.execute(count_stmt).scalar() or 0
  524. session.delete(credential_record)
  525. if provider_record and available_credentials_count <= 1:
  526. # If all credentials are deleted, delete the provider record
  527. session.delete(provider_record)
  528. provider_model_credentials_cache = ProviderCredentialsCache(
  529. tenant_id=self.tenant_id,
  530. identity_id=provider_record.id,
  531. cache_type=ProviderCredentialsCacheType.PROVIDER,
  532. )
  533. provider_model_credentials_cache.delete()
  534. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  535. elif provider_record and provider_record.credential_id == credential_id:
  536. provider_record.credential_id = None
  537. provider_record.updated_at = naive_utc_now()
  538. provider_model_credentials_cache = ProviderCredentialsCache(
  539. tenant_id=self.tenant_id,
  540. identity_id=provider_record.id,
  541. cache_type=ProviderCredentialsCacheType.PROVIDER,
  542. )
  543. provider_model_credentials_cache.delete()
  544. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  545. session.commit()
  546. except Exception:
  547. session.rollback()
  548. raise
  549. def switch_active_provider_credential(self, credential_id: str):
  550. """
  551. Switch active provider credential (copy the selected one into current active snapshot).
  552. :param credential_id: credential id
  553. :return:
  554. """
  555. with Session(db.engine) as session:
  556. stmt = select(ProviderCredential).where(
  557. ProviderCredential.id == credential_id,
  558. ProviderCredential.tenant_id == self.tenant_id,
  559. ProviderCredential.provider_name == self.provider.provider,
  560. )
  561. credential_record = session.execute(stmt).scalar_one_or_none()
  562. if not credential_record:
  563. raise ValueError("Credential record not found.")
  564. provider_record = self._get_provider_record(session)
  565. if not provider_record:
  566. raise ValueError("Provider record not found.")
  567. try:
  568. provider_record.credential_id = credential_record.id
  569. provider_record.updated_at = naive_utc_now()
  570. session.commit()
  571. provider_model_credentials_cache = ProviderCredentialsCache(
  572. tenant_id=self.tenant_id,
  573. identity_id=provider_record.id,
  574. cache_type=ProviderCredentialsCacheType.PROVIDER,
  575. )
  576. provider_model_credentials_cache.delete()
  577. self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session)
  578. except Exception:
  579. session.rollback()
  580. raise
  581. def _get_custom_model_record(
  582. self,
  583. model_type: ModelType,
  584. model: str,
  585. session: Session,
  586. ) -> ProviderModel | None:
  587. """
  588. Get custom model credentials.
  589. """
  590. # get provider model
  591. model_provider_id = ModelProviderID(self.provider.provider)
  592. provider_names = [self.provider.provider]
  593. if model_provider_id.is_langgenius():
  594. provider_names.append(model_provider_id.provider_name)
  595. stmt = select(ProviderModel).where(
  596. ProviderModel.tenant_id == self.tenant_id,
  597. ProviderModel.provider_name.in_(provider_names),
  598. ProviderModel.model_name == model,
  599. ProviderModel.model_type == model_type.to_origin_model_type(),
  600. )
  601. return session.execute(stmt).scalar_one_or_none()
  602. def _get_specific_custom_model_credential(
  603. self, model_type: ModelType, model: str, credential_id: str
  604. ) -> dict | None:
  605. """
  606. Get a specific provider credential by ID.
  607. :param credential_id: Credential ID
  608. :return:
  609. """
  610. model_credential_secret_variables = self.extract_secret_variables(
  611. self.provider.model_credential_schema.credential_form_schemas
  612. if self.provider.model_credential_schema
  613. else []
  614. )
  615. with Session(db.engine) as session:
  616. stmt = select(ProviderModelCredential).where(
  617. ProviderModelCredential.id == credential_id,
  618. ProviderModelCredential.tenant_id == self.tenant_id,
  619. ProviderModelCredential.provider_name == self.provider.provider,
  620. ProviderModelCredential.model_name == model,
  621. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  622. )
  623. credential_record = session.execute(stmt).scalar_one_or_none()
  624. if not credential_record or not credential_record.encrypted_config:
  625. raise ValueError(f"Credential with id {credential_id} not found.")
  626. try:
  627. credentials = json.loads(credential_record.encrypted_config)
  628. except JSONDecodeError:
  629. credentials = {}
  630. # Decrypt secret variables
  631. for key in model_credential_secret_variables:
  632. if key in credentials and credentials[key] is not None:
  633. try:
  634. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  635. except Exception:
  636. pass
  637. current_credential_id = credential_record.id
  638. current_credential_name = credential_record.credential_name
  639. credentials = self.obfuscated_credentials(
  640. credentials=credentials,
  641. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  642. if self.provider.model_credential_schema
  643. else [],
  644. )
  645. return {
  646. "current_credential_id": current_credential_id,
  647. "current_credential_name": current_credential_name,
  648. "credentials": credentials,
  649. }
  650. def _check_custom_model_credential_name_exists(
  651. self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None
  652. ) -> bool:
  653. """
  654. not allowed same name when create or update a credential
  655. """
  656. stmt = select(ProviderModelCredential).where(
  657. ProviderModelCredential.tenant_id == self.tenant_id,
  658. ProviderModelCredential.provider_name == self.provider.provider,
  659. ProviderModelCredential.model_name == model,
  660. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  661. ProviderModelCredential.credential_name == credential_name,
  662. )
  663. if exclude_id:
  664. stmt = stmt.where(ProviderModelCredential.id != exclude_id)
  665. return session.execute(stmt).scalar_one_or_none() is not None
  666. def get_custom_model_credential(
  667. self, model_type: ModelType, model: str, credential_id: str | None
  668. ) -> Optional[dict]:
  669. """
  670. Get custom model credentials.
  671. :param model_type: model type
  672. :param model: model name
  673. :return:
  674. """
  675. # If credential_id is provided, return the specific credential
  676. if credential_id:
  677. return self._get_specific_custom_model_credential(
  678. model_type=model_type, model=model, credential_id=credential_id
  679. )
  680. for model_configuration in self.custom_configuration.models:
  681. if (
  682. model_configuration.model_type == model_type
  683. and model_configuration.model == model
  684. and model_configuration.credentials
  685. ):
  686. current_credential_id = model_configuration.current_credential_id
  687. current_credential_name = model_configuration.current_credential_name
  688. credentials = self.obfuscated_credentials(
  689. credentials=model_configuration.credentials,
  690. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  691. if self.provider.model_credential_schema
  692. else [],
  693. )
  694. return {
  695. "current_credential_id": current_credential_id,
  696. "current_credential_name": current_credential_name,
  697. "credentials": credentials,
  698. }
  699. return None
  700. def validate_custom_model_credentials(
  701. self,
  702. model_type: ModelType,
  703. model: str,
  704. credentials: dict,
  705. credential_id: str = "",
  706. session: Session | None = None,
  707. ):
  708. """
  709. Validate custom model credentials.
  710. :param model_type: model type
  711. :param model: model name
  712. :param credentials: model credentials dict
  713. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  714. :return:
  715. """
  716. def _validate(s: Session):
  717. # Get provider credential secret variables
  718. provider_credential_secret_variables = self.extract_secret_variables(
  719. self.provider.model_credential_schema.credential_form_schemas
  720. if self.provider.model_credential_schema
  721. else []
  722. )
  723. if credential_id:
  724. try:
  725. stmt = select(ProviderModelCredential).where(
  726. ProviderModelCredential.id == credential_id,
  727. ProviderModelCredential.tenant_id == self.tenant_id,
  728. ProviderModelCredential.provider_name == self.provider.provider,
  729. ProviderModelCredential.model_name == model,
  730. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  731. )
  732. credential_record = s.execute(stmt).scalar_one_or_none()
  733. original_credentials = (
  734. json.loads(credential_record.encrypted_config)
  735. if credential_record and credential_record.encrypted_config
  736. else {}
  737. )
  738. except JSONDecodeError:
  739. original_credentials = {}
  740. # decrypt credentials
  741. for key, value in credentials.items():
  742. if key in provider_credential_secret_variables:
  743. # if send [__HIDDEN__] in secret input, it will be same as original value
  744. if value == HIDDEN_VALUE and key in original_credentials:
  745. credentials[key] = encrypter.decrypt_token(
  746. tenant_id=self.tenant_id, token=original_credentials[key]
  747. )
  748. model_provider_factory = ModelProviderFactory(self.tenant_id)
  749. validated_credentials = model_provider_factory.model_credentials_validate(
  750. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  751. )
  752. for key, value in validated_credentials.items():
  753. if key in provider_credential_secret_variables:
  754. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  755. return validated_credentials
  756. if session:
  757. return _validate(session)
  758. else:
  759. with Session(db.engine) as new_session:
  760. return _validate(new_session)
  761. def create_custom_model_credential(
  762. self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
  763. ) -> None:
  764. """
  765. Create a custom model credential.
  766. :param model_type: model type
  767. :param model: model name
  768. :param credentials: model credentials dict
  769. :return:
  770. """
  771. with Session(db.engine) as session:
  772. if credential_name:
  773. if self._check_custom_model_credential_name_exists(
  774. model=model, model_type=model_type, credential_name=credential_name, session=session
  775. ):
  776. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  777. else:
  778. credential_name = self._generate_custom_model_credential_name(
  779. model=model, model_type=model_type, session=session
  780. )
  781. # validate custom model config
  782. credentials = self.validate_custom_model_credentials(
  783. model_type=model_type, model=model, credentials=credentials, session=session
  784. )
  785. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  786. try:
  787. credential = ProviderModelCredential(
  788. tenant_id=self.tenant_id,
  789. provider_name=self.provider.provider,
  790. model_name=model,
  791. model_type=model_type.to_origin_model_type(),
  792. encrypted_config=json.dumps(credentials),
  793. credential_name=credential_name,
  794. )
  795. session.add(credential)
  796. session.flush()
  797. # save provider model
  798. if not provider_model_record:
  799. provider_model_record = ProviderModel(
  800. tenant_id=self.tenant_id,
  801. provider_name=self.provider.provider,
  802. model_name=model,
  803. model_type=model_type.to_origin_model_type(),
  804. credential_id=credential.id,
  805. is_valid=True,
  806. )
  807. session.add(provider_model_record)
  808. session.commit()
  809. provider_model_credentials_cache = ProviderCredentialsCache(
  810. tenant_id=self.tenant_id,
  811. identity_id=provider_model_record.id,
  812. cache_type=ProviderCredentialsCacheType.MODEL,
  813. )
  814. provider_model_credentials_cache.delete()
  815. except Exception:
  816. session.rollback()
  817. raise
  818. def update_custom_model_credential(
  819. self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
  820. ) -> None:
  821. """
  822. Update a custom model credential.
  823. :param model_type: model type
  824. :param model: model name
  825. :param credentials: model credentials dict
  826. :param credential_name: credential name
  827. :param credential_id: credential id
  828. :return:
  829. """
  830. with Session(db.engine) as session:
  831. if credential_name and self._check_custom_model_credential_name_exists(
  832. model=model,
  833. model_type=model_type,
  834. credential_name=credential_name,
  835. session=session,
  836. exclude_id=credential_id,
  837. ):
  838. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  839. # validate custom model config
  840. credentials = self.validate_custom_model_credentials(
  841. model_type=model_type,
  842. model=model,
  843. credentials=credentials,
  844. credential_id=credential_id,
  845. session=session,
  846. )
  847. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  848. stmt = select(ProviderModelCredential).where(
  849. ProviderModelCredential.id == credential_id,
  850. ProviderModelCredential.tenant_id == self.tenant_id,
  851. ProviderModelCredential.provider_name == self.provider.provider,
  852. ProviderModelCredential.model_name == model,
  853. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  854. )
  855. credential_record = session.execute(stmt).scalar_one_or_none()
  856. if not credential_record:
  857. raise ValueError("Credential record not found.")
  858. try:
  859. # Update credential
  860. credential_record.encrypted_config = json.dumps(credentials)
  861. credential_record.updated_at = naive_utc_now()
  862. if credential_name:
  863. credential_record.credential_name = credential_name
  864. session.commit()
  865. if provider_model_record and provider_model_record.credential_id == credential_id:
  866. provider_model_credentials_cache = ProviderCredentialsCache(
  867. tenant_id=self.tenant_id,
  868. identity_id=provider_model_record.id,
  869. cache_type=ProviderCredentialsCacheType.MODEL,
  870. )
  871. provider_model_credentials_cache.delete()
  872. self._update_load_balancing_configs_with_credential(
  873. credential_id=credential_id,
  874. credential_record=credential_record,
  875. credential_source="custom_model",
  876. session=session,
  877. )
  878. except Exception:
  879. session.rollback()
  880. raise
  881. def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
  882. """
  883. Delete a saved provider credential (by credential_id).
  884. :param credential_id: credential id
  885. :return:
  886. """
  887. with Session(db.engine) as session:
  888. stmt = select(ProviderModelCredential).where(
  889. ProviderModelCredential.id == credential_id,
  890. ProviderModelCredential.tenant_id == self.tenant_id,
  891. ProviderModelCredential.provider_name == self.provider.provider,
  892. ProviderModelCredential.model_name == model,
  893. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  894. )
  895. credential_record = session.execute(stmt).scalar_one_or_none()
  896. if not credential_record:
  897. raise ValueError("Credential record not found.")
  898. lb_stmt = select(LoadBalancingModelConfig).where(
  899. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  900. LoadBalancingModelConfig.provider_name == self.provider.provider,
  901. LoadBalancingModelConfig.credential_id == credential_id,
  902. LoadBalancingModelConfig.credential_source_type == "custom_model",
  903. )
  904. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  905. try:
  906. for lb_config in lb_configs_using_credential:
  907. lb_credentials_cache = ProviderCredentialsCache(
  908. tenant_id=self.tenant_id,
  909. identity_id=lb_config.id,
  910. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  911. )
  912. lb_credentials_cache.delete()
  913. session.delete(lb_config)
  914. # Check if this is the currently active credential
  915. provider_model_record = self._get_custom_model_record(model_type, model, session=session)
  916. # Check available credentials count BEFORE deleting
  917. # if this is the last credential, we need to delete the custom model record
  918. count_stmt = select(func.count(ProviderModelCredential.id)).where(
  919. ProviderModelCredential.tenant_id == self.tenant_id,
  920. ProviderModelCredential.provider_name == self.provider.provider,
  921. ProviderModelCredential.model_name == model,
  922. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  923. )
  924. available_credentials_count = session.execute(count_stmt).scalar() or 0
  925. session.delete(credential_record)
  926. if provider_model_record and available_credentials_count <= 1:
  927. # If all credentials are deleted, delete the custom model record
  928. session.delete(provider_model_record)
  929. elif provider_model_record and provider_model_record.credential_id == credential_id:
  930. provider_model_record.credential_id = None
  931. provider_model_record.updated_at = naive_utc_now()
  932. provider_model_credentials_cache = ProviderCredentialsCache(
  933. tenant_id=self.tenant_id,
  934. identity_id=provider_model_record.id,
  935. cache_type=ProviderCredentialsCacheType.PROVIDER,
  936. )
  937. provider_model_credentials_cache.delete()
  938. session.commit()
  939. except Exception:
  940. session.rollback()
  941. raise
  942. def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
  943. """
  944. if model list exist this custom model, switch the custom model credential.
  945. if model list not exist this custom model, use the credential to add a new custom model record.
  946. :param model_type: model type
  947. :param model: model name
  948. :param credential_id: credential id
  949. :return:
  950. """
  951. with Session(db.engine) as session:
  952. stmt = select(ProviderModelCredential).where(
  953. ProviderModelCredential.id == credential_id,
  954. ProviderModelCredential.tenant_id == self.tenant_id,
  955. ProviderModelCredential.provider_name == self.provider.provider,
  956. ProviderModelCredential.model_name == model,
  957. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  958. )
  959. credential_record = session.execute(stmt).scalar_one_or_none()
  960. if not credential_record:
  961. raise ValueError("Credential record not found.")
  962. # validate custom model config
  963. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  964. if not provider_model_record:
  965. # create provider model record
  966. provider_model_record = ProviderModel(
  967. tenant_id=self.tenant_id,
  968. provider_name=self.provider.provider,
  969. model_name=model,
  970. model_type=model_type.to_origin_model_type(),
  971. is_valid=True,
  972. credential_id=credential_id,
  973. )
  974. else:
  975. if provider_model_record.credential_id == credential_record.id:
  976. raise ValueError("Can't add same credential")
  977. provider_model_record.credential_id = credential_record.id
  978. provider_model_record.updated_at = naive_utc_now()
  979. session.add(provider_model_record)
  980. session.commit()
  981. def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
  982. """
  983. switch the custom model credential.
  984. :param model_type: model type
  985. :param model: model name
  986. :param credential_id: credential id
  987. :return:
  988. """
  989. with Session(db.engine) as session:
  990. stmt = select(ProviderModelCredential).where(
  991. ProviderModelCredential.id == credential_id,
  992. ProviderModelCredential.tenant_id == self.tenant_id,
  993. ProviderModelCredential.provider_name == self.provider.provider,
  994. ProviderModelCredential.model_name == model,
  995. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  996. )
  997. credential_record = session.execute(stmt).scalar_one_or_none()
  998. if not credential_record:
  999. raise ValueError("Credential record not found.")
  1000. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  1001. if not provider_model_record:
  1002. raise ValueError("The custom model record not found.")
  1003. provider_model_record.credential_id = credential_record.id
  1004. provider_model_record.updated_at = naive_utc_now()
  1005. session.add(provider_model_record)
  1006. session.commit()
  1007. def delete_custom_model(self, model_type: ModelType, model: str):
  1008. """
  1009. Delete custom model.
  1010. :param model_type: model type
  1011. :param model: model name
  1012. :return:
  1013. """
  1014. with Session(db.engine) as session:
  1015. # get provider model
  1016. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  1017. # delete provider model
  1018. if provider_model_record:
  1019. session.delete(provider_model_record)
  1020. session.commit()
  1021. provider_model_credentials_cache = ProviderCredentialsCache(
  1022. tenant_id=self.tenant_id,
  1023. identity_id=provider_model_record.id,
  1024. cache_type=ProviderCredentialsCacheType.MODEL,
  1025. )
  1026. provider_model_credentials_cache.delete()
  1027. def _get_provider_model_setting(
  1028. self, model_type: ModelType, model: str, session: Session
  1029. ) -> ProviderModelSetting | None:
  1030. """
  1031. Get provider model setting.
  1032. """
  1033. model_provider_id = ModelProviderID(self.provider.provider)
  1034. provider_names = [self.provider.provider]
  1035. if model_provider_id.is_langgenius():
  1036. provider_names.append(model_provider_id.provider_name)
  1037. stmt = select(ProviderModelSetting).where(
  1038. ProviderModelSetting.tenant_id == self.tenant_id,
  1039. ProviderModelSetting.provider_name.in_(provider_names),
  1040. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  1041. ProviderModelSetting.model_name == model,
  1042. )
  1043. return session.execute(stmt).scalars().first()
  1044. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1045. """
  1046. Enable model.
  1047. :param model_type: model type
  1048. :param model: model name
  1049. :return:
  1050. """
  1051. with Session(db.engine) as session:
  1052. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1053. if model_setting:
  1054. model_setting.enabled = True
  1055. model_setting.updated_at = naive_utc_now()
  1056. else:
  1057. model_setting = ProviderModelSetting(
  1058. tenant_id=self.tenant_id,
  1059. provider_name=self.provider.provider,
  1060. model_type=model_type.to_origin_model_type(),
  1061. model_name=model,
  1062. enabled=True,
  1063. )
  1064. session.add(model_setting)
  1065. session.commit()
  1066. return model_setting
  1067. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1068. """
  1069. Disable model.
  1070. :param model_type: model type
  1071. :param model: model name
  1072. :return:
  1073. """
  1074. with Session(db.engine) as session:
  1075. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1076. if model_setting:
  1077. model_setting.enabled = False
  1078. model_setting.updated_at = naive_utc_now()
  1079. else:
  1080. model_setting = ProviderModelSetting(
  1081. tenant_id=self.tenant_id,
  1082. provider_name=self.provider.provider,
  1083. model_type=model_type.to_origin_model_type(),
  1084. model_name=model,
  1085. enabled=False,
  1086. )
  1087. session.add(model_setting)
  1088. session.commit()
  1089. return model_setting
  1090. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  1091. """
  1092. Get provider model setting.
  1093. :param model_type: model type
  1094. :param model: model name
  1095. :return:
  1096. """
  1097. with Session(db.engine) as session:
  1098. return self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1099. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1100. """
  1101. Enable model load balancing.
  1102. :param model_type: model type
  1103. :param model: model name
  1104. :return:
  1105. """
  1106. model_provider_id = ModelProviderID(self.provider.provider)
  1107. provider_names = [self.provider.provider]
  1108. if model_provider_id.is_langgenius():
  1109. provider_names.append(model_provider_id.provider_name)
  1110. with Session(db.engine) as session:
  1111. stmt = select(func.count(LoadBalancingModelConfig.id)).where(
  1112. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  1113. LoadBalancingModelConfig.provider_name.in_(provider_names),
  1114. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  1115. LoadBalancingModelConfig.model_name == model,
  1116. )
  1117. load_balancing_config_count = session.execute(stmt).scalar() or 0
  1118. if load_balancing_config_count <= 1:
  1119. raise ValueError("Model load balancing configuration must be more than 1.")
  1120. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1121. if model_setting:
  1122. model_setting.load_balancing_enabled = True
  1123. model_setting.updated_at = naive_utc_now()
  1124. else:
  1125. model_setting = ProviderModelSetting(
  1126. tenant_id=self.tenant_id,
  1127. provider_name=self.provider.provider,
  1128. model_type=model_type.to_origin_model_type(),
  1129. model_name=model,
  1130. load_balancing_enabled=True,
  1131. )
  1132. session.add(model_setting)
  1133. session.commit()
  1134. return model_setting
  1135. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1136. """
  1137. Disable model load balancing.
  1138. :param model_type: model type
  1139. :param model: model name
  1140. :return:
  1141. """
  1142. with Session(db.engine) as session:
  1143. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1144. if model_setting:
  1145. model_setting.load_balancing_enabled = False
  1146. model_setting.updated_at = naive_utc_now()
  1147. else:
  1148. model_setting = ProviderModelSetting(
  1149. tenant_id=self.tenant_id,
  1150. provider_name=self.provider.provider,
  1151. model_type=model_type.to_origin_model_type(),
  1152. model_name=model,
  1153. load_balancing_enabled=False,
  1154. )
  1155. session.add(model_setting)
  1156. session.commit()
  1157. return model_setting
  1158. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  1159. """
  1160. Get current model type instance.
  1161. :param model_type: model type
  1162. :return:
  1163. """
  1164. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1165. # Get model instance of LLM
  1166. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  1167. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
  1168. """
  1169. Get model schema
  1170. """
  1171. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1172. return model_provider_factory.get_model_schema(
  1173. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  1174. )
  1175. def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
  1176. """
  1177. Switch preferred provider type.
  1178. :param provider_type:
  1179. :return:
  1180. """
  1181. if provider_type == self.preferred_provider_type:
  1182. return
  1183. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  1184. return
  1185. def _switch(s: Session):
  1186. # get preferred provider
  1187. model_provider_id = ModelProviderID(self.provider.provider)
  1188. provider_names = [self.provider.provider]
  1189. if model_provider_id.is_langgenius():
  1190. provider_names.append(model_provider_id.provider_name)
  1191. stmt = select(TenantPreferredModelProvider).where(
  1192. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  1193. TenantPreferredModelProvider.provider_name.in_(provider_names),
  1194. )
  1195. preferred_model_provider = s.execute(stmt).scalars().first()
  1196. if preferred_model_provider:
  1197. preferred_model_provider.preferred_provider_type = provider_type.value
  1198. else:
  1199. preferred_model_provider = TenantPreferredModelProvider(
  1200. tenant_id=self.tenant_id,
  1201. provider_name=self.provider.provider,
  1202. preferred_provider_type=provider_type.value,
  1203. )
  1204. s.add(preferred_model_provider)
  1205. s.commit()
  1206. if session:
  1207. return _switch(session)
  1208. else:
  1209. with Session(db.engine) as session:
  1210. return _switch(session)
  1211. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  1212. """
  1213. Extract secret input form variables.
  1214. :param credential_form_schemas:
  1215. :return:
  1216. """
  1217. secret_input_form_variables = []
  1218. for credential_form_schema in credential_form_schemas:
  1219. if credential_form_schema.type == FormType.SECRET_INPUT:
  1220. secret_input_form_variables.append(credential_form_schema.variable)
  1221. return secret_input_form_variables
  1222. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
  1223. """
  1224. Obfuscated credentials.
  1225. :param credentials: credentials
  1226. :param credential_form_schemas: credential form schemas
  1227. :return:
  1228. """
  1229. # Get provider credential secret variables
  1230. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  1231. # Obfuscate provider credentials
  1232. copy_credentials = credentials.copy()
  1233. for key, value in copy_credentials.items():
  1234. if key in credential_secret_variables:
  1235. copy_credentials[key] = encrypter.obfuscated_token(value)
  1236. return copy_credentials
  1237. def get_provider_model(
  1238. self, model_type: ModelType, model: str, only_active: bool = False
  1239. ) -> Optional[ModelWithProviderEntity]:
  1240. """
  1241. Get provider model.
  1242. :param model_type: model type
  1243. :param model: model name
  1244. :param only_active: return active model only
  1245. :return:
  1246. """
  1247. provider_models = self.get_provider_models(model_type, only_active, model)
  1248. for provider_model in provider_models:
  1249. if provider_model.model == model:
  1250. return provider_model
  1251. return None
  1252. def get_provider_models(
  1253. self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
  1254. ) -> list[ModelWithProviderEntity]:
  1255. """
  1256. Get provider models.
  1257. :param model_type: model type
  1258. :param only_active: only active models
  1259. :param model: model name
  1260. :return:
  1261. """
  1262. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1263. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  1264. model_types: list[ModelType] = []
  1265. if model_type:
  1266. model_types.append(model_type)
  1267. else:
  1268. model_types = list(provider_schema.supported_model_types)
  1269. # Group model settings by model type and model
  1270. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  1271. for model_setting in self.model_settings:
  1272. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  1273. if self.using_provider_type == ProviderType.SYSTEM:
  1274. provider_models = self._get_system_provider_models(
  1275. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  1276. )
  1277. else:
  1278. provider_models = self._get_custom_provider_models(
  1279. model_types=model_types,
  1280. provider_schema=provider_schema,
  1281. model_setting_map=model_setting_map,
  1282. model=model,
  1283. )
  1284. if only_active:
  1285. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  1286. # resort provider_models
  1287. # Optimize sorting logic: first sort by provider.position order, then by model_type.value
  1288. # Get the position list for model types (retrieve only once for better performance)
  1289. model_type_positions = {}
  1290. if hasattr(self.provider, "position") and self.provider.position:
  1291. model_type_positions = self.provider.position
  1292. def get_sort_key(model: ModelWithProviderEntity):
  1293. # Get the position list for the current model type
  1294. positions = model_type_positions.get(model.model_type.value, [])
  1295. # If the model name is in the position list, use its index for sorting
  1296. # Otherwise use a large value (list length) to place undefined models at the end
  1297. position_index = positions.index(model.model) if model.model in positions else len(positions)
  1298. # Return composite sort key: (model_type value, model position index)
  1299. return (model.model_type.value, position_index)
  1300. # Sort using the composite sort key
  1301. return sorted(provider_models, key=get_sort_key)
  1302. def _get_system_provider_models(
  1303. self,
  1304. model_types: Sequence[ModelType],
  1305. provider_schema: ProviderEntity,
  1306. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1307. ) -> list[ModelWithProviderEntity]:
  1308. """
  1309. Get system provider models.
  1310. :param model_types: model types
  1311. :param provider_schema: provider schema
  1312. :param model_setting_map: model setting map
  1313. :return:
  1314. """
  1315. provider_models = []
  1316. for model_type in model_types:
  1317. for m in provider_schema.models:
  1318. if m.model_type != model_type:
  1319. continue
  1320. status = ModelStatus.ACTIVE
  1321. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1322. model_setting = model_setting_map[m.model_type][m.model]
  1323. if model_setting.enabled is False:
  1324. status = ModelStatus.DISABLED
  1325. provider_models.append(
  1326. ModelWithProviderEntity(
  1327. model=m.model,
  1328. label=m.label,
  1329. model_type=m.model_type,
  1330. features=m.features,
  1331. fetch_from=m.fetch_from,
  1332. model_properties=m.model_properties,
  1333. deprecated=m.deprecated,
  1334. provider=SimpleModelProviderEntity(self.provider),
  1335. status=status,
  1336. )
  1337. )
  1338. if self.provider.provider not in original_provider_configurate_methods:
  1339. original_provider_configurate_methods[self.provider.provider] = []
  1340. for configurate_method in provider_schema.configurate_methods:
  1341. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  1342. should_use_custom_model = False
  1343. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  1344. should_use_custom_model = True
  1345. for quota_configuration in self.system_configuration.quota_configurations:
  1346. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  1347. continue
  1348. restrict_models = quota_configuration.restrict_models
  1349. if len(restrict_models) == 0:
  1350. break
  1351. if should_use_custom_model:
  1352. if original_provider_configurate_methods[self.provider.provider] == [
  1353. ConfigurateMethod.CUSTOMIZABLE_MODEL
  1354. ]:
  1355. # only customizable model
  1356. for restrict_model in restrict_models:
  1357. copy_credentials = (
  1358. self.system_configuration.credentials.copy()
  1359. if self.system_configuration.credentials
  1360. else {}
  1361. )
  1362. if restrict_model.base_model_name:
  1363. copy_credentials["base_model_name"] = restrict_model.base_model_name
  1364. try:
  1365. custom_model_schema = self.get_model_schema(
  1366. model_type=restrict_model.model_type,
  1367. model=restrict_model.model,
  1368. credentials=copy_credentials,
  1369. )
  1370. except Exception as ex:
  1371. logger.warning("get custom model schema failed, %s", ex)
  1372. continue
  1373. if not custom_model_schema:
  1374. continue
  1375. if custom_model_schema.model_type not in model_types:
  1376. continue
  1377. status = ModelStatus.ACTIVE
  1378. if (
  1379. custom_model_schema.model_type in model_setting_map
  1380. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1381. ):
  1382. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1383. if model_setting.enabled is False:
  1384. status = ModelStatus.DISABLED
  1385. provider_models.append(
  1386. ModelWithProviderEntity(
  1387. model=custom_model_schema.model,
  1388. label=custom_model_schema.label,
  1389. model_type=custom_model_schema.model_type,
  1390. features=custom_model_schema.features,
  1391. fetch_from=FetchFrom.PREDEFINED_MODEL,
  1392. model_properties=custom_model_schema.model_properties,
  1393. deprecated=custom_model_schema.deprecated,
  1394. provider=SimpleModelProviderEntity(self.provider),
  1395. status=status,
  1396. )
  1397. )
  1398. # if llm name not in restricted llm list, remove it
  1399. restrict_model_names = [rm.model for rm in restrict_models]
  1400. for model in provider_models:
  1401. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  1402. model.status = ModelStatus.NO_PERMISSION
  1403. elif not quota_configuration.is_valid:
  1404. model.status = ModelStatus.QUOTA_EXCEEDED
  1405. return provider_models
  1406. def _get_custom_provider_models(
  1407. self,
  1408. model_types: Sequence[ModelType],
  1409. provider_schema: ProviderEntity,
  1410. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1411. model: Optional[str] = None,
  1412. ) -> list[ModelWithProviderEntity]:
  1413. """
  1414. Get custom provider models.
  1415. :param model_types: model types
  1416. :param provider_schema: provider schema
  1417. :param model_setting_map: model setting map
  1418. :return:
  1419. """
  1420. provider_models = []
  1421. credentials = None
  1422. if self.custom_configuration.provider:
  1423. credentials = self.custom_configuration.provider.credentials
  1424. for model_type in model_types:
  1425. if model_type not in self.provider.supported_model_types:
  1426. continue
  1427. for m in provider_schema.models:
  1428. if m.model_type != model_type:
  1429. continue
  1430. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  1431. load_balancing_enabled = False
  1432. has_invalid_load_balancing_configs = False
  1433. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1434. model_setting = model_setting_map[m.model_type][m.model]
  1435. if model_setting.enabled is False:
  1436. status = ModelStatus.DISABLED
  1437. provider_model_lb_configs = [
  1438. config
  1439. for config in model_setting.load_balancing_configs
  1440. if config.credential_source_type != "custom_model"
  1441. ]
  1442. load_balancing_enabled = model_setting.load_balancing_enabled
  1443. # when the user enable load_balancing but available configs are less than 2 display warning
  1444. has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
  1445. provider_models.append(
  1446. ModelWithProviderEntity(
  1447. model=m.model,
  1448. label=m.label,
  1449. model_type=m.model_type,
  1450. features=m.features,
  1451. fetch_from=m.fetch_from,
  1452. model_properties=m.model_properties,
  1453. deprecated=m.deprecated,
  1454. provider=SimpleModelProviderEntity(self.provider),
  1455. status=status,
  1456. load_balancing_enabled=load_balancing_enabled,
  1457. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1458. )
  1459. )
  1460. # custom models
  1461. for model_configuration in self.custom_configuration.models:
  1462. if model_configuration.model_type not in model_types:
  1463. continue
  1464. if model_configuration.unadded_to_model_list:
  1465. continue
  1466. if model and model != model_configuration.model:
  1467. continue
  1468. try:
  1469. custom_model_schema = self.get_model_schema(
  1470. model_type=model_configuration.model_type,
  1471. model=model_configuration.model,
  1472. credentials=model_configuration.credentials,
  1473. )
  1474. except Exception as ex:
  1475. logger.warning("get custom model schema failed, %s", ex)
  1476. continue
  1477. if not custom_model_schema:
  1478. continue
  1479. status = ModelStatus.ACTIVE
  1480. load_balancing_enabled = False
  1481. has_invalid_load_balancing_configs = False
  1482. if (
  1483. custom_model_schema.model_type in model_setting_map
  1484. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1485. ):
  1486. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1487. if model_setting.enabled is False:
  1488. status = ModelStatus.DISABLED
  1489. custom_model_lb_configs = [
  1490. config
  1491. for config in model_setting.load_balancing_configs
  1492. if config.credential_source_type != "provider"
  1493. ]
  1494. load_balancing_enabled = model_setting.load_balancing_enabled
  1495. # when the user enable load_balancing but available configs are less than 2 display warning
  1496. has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
  1497. if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
  1498. status = ModelStatus.CREDENTIAL_REMOVED
  1499. provider_models.append(
  1500. ModelWithProviderEntity(
  1501. model=custom_model_schema.model,
  1502. label=custom_model_schema.label,
  1503. model_type=custom_model_schema.model_type,
  1504. features=custom_model_schema.features,
  1505. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  1506. model_properties=custom_model_schema.model_properties,
  1507. deprecated=custom_model_schema.deprecated,
  1508. provider=SimpleModelProviderEntity(self.provider),
  1509. status=status,
  1510. load_balancing_enabled=load_balancing_enabled,
  1511. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1512. )
  1513. )
  1514. return provider_models
  1515. class ProviderConfigurations(BaseModel):
  1516. """
  1517. Model class for provider configuration dict.
  1518. """
  1519. tenant_id: str
  1520. configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
  1521. def __init__(self, tenant_id: str):
  1522. super().__init__(tenant_id=tenant_id)
  1523. def get_models(
  1524. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  1525. ) -> list[ModelWithProviderEntity]:
  1526. """
  1527. Get available models.
  1528. If preferred provider type is `system`:
  1529. Get the current **system mode** if provider supported,
  1530. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  1531. If there is no model configured in custom mode, it is treated as no_configure.
  1532. system > custom > no_configure
  1533. If preferred provider type is `custom`:
  1534. If custom credentials are configured, it is treated as custom mode.
  1535. Otherwise, get the current **system mode** if supported,
  1536. If all system modes are not available (no quota), it is treated as no_configure.
  1537. custom > system > no_configure
  1538. If real mode is `system`, use system credentials to get models,
  1539. paid quotas > provider free quotas > system free quotas
  1540. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  1541. If real mode is `custom`, use workspace custom credentials to get models,
  1542. include pre-defined models, custom models(manual append).
  1543. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  1544. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  1545. model status marked as `active` is available.
  1546. :param provider: provider name
  1547. :param model_type: model type
  1548. :param only_active: only active models
  1549. :return:
  1550. """
  1551. all_models = []
  1552. for provider_configuration in self.values():
  1553. if provider and provider_configuration.provider.provider != provider:
  1554. continue
  1555. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  1556. return all_models
  1557. def to_list(self) -> list[ProviderConfiguration]:
  1558. """
  1559. Convert to list.
  1560. :return:
  1561. """
  1562. return list(self.values())
  1563. def __getitem__(self, key):
  1564. if "/" not in key:
  1565. key = str(ModelProviderID(key))
  1566. return self.configurations[key]
  1567. def __setitem__(self, key, value):
  1568. self.configurations[key] = value
  1569. def __iter__(self):
  1570. return iter(self.configurations)
  1571. def values(self) -> Iterator[ProviderConfiguration]:
  1572. return iter(self.configurations.values())
  1573. def get(self, key, default=None) -> ProviderConfiguration | None:
  1574. if "/" not in key:
  1575. key = str(ModelProviderID(key))
  1576. return self.configurations.get(key, default) # type: ignore
  1577. class ProviderModelBundle(BaseModel):
  1578. """
  1579. Provider model bundle.
  1580. """
  1581. configuration: ProviderConfiguration
  1582. model_type_instance: AIModel
  1583. # pydantic configs
  1584. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())