You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_provider_service.py 22KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. import logging
  2. from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
  3. from core.model_runtime.entities.model_entities import ModelType, ParameterRule
  4. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  5. from core.provider_manager import ProviderManager
  6. from models.provider import ProviderType
  7. from services.entities.model_provider_entities import (
  8. CustomConfigurationResponse,
  9. CustomConfigurationStatus,
  10. DefaultModelResponse,
  11. ModelWithProviderEntityResponse,
  12. ProviderResponse,
  13. ProviderWithModelsResponse,
  14. SimpleProviderEntityResponse,
  15. SystemConfigurationResponse,
  16. )
  17. from services.errors.app_model_config import ProviderNotFoundError
  18. logger = logging.getLogger(__name__)
  19. class ModelProviderService:
  20. """
  21. Model Provider Service
  22. """
  23. def __init__(self):
  24. self.provider_manager = ProviderManager()
  25. def _get_provider_configuration(self, tenant_id: str, provider: str):
  26. """
  27. Get provider configuration or raise exception if not found.
  28. Args:
  29. tenant_id: Workspace identifier
  30. provider: Provider name
  31. Returns:
  32. Provider configuration instance
  33. Raises:
  34. ProviderNotFoundError: If provider doesn't exist
  35. """
  36. # Get all provider configurations of the current workspace
  37. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  38. provider_configuration = provider_configurations.get(provider)
  39. if not provider_configuration:
  40. raise ProviderNotFoundError(f"Provider {provider} does not exist.")
  41. return provider_configuration
  42. def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]:
  43. """
  44. get provider list.
  45. :param tenant_id: workspace id
  46. :param model_type: model type
  47. :return:
  48. """
  49. # Get all provider configurations of the current workspace
  50. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  51. provider_responses = []
  52. for provider_configuration in provider_configurations.values():
  53. if model_type:
  54. model_type_entity = ModelType.value_of(model_type)
  55. if model_type_entity not in provider_configuration.provider.supported_model_types:
  56. continue
  57. provider_config = provider_configuration.custom_configuration.provider
  58. model_config = provider_configuration.custom_configuration.models
  59. can_added_models = provider_configuration.custom_configuration.can_added_models
  60. provider_response = ProviderResponse(
  61. tenant_id=tenant_id,
  62. provider=provider_configuration.provider.provider,
  63. label=provider_configuration.provider.label,
  64. description=provider_configuration.provider.description,
  65. icon_small=provider_configuration.provider.icon_small,
  66. icon_large=provider_configuration.provider.icon_large,
  67. background=provider_configuration.provider.background,
  68. help=provider_configuration.provider.help,
  69. supported_model_types=provider_configuration.provider.supported_model_types,
  70. configurate_methods=provider_configuration.provider.configurate_methods,
  71. provider_credential_schema=provider_configuration.provider.provider_credential_schema,
  72. model_credential_schema=provider_configuration.provider.model_credential_schema,
  73. preferred_provider_type=provider_configuration.preferred_provider_type,
  74. custom_configuration=CustomConfigurationResponse(
  75. status=CustomConfigurationStatus.ACTIVE
  76. if provider_configuration.is_custom_configuration_available()
  77. else CustomConfigurationStatus.NO_CONFIGURE,
  78. current_credential_id=getattr(provider_config, "current_credential_id", None),
  79. current_credential_name=getattr(provider_config, "current_credential_name", None),
  80. available_credentials=getattr(provider_config, "available_credentials", []),
  81. custom_models=model_config,
  82. can_added_models=can_added_models,
  83. ),
  84. system_configuration=SystemConfigurationResponse(
  85. enabled=provider_configuration.system_configuration.enabled,
  86. current_quota_type=provider_configuration.system_configuration.current_quota_type,
  87. quota_configurations=provider_configuration.system_configuration.quota_configurations,
  88. ),
  89. )
  90. provider_responses.append(provider_response)
  91. return provider_responses
  92. def get_models_by_provider(self, tenant_id: str, provider: str) -> list[ModelWithProviderEntityResponse]:
  93. """
  94. get provider models.
  95. For the model provider page,
  96. only supports passing in a single provider to query the list of supported models.
  97. :param tenant_id: workspace id
  98. :param provider: provider name
  99. :return:
  100. """
  101. # Get all provider configurations of the current workspace
  102. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  103. # Get provider available models
  104. return [
  105. ModelWithProviderEntityResponse(tenant_id=tenant_id, model=model)
  106. for model in provider_configurations.get_models(provider=provider)
  107. ]
  108. def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
  109. """
  110. get provider credentials.
  111. :param tenant_id: workspace id
  112. :param provider: provider name
  113. :param credential_id: credential id, if not provided, return current used credentials
  114. :return:
  115. """
  116. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  117. return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
  118. def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
  119. """
  120. validate provider credentials before saving.
  121. :param tenant_id: workspace id
  122. :param provider: provider name
  123. :param credentials: provider credentials dict
  124. """
  125. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  126. provider_configuration.validate_provider_credentials(credentials)
  127. def create_provider_credential(
  128. self, tenant_id: str, provider: str, credentials: dict, credential_name: str | None
  129. ) -> None:
  130. """
  131. Create and save new provider credentials.
  132. :param tenant_id: workspace id
  133. :param provider: provider name
  134. :param credentials: provider credentials dict
  135. :param credential_name: credential name
  136. :return:
  137. """
  138. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  139. provider_configuration.create_provider_credential(credentials, credential_name)
  140. def update_provider_credential(
  141. self,
  142. tenant_id: str,
  143. provider: str,
  144. credentials: dict,
  145. credential_id: str,
  146. credential_name: str | None,
  147. ) -> None:
  148. """
  149. update a saved provider credential (by credential_id).
  150. :param tenant_id: workspace id
  151. :param provider: provider name
  152. :param credentials: provider credentials dict
  153. :param credential_id: credential id
  154. :param credential_name: credential name
  155. :return:
  156. """
  157. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  158. provider_configuration.update_provider_credential(
  159. credential_id=credential_id,
  160. credentials=credentials,
  161. credential_name=credential_name,
  162. )
  163. def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
  164. """
  165. remove a saved provider credential (by credential_id).
  166. :param tenant_id: workspace id
  167. :param provider: provider name
  168. :param credential_id: credential id
  169. :return:
  170. """
  171. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  172. provider_configuration.delete_provider_credential(credential_id=credential_id)
  173. def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
  174. """
  175. :param tenant_id: workspace id
  176. :param provider: provider name
  177. :param credential_id: credential id
  178. :return:
  179. """
  180. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  181. provider_configuration.switch_active_provider_credential(credential_id=credential_id)
  182. def get_model_credential(
  183. self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
  184. ) -> dict | None:
  185. """
  186. Retrieve model-specific credentials.
  187. :param tenant_id: workspace id
  188. :param provider: provider name
  189. :param model_type: model type
  190. :param model: model name
  191. :param credential_id: Optional credential ID, uses current if not provided
  192. :return:
  193. """
  194. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  195. return provider_configuration.get_custom_model_credential( # type: ignore
  196. model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
  197. )
  198. def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
  199. """
  200. validate model credentials.
  201. :param tenant_id: workspace id
  202. :param provider: provider name
  203. :param model_type: model type
  204. :param model: model name
  205. :param credentials: model credentials dict
  206. :return:
  207. """
  208. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  209. provider_configuration.validate_custom_model_credentials(
  210. model_type=ModelType.value_of(model_type), model=model, credentials=credentials
  211. )
  212. def create_model_credential(
  213. self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict, credential_name: str | None
  214. ) -> None:
  215. """
  216. create and save model credentials.
  217. :param tenant_id: workspace id
  218. :param provider: provider name
  219. :param model_type: model type
  220. :param model: model name
  221. :param credentials: model credentials dict
  222. :param credential_name: credential name
  223. :return:
  224. """
  225. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  226. provider_configuration.create_custom_model_credential(
  227. model_type=ModelType.value_of(model_type),
  228. model=model,
  229. credentials=credentials,
  230. credential_name=credential_name,
  231. )
  232. def update_model_credential(
  233. self,
  234. tenant_id: str,
  235. provider: str,
  236. model_type: str,
  237. model: str,
  238. credentials: dict,
  239. credential_id: str,
  240. credential_name: str | None,
  241. ) -> None:
  242. """
  243. update model credentials.
  244. :param tenant_id: workspace id
  245. :param provider: provider name
  246. :param model_type: model type
  247. :param model: model name
  248. :param credentials: model credentials dict
  249. :param credential_id: credential id
  250. :param credential_name: credential name
  251. :return:
  252. """
  253. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  254. provider_configuration.update_custom_model_credential(
  255. model_type=ModelType.value_of(model_type),
  256. model=model,
  257. credentials=credentials,
  258. credential_id=credential_id,
  259. credential_name=credential_name,
  260. )
  261. def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
  262. """
  263. remove model credentials.
  264. :param tenant_id: workspace id
  265. :param provider: provider name
  266. :param model_type: model type
  267. :param model: model name
  268. :param credential_id: credential id
  269. :return:
  270. """
  271. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  272. provider_configuration.delete_custom_model_credential(
  273. model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
  274. )
  275. def switch_active_custom_model_credential(
  276. self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
  277. ):
  278. """
  279. switch model credentials.
  280. :param tenant_id: workspace id
  281. :param provider: provider name
  282. :param model_type: model type
  283. :param model: model name
  284. :param credential_id: credential id
  285. :return:
  286. """
  287. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  288. provider_configuration.switch_custom_model_credential(
  289. model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
  290. )
  291. def add_model_credential_to_model_list(
  292. self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
  293. ):
  294. """
  295. add model credentials to model list.
  296. :param tenant_id: workspace id
  297. :param provider: provider name
  298. :param model_type: model type
  299. :param model: model name
  300. :param credential_id: credential id
  301. :return:
  302. """
  303. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  304. provider_configuration.add_model_credential_to_model(
  305. model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
  306. )
  307. def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
  308. """
  309. remove model credentials.
  310. :param tenant_id: workspace id
  311. :param provider: provider name
  312. :param model_type: model type
  313. :param model: model name
  314. :return:
  315. """
  316. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  317. provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model)
  318. def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
  319. """
  320. get models by model type.
  321. :param tenant_id: workspace id
  322. :param model_type: model type
  323. :return:
  324. """
  325. # Get all provider configurations of the current workspace
  326. provider_configurations = self.provider_manager.get_configurations(tenant_id)
  327. # Get provider available models
  328. models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True)
  329. # Group models by provider
  330. provider_models: dict[str, list[ModelWithProviderEntity]] = {}
  331. for model in models:
  332. if model.provider.provider not in provider_models:
  333. provider_models[model.provider.provider] = []
  334. if model.deprecated:
  335. continue
  336. provider_models[model.provider.provider].append(model)
  337. # convert to ProviderWithModelsResponse list
  338. providers_with_models: list[ProviderWithModelsResponse] = []
  339. for provider, models in provider_models.items():
  340. if not models:
  341. continue
  342. first_model = models[0]
  343. providers_with_models.append(
  344. ProviderWithModelsResponse(
  345. tenant_id=tenant_id,
  346. provider=provider,
  347. label=first_model.provider.label,
  348. icon_small=first_model.provider.icon_small,
  349. icon_large=first_model.provider.icon_large,
  350. status=CustomConfigurationStatus.ACTIVE,
  351. models=[
  352. ProviderModelWithStatusEntity(
  353. model=model.model,
  354. label=model.label,
  355. model_type=model.model_type,
  356. features=model.features,
  357. fetch_from=model.fetch_from,
  358. model_properties=model.model_properties,
  359. status=model.status,
  360. load_balancing_enabled=model.load_balancing_enabled,
  361. )
  362. for model in models
  363. ],
  364. )
  365. )
  366. return providers_with_models
  367. def get_model_parameter_rules(self, tenant_id: str, provider: str, model: str) -> list[ParameterRule]:
  368. """
  369. get model parameter rules.
  370. Only supports LLM.
  371. :param tenant_id: workspace id
  372. :param provider: provider name
  373. :param model: model name
  374. :return:
  375. """
  376. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  377. # fetch credentials
  378. credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
  379. if not credentials:
  380. return []
  381. model_schema = provider_configuration.get_model_schema(
  382. model_type=ModelType.LLM, model=model, credentials=credentials
  383. )
  384. return model_schema.parameter_rules if model_schema else []
  385. def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None:
  386. """
  387. get default model of model type.
  388. :param tenant_id: workspace id
  389. :param model_type: model type
  390. :return:
  391. """
  392. model_type_enum = ModelType.value_of(model_type)
  393. try:
  394. result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
  395. return (
  396. DefaultModelResponse(
  397. model=result.model,
  398. model_type=result.model_type,
  399. provider=SimpleProviderEntityResponse(
  400. tenant_id=tenant_id,
  401. provider=result.provider.provider,
  402. label=result.provider.label,
  403. icon_small=result.provider.icon_small,
  404. icon_large=result.provider.icon_large,
  405. supported_model_types=result.provider.supported_model_types,
  406. ),
  407. )
  408. if result
  409. else None
  410. )
  411. except Exception as e:
  412. logger.debug("get_default_model_of_model_type error: %s", e)
  413. return None
  414. def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
  415. """
  416. update default model of model type.
  417. :param tenant_id: workspace id
  418. :param model_type: model type
  419. :param provider: provider name
  420. :param model: model name
  421. :return:
  422. """
  423. model_type_enum = ModelType.value_of(model_type)
  424. self.provider_manager.update_default_model_record(
  425. tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
  426. )
  427. def get_model_provider_icon(
  428. self, tenant_id: str, provider: str, icon_type: str, lang: str
  429. ) -> tuple[bytes | None, str | None]:
  430. """
  431. get model provider icon.
  432. :param tenant_id: workspace id
  433. :param provider: provider name
  434. :param icon_type: icon type (icon_small or icon_large)
  435. :param lang: language (zh_Hans or en_US)
  436. :return:
  437. """
  438. model_provider_factory = ModelProviderFactory(tenant_id)
  439. byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang)
  440. return byte_data, mime_type
  441. def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
  442. """
  443. switch preferred provider.
  444. :param tenant_id: workspace id
  445. :param provider: provider name
  446. :param preferred_provider_type: preferred provider type
  447. :return:
  448. """
  449. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  450. # Convert preferred_provider_type to ProviderType
  451. preferred_provider_type_enum = ProviderType.value_of(preferred_provider_type)
  452. # Switch preferred provider type
  453. provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
  454. def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
  455. """
  456. enable model.
  457. :param tenant_id: workspace id
  458. :param provider: provider name
  459. :param model: model name
  460. :param model_type: model type
  461. :return:
  462. """
  463. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  464. provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
  465. def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
  466. """
  467. disable model.
  468. :param tenant_id: workspace id
  469. :param provider: provider name
  470. :param model: model name
  471. :param model_type: model type
  472. :return:
  473. """
  474. provider_configuration = self._get_provider_configuration(tenant_id, provider)
  475. provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))