You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. import logging
  2. from flask_login import current_user
  3. from flask_restx import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.wraps import account_initialization_required, setup_required
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from libs.helper import StrLen, uuid_value
  11. from libs.login import login_required
  12. from services.model_load_balancing_service import ModelLoadBalancingService
  13. from services.model_provider_service import ModelProviderService
  14. logger = logging.getLogger(__name__)
  15. class DefaultModelApi(Resource):
  16. @setup_required
  17. @login_required
  18. @account_initialization_required
  19. def get(self):
  20. parser = reqparse.RequestParser()
  21. parser.add_argument(
  22. "model_type",
  23. type=str,
  24. required=True,
  25. nullable=False,
  26. choices=[mt.value for mt in ModelType],
  27. location="args",
  28. )
  29. args = parser.parse_args()
  30. tenant_id = current_user.current_tenant_id
  31. model_provider_service = ModelProviderService()
  32. default_model_entity = model_provider_service.get_default_model_of_model_type(
  33. tenant_id=tenant_id, model_type=args["model_type"]
  34. )
  35. return jsonable_encoder({"data": default_model_entity})
  36. @setup_required
  37. @login_required
  38. @account_initialization_required
  39. def post(self):
  40. if not current_user.is_admin_or_owner:
  41. raise Forbidden()
  42. parser = reqparse.RequestParser()
  43. parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
  44. args = parser.parse_args()
  45. tenant_id = current_user.current_tenant_id
  46. model_provider_service = ModelProviderService()
  47. model_settings = args["model_settings"]
  48. for model_setting in model_settings:
  49. if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
  50. raise ValueError("invalid model type")
  51. if "provider" not in model_setting:
  52. continue
  53. if "model" not in model_setting:
  54. raise ValueError("invalid model")
  55. try:
  56. model_provider_service.update_default_model_of_model_type(
  57. tenant_id=tenant_id,
  58. model_type=model_setting["model_type"],
  59. provider=model_setting["provider"],
  60. model=model_setting["model"],
  61. )
  62. except Exception as ex:
  63. logger.exception(
  64. "Failed to update default model, model type: %s, model: %s",
  65. model_setting["model_type"],
  66. model_setting.get("model"),
  67. )
  68. raise ex
  69. return {"result": "success"}
  70. class ModelProviderModelApi(Resource):
  71. @setup_required
  72. @login_required
  73. @account_initialization_required
  74. def get(self, provider):
  75. tenant_id = current_user.current_tenant_id
  76. model_provider_service = ModelProviderService()
  77. models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
  78. return jsonable_encoder({"data": models})
  79. @setup_required
  80. @login_required
  81. @account_initialization_required
  82. def post(self, provider: str):
  83. # To save the model's load balance configs
  84. if not current_user.is_admin_or_owner:
  85. raise Forbidden()
  86. tenant_id = current_user.current_tenant_id
  87. parser = reqparse.RequestParser()
  88. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  89. parser.add_argument(
  90. "model_type",
  91. type=str,
  92. required=True,
  93. nullable=False,
  94. choices=[mt.value for mt in ModelType],
  95. location="json",
  96. )
  97. parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
  98. parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
  99. parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
  100. args = parser.parse_args()
  101. if args.get("config_from", "") == "custom-model":
  102. if not args.get("credential_id"):
  103. raise ValueError("credential_id is required when configuring a custom-model")
  104. service = ModelProviderService()
  105. service.switch_active_custom_model_credential(
  106. tenant_id=current_user.current_tenant_id,
  107. provider=provider,
  108. model_type=args["model_type"],
  109. model=args["model"],
  110. credential_id=args["credential_id"],
  111. )
  112. model_load_balancing_service = ModelLoadBalancingService()
  113. if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
  114. # save load balancing configs
  115. model_load_balancing_service.update_load_balancing_configs(
  116. tenant_id=tenant_id,
  117. provider=provider,
  118. model=args["model"],
  119. model_type=args["model_type"],
  120. configs=args["load_balancing"]["configs"],
  121. config_from=args.get("config_from", ""),
  122. )
  123. if args.get("load_balancing", {}).get("enabled"):
  124. model_load_balancing_service.enable_model_load_balancing(
  125. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  126. )
  127. else:
  128. model_load_balancing_service.disable_model_load_balancing(
  129. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  130. )
  131. return {"result": "success"}, 200
  132. @setup_required
  133. @login_required
  134. @account_initialization_required
  135. def delete(self, provider: str):
  136. if not current_user.is_admin_or_owner:
  137. raise Forbidden()
  138. tenant_id = current_user.current_tenant_id
  139. parser = reqparse.RequestParser()
  140. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  141. parser.add_argument(
  142. "model_type",
  143. type=str,
  144. required=True,
  145. nullable=False,
  146. choices=[mt.value for mt in ModelType],
  147. location="json",
  148. )
  149. args = parser.parse_args()
  150. model_provider_service = ModelProviderService()
  151. model_provider_service.remove_model(
  152. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  153. )
  154. return {"result": "success"}, 204
  155. class ModelProviderModelCredentialApi(Resource):
  156. @setup_required
  157. @login_required
  158. @account_initialization_required
  159. def get(self, provider: str):
  160. tenant_id = current_user.current_tenant_id
  161. parser = reqparse.RequestParser()
  162. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  163. parser.add_argument(
  164. "model_type",
  165. type=str,
  166. required=True,
  167. nullable=False,
  168. choices=[mt.value for mt in ModelType],
  169. location="args",
  170. )
  171. parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
  172. parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
  173. args = parser.parse_args()
  174. model_provider_service = ModelProviderService()
  175. current_credential = model_provider_service.get_model_credential(
  176. tenant_id=tenant_id,
  177. provider=provider,
  178. model_type=args["model_type"],
  179. model=args["model"],
  180. credential_id=args.get("credential_id"),
  181. )
  182. model_load_balancing_service = ModelLoadBalancingService()
  183. is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
  184. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  185. )
  186. if args.get("config_from", "") == "predefined-model":
  187. available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
  188. tenant_id=tenant_id, provider_name=provider
  189. )
  190. else:
  191. model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
  192. available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
  193. tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
  194. )
  195. return jsonable_encoder(
  196. {
  197. "credentials": current_credential.get("credentials") if current_credential else {},
  198. "current_credential_id": current_credential.get("current_credential_id")
  199. if current_credential
  200. else None,
  201. "current_credential_name": current_credential.get("current_credential_name")
  202. if current_credential
  203. else None,
  204. "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
  205. "available_credentials": available_credentials,
  206. }
  207. )
  208. @setup_required
  209. @login_required
  210. @account_initialization_required
  211. def post(self, provider: str):
  212. if not current_user.is_admin_or_owner:
  213. raise Forbidden()
  214. parser = reqparse.RequestParser()
  215. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  216. parser.add_argument(
  217. "model_type",
  218. type=str,
  219. required=True,
  220. nullable=False,
  221. choices=[mt.value for mt in ModelType],
  222. location="json",
  223. )
  224. parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
  225. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  226. args = parser.parse_args()
  227. tenant_id = current_user.current_tenant_id
  228. model_provider_service = ModelProviderService()
  229. try:
  230. model_provider_service.create_model_credential(
  231. tenant_id=tenant_id,
  232. provider=provider,
  233. model=args["model"],
  234. model_type=args["model_type"],
  235. credentials=args["credentials"],
  236. credential_name=args["name"],
  237. )
  238. except CredentialsValidateFailedError as ex:
  239. logger.exception(
  240. "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
  241. tenant_id,
  242. args.get("model"),
  243. args.get("model_type"),
  244. )
  245. raise ValueError(str(ex))
  246. return {"result": "success"}, 201
  247. @setup_required
  248. @login_required
  249. @account_initialization_required
  250. def put(self, provider: str):
  251. if not current_user.is_admin_or_owner:
  252. raise Forbidden()
  253. parser = reqparse.RequestParser()
  254. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  255. parser.add_argument(
  256. "model_type",
  257. type=str,
  258. required=True,
  259. nullable=False,
  260. choices=[mt.value for mt in ModelType],
  261. location="json",
  262. )
  263. parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
  264. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  265. parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
  266. args = parser.parse_args()
  267. model_provider_service = ModelProviderService()
  268. try:
  269. model_provider_service.update_model_credential(
  270. tenant_id=current_user.current_tenant_id,
  271. provider=provider,
  272. model_type=args["model_type"],
  273. model=args["model"],
  274. credentials=args["credentials"],
  275. credential_id=args["credential_id"],
  276. credential_name=args["name"],
  277. )
  278. except CredentialsValidateFailedError as ex:
  279. raise ValueError(str(ex))
  280. return {"result": "success"}
  281. @setup_required
  282. @login_required
  283. @account_initialization_required
  284. def delete(self, provider: str):
  285. if not current_user.is_admin_or_owner:
  286. raise Forbidden()
  287. parser = reqparse.RequestParser()
  288. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  289. parser.add_argument(
  290. "model_type",
  291. type=str,
  292. required=True,
  293. nullable=False,
  294. choices=[mt.value for mt in ModelType],
  295. location="json",
  296. )
  297. parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
  298. args = parser.parse_args()
  299. model_provider_service = ModelProviderService()
  300. model_provider_service.remove_model_credential(
  301. tenant_id=current_user.current_tenant_id,
  302. provider=provider,
  303. model_type=args["model_type"],
  304. model=args["model"],
  305. credential_id=args["credential_id"],
  306. )
  307. return {"result": "success"}, 204
  308. class ModelProviderModelCredentialSwitchApi(Resource):
  309. @setup_required
  310. @login_required
  311. @account_initialization_required
  312. def post(self, provider: str):
  313. if not current_user.is_admin_or_owner:
  314. raise Forbidden()
  315. parser = reqparse.RequestParser()
  316. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  317. parser.add_argument(
  318. "model_type",
  319. type=str,
  320. required=True,
  321. nullable=False,
  322. choices=[mt.value for mt in ModelType],
  323. location="json",
  324. )
  325. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  326. args = parser.parse_args()
  327. service = ModelProviderService()
  328. service.add_model_credential_to_model_list(
  329. tenant_id=current_user.current_tenant_id,
  330. provider=provider,
  331. model_type=args["model_type"],
  332. model=args["model"],
  333. credential_id=args["credential_id"],
  334. )
  335. return {"result": "success"}
  336. class ModelProviderModelEnableApi(Resource):
  337. @setup_required
  338. @login_required
  339. @account_initialization_required
  340. def patch(self, provider: str):
  341. tenant_id = current_user.current_tenant_id
  342. parser = reqparse.RequestParser()
  343. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  344. parser.add_argument(
  345. "model_type",
  346. type=str,
  347. required=True,
  348. nullable=False,
  349. choices=[mt.value for mt in ModelType],
  350. location="json",
  351. )
  352. args = parser.parse_args()
  353. model_provider_service = ModelProviderService()
  354. model_provider_service.enable_model(
  355. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  356. )
  357. return {"result": "success"}
  358. class ModelProviderModelDisableApi(Resource):
  359. @setup_required
  360. @login_required
  361. @account_initialization_required
  362. def patch(self, provider: str):
  363. tenant_id = current_user.current_tenant_id
  364. parser = reqparse.RequestParser()
  365. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  366. parser.add_argument(
  367. "model_type",
  368. type=str,
  369. required=True,
  370. nullable=False,
  371. choices=[mt.value for mt in ModelType],
  372. location="json",
  373. )
  374. args = parser.parse_args()
  375. model_provider_service = ModelProviderService()
  376. model_provider_service.disable_model(
  377. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  378. )
  379. return {"result": "success"}
  380. class ModelProviderModelValidateApi(Resource):
  381. @setup_required
  382. @login_required
  383. @account_initialization_required
  384. def post(self, provider: str):
  385. tenant_id = current_user.current_tenant_id
  386. parser = reqparse.RequestParser()
  387. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  388. parser.add_argument(
  389. "model_type",
  390. type=str,
  391. required=True,
  392. nullable=False,
  393. choices=[mt.value for mt in ModelType],
  394. location="json",
  395. )
  396. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  397. args = parser.parse_args()
  398. model_provider_service = ModelProviderService()
  399. result = True
  400. error = ""
  401. try:
  402. model_provider_service.validate_model_credentials(
  403. tenant_id=tenant_id,
  404. provider=provider,
  405. model=args["model"],
  406. model_type=args["model_type"],
  407. credentials=args["credentials"],
  408. )
  409. except CredentialsValidateFailedError as ex:
  410. result = False
  411. error = str(ex)
  412. response = {"result": "success" if result else "error"}
  413. if not result:
  414. response["error"] = error or ""
  415. return response
  416. class ModelProviderModelParameterRuleApi(Resource):
  417. @setup_required
  418. @login_required
  419. @account_initialization_required
  420. def get(self, provider: str):
  421. parser = reqparse.RequestParser()
  422. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  423. args = parser.parse_args()
  424. tenant_id = current_user.current_tenant_id
  425. model_provider_service = ModelProviderService()
  426. parameter_rules = model_provider_service.get_model_parameter_rules(
  427. tenant_id=tenant_id, provider=provider, model=args["model"]
  428. )
  429. return jsonable_encoder({"data": parameter_rules})
  430. class ModelProviderAvailableModelApi(Resource):
  431. @setup_required
  432. @login_required
  433. @account_initialization_required
  434. def get(self, model_type):
  435. tenant_id = current_user.current_tenant_id
  436. model_provider_service = ModelProviderService()
  437. models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
  438. return jsonable_encoder({"data": models})
  439. api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
  440. api.add_resource(
  441. ModelProviderModelEnableApi,
  442. "/workspaces/current/model-providers/<path:provider>/models/enable",
  443. endpoint="model-provider-model-enable",
  444. )
  445. api.add_resource(
  446. ModelProviderModelDisableApi,
  447. "/workspaces/current/model-providers/<path:provider>/models/disable",
  448. endpoint="model-provider-model-disable",
  449. )
  450. api.add_resource(
  451. ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
  452. )
  453. api.add_resource(
  454. ModelProviderModelCredentialSwitchApi,
  455. "/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
  456. )
  457. api.add_resource(
  458. ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
  459. )
  460. api.add_resource(
  461. ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
  462. )
  463. api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
  464. api.add_resource(DefaultModelApi, "/workspaces/current/default-model")