Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

models.py 20KB


  1. import logging
  2. from flask_login import current_user
  3. from flask_restx import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.wraps import account_initialization_required, setup_required
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from libs.helper import StrLen, uuid_value
  11. from libs.login import login_required
  12. from services.model_load_balancing_service import ModelLoadBalancingService
  13. from services.model_provider_service import ModelProviderService
  14. class DefaultModelApi(Resource):
  15. @setup_required
  16. @login_required
  17. @account_initialization_required
  18. def get(self):
  19. parser = reqparse.RequestParser()
  20. parser.add_argument(
  21. "model_type",
  22. type=str,
  23. required=True,
  24. nullable=False,
  25. choices=[mt.value for mt in ModelType],
  26. location="args",
  27. )
  28. args = parser.parse_args()
  29. tenant_id = current_user.current_tenant_id
  30. model_provider_service = ModelProviderService()
  31. default_model_entity = model_provider_service.get_default_model_of_model_type(
  32. tenant_id=tenant_id, model_type=args["model_type"]
  33. )
  34. return jsonable_encoder({"data": default_model_entity})
  35. @setup_required
  36. @login_required
  37. @account_initialization_required
  38. def post(self):
  39. if not current_user.is_admin_or_owner:
  40. raise Forbidden()
  41. parser = reqparse.RequestParser()
  42. parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
  43. args = parser.parse_args()
  44. tenant_id = current_user.current_tenant_id
  45. model_provider_service = ModelProviderService()
  46. model_settings = args["model_settings"]
  47. for model_setting in model_settings:
  48. if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
  49. raise ValueError("invalid model type")
  50. if "provider" not in model_setting:
  51. continue
  52. if "model" not in model_setting:
  53. raise ValueError("invalid model")
  54. try:
  55. model_provider_service.update_default_model_of_model_type(
  56. tenant_id=tenant_id,
  57. model_type=model_setting["model_type"],
  58. provider=model_setting["provider"],
  59. model=model_setting["model"],
  60. )
  61. except Exception as ex:
  62. logging.exception(
  63. "Failed to update default model, model type: %s, model: %s",
  64. model_setting["model_type"],
  65. model_setting.get("model"),
  66. )
  67. raise ex
  68. return {"result": "success"}
  69. class ModelProviderModelApi(Resource):
  70. @setup_required
  71. @login_required
  72. @account_initialization_required
  73. def get(self, provider):
  74. tenant_id = current_user.current_tenant_id
  75. model_provider_service = ModelProviderService()
  76. models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
  77. return jsonable_encoder({"data": models})
  78. @setup_required
  79. @login_required
  80. @account_initialization_required
  81. def post(self, provider: str):
  82. # To save the model's load balance configs
  83. if not current_user.is_admin_or_owner:
  84. raise Forbidden()
  85. tenant_id = current_user.current_tenant_id
  86. parser = reqparse.RequestParser()
  87. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  88. parser.add_argument(
  89. "model_type",
  90. type=str,
  91. required=True,
  92. nullable=False,
  93. choices=[mt.value for mt in ModelType],
  94. location="json",
  95. )
  96. parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
  97. parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
  98. parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
  99. args = parser.parse_args()
  100. if args.get("config_from", "") == "custom-model":
  101. if not args.get("credential_id"):
  102. raise ValueError("credential_id is required when configuring a custom-model")
  103. service = ModelProviderService()
  104. service.switch_active_custom_model_credential(
  105. tenant_id=current_user.current_tenant_id,
  106. provider=provider,
  107. model_type=args["model_type"],
  108. model=args["model"],
  109. credential_id=args["credential_id"],
  110. )
  111. model_load_balancing_service = ModelLoadBalancingService()
  112. if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
  113. # save load balancing configs
  114. model_load_balancing_service.update_load_balancing_configs(
  115. tenant_id=tenant_id,
  116. provider=provider,
  117. model=args["model"],
  118. model_type=args["model_type"],
  119. configs=args["load_balancing"]["configs"],
  120. config_from=args.get("config_from", ""),
  121. )
  122. if args.get("load_balancing", {}).get("enabled"):
  123. model_load_balancing_service.enable_model_load_balancing(
  124. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  125. )
  126. else:
  127. model_load_balancing_service.disable_model_load_balancing(
  128. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  129. )
  130. return {"result": "success"}, 200
  131. @setup_required
  132. @login_required
  133. @account_initialization_required
  134. def delete(self, provider: str):
  135. if not current_user.is_admin_or_owner:
  136. raise Forbidden()
  137. tenant_id = current_user.current_tenant_id
  138. parser = reqparse.RequestParser()
  139. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  140. parser.add_argument(
  141. "model_type",
  142. type=str,
  143. required=True,
  144. nullable=False,
  145. choices=[mt.value for mt in ModelType],
  146. location="json",
  147. )
  148. args = parser.parse_args()
  149. model_provider_service = ModelProviderService()
  150. model_provider_service.remove_model(
  151. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  152. )
  153. return {"result": "success"}, 204
  154. class ModelProviderModelCredentialApi(Resource):
  155. @setup_required
  156. @login_required
  157. @account_initialization_required
  158. def get(self, provider: str):
  159. tenant_id = current_user.current_tenant_id
  160. parser = reqparse.RequestParser()
  161. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  162. parser.add_argument(
  163. "model_type",
  164. type=str,
  165. required=True,
  166. nullable=False,
  167. choices=[mt.value for mt in ModelType],
  168. location="args",
  169. )
  170. parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
  171. parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
  172. args = parser.parse_args()
  173. model_provider_service = ModelProviderService()
  174. current_credential = model_provider_service.get_model_credential(
  175. tenant_id=tenant_id,
  176. provider=provider,
  177. model_type=args["model_type"],
  178. model=args["model"],
  179. credential_id=args.get("credential_id"),
  180. )
  181. model_load_balancing_service = ModelLoadBalancingService()
  182. is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
  183. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  184. )
  185. if args.get("config_from", "") == "predefined-model":
  186. available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
  187. tenant_id=tenant_id, provider_name=provider
  188. )
  189. else:
  190. model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
  191. available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
  192. tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
  193. )
  194. return jsonable_encoder(
  195. {
  196. "credentials": current_credential.get("credentials") if current_credential else {},
  197. "current_credential_id": current_credential.get("current_credential_id")
  198. if current_credential
  199. else None,
  200. "current_credential_name": current_credential.get("current_credential_name")
  201. if current_credential
  202. else None,
  203. "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
  204. "available_credentials": available_credentials,
  205. }
  206. )
  207. @setup_required
  208. @login_required
  209. @account_initialization_required
  210. def post(self, provider: str):
  211. if not current_user.is_admin_or_owner:
  212. raise Forbidden()
  213. parser = reqparse.RequestParser()
  214. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  215. parser.add_argument(
  216. "model_type",
  217. type=str,
  218. required=True,
  219. nullable=False,
  220. choices=[mt.value for mt in ModelType],
  221. location="json",
  222. )
  223. parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
  224. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  225. args = parser.parse_args()
  226. tenant_id = current_user.current_tenant_id
  227. model_provider_service = ModelProviderService()
  228. try:
  229. model_provider_service.create_model_credential(
  230. tenant_id=tenant_id,
  231. provider=provider,
  232. model=args["model"],
  233. model_type=args["model_type"],
  234. credentials=args["credentials"],
  235. credential_name=args["name"],
  236. )
  237. except CredentialsValidateFailedError as ex:
  238. logging.exception(
  239. "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
  240. tenant_id,
  241. args.get("model"),
  242. args.get("model_type"),
  243. )
  244. raise ValueError(str(ex))
  245. return {"result": "success"}, 201
  246. @setup_required
  247. @login_required
  248. @account_initialization_required
  249. def put(self, provider: str):
  250. if not current_user.is_admin_or_owner:
  251. raise Forbidden()
  252. parser = reqparse.RequestParser()
  253. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  254. parser.add_argument(
  255. "model_type",
  256. type=str,
  257. required=True,
  258. nullable=False,
  259. choices=[mt.value for mt in ModelType],
  260. location="json",
  261. )
  262. parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
  263. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  264. parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
  265. args = parser.parse_args()
  266. model_provider_service = ModelProviderService()
  267. try:
  268. model_provider_service.update_model_credential(
  269. tenant_id=current_user.current_tenant_id,
  270. provider=provider,
  271. model_type=args["model_type"],
  272. model=args["model"],
  273. credentials=args["credentials"],
  274. credential_id=args["credential_id"],
  275. credential_name=args["name"],
  276. )
  277. except CredentialsValidateFailedError as ex:
  278. raise ValueError(str(ex))
  279. return {"result": "success"}
  280. @setup_required
  281. @login_required
  282. @account_initialization_required
  283. def delete(self, provider: str):
  284. if not current_user.is_admin_or_owner:
  285. raise Forbidden()
  286. parser = reqparse.RequestParser()
  287. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  288. parser.add_argument(
  289. "model_type",
  290. type=str,
  291. required=True,
  292. nullable=False,
  293. choices=[mt.value for mt in ModelType],
  294. location="json",
  295. )
  296. parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
  297. args = parser.parse_args()
  298. model_provider_service = ModelProviderService()
  299. model_provider_service.remove_model_credential(
  300. tenant_id=current_user.current_tenant_id,
  301. provider=provider,
  302. model_type=args["model_type"],
  303. model=args["model"],
  304. credential_id=args["credential_id"],
  305. )
  306. return {"result": "success"}, 204
  307. class ModelProviderModelCredentialSwitchApi(Resource):
  308. @setup_required
  309. @login_required
  310. @account_initialization_required
  311. def post(self, provider: str):
  312. if not current_user.is_admin_or_owner:
  313. raise Forbidden()
  314. parser = reqparse.RequestParser()
  315. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  316. parser.add_argument(
  317. "model_type",
  318. type=str,
  319. required=True,
  320. nullable=False,
  321. choices=[mt.value for mt in ModelType],
  322. location="json",
  323. )
  324. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  325. args = parser.parse_args()
  326. service = ModelProviderService()
  327. service.add_model_credential_to_model_list(
  328. tenant_id=current_user.current_tenant_id,
  329. provider=provider,
  330. model_type=args["model_type"],
  331. model=args["model"],
  332. credential_id=args["credential_id"],
  333. )
  334. return {"result": "success"}
  335. class ModelProviderModelEnableApi(Resource):
  336. @setup_required
  337. @login_required
  338. @account_initialization_required
  339. def patch(self, provider: str):
  340. tenant_id = current_user.current_tenant_id
  341. parser = reqparse.RequestParser()
  342. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  343. parser.add_argument(
  344. "model_type",
  345. type=str,
  346. required=True,
  347. nullable=False,
  348. choices=[mt.value for mt in ModelType],
  349. location="json",
  350. )
  351. args = parser.parse_args()
  352. model_provider_service = ModelProviderService()
  353. model_provider_service.enable_model(
  354. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  355. )
  356. return {"result": "success"}
  357. class ModelProviderModelDisableApi(Resource):
  358. @setup_required
  359. @login_required
  360. @account_initialization_required
  361. def patch(self, provider: str):
  362. tenant_id = current_user.current_tenant_id
  363. parser = reqparse.RequestParser()
  364. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  365. parser.add_argument(
  366. "model_type",
  367. type=str,
  368. required=True,
  369. nullable=False,
  370. choices=[mt.value for mt in ModelType],
  371. location="json",
  372. )
  373. args = parser.parse_args()
  374. model_provider_service = ModelProviderService()
  375. model_provider_service.disable_model(
  376. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  377. )
  378. return {"result": "success"}
  379. class ModelProviderModelValidateApi(Resource):
  380. @setup_required
  381. @login_required
  382. @account_initialization_required
  383. def post(self, provider: str):
  384. tenant_id = current_user.current_tenant_id
  385. parser = reqparse.RequestParser()
  386. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  387. parser.add_argument(
  388. "model_type",
  389. type=str,
  390. required=True,
  391. nullable=False,
  392. choices=[mt.value for mt in ModelType],
  393. location="json",
  394. )
  395. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  396. args = parser.parse_args()
  397. model_provider_service = ModelProviderService()
  398. result = True
  399. error = ""
  400. try:
  401. model_provider_service.validate_model_credentials(
  402. tenant_id=tenant_id,
  403. provider=provider,
  404. model=args["model"],
  405. model_type=args["model_type"],
  406. credentials=args["credentials"],
  407. )
  408. except CredentialsValidateFailedError as ex:
  409. result = False
  410. error = str(ex)
  411. response = {"result": "success" if result else "error"}
  412. if not result:
  413. response["error"] = error or ""
  414. return response
  415. class ModelProviderModelParameterRuleApi(Resource):
  416. @setup_required
  417. @login_required
  418. @account_initialization_required
  419. def get(self, provider: str):
  420. parser = reqparse.RequestParser()
  421. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  422. args = parser.parse_args()
  423. tenant_id = current_user.current_tenant_id
  424. model_provider_service = ModelProviderService()
  425. parameter_rules = model_provider_service.get_model_parameter_rules(
  426. tenant_id=tenant_id, provider=provider, model=args["model"]
  427. )
  428. return jsonable_encoder({"data": parameter_rules})
  429. class ModelProviderAvailableModelApi(Resource):
  430. @setup_required
  431. @login_required
  432. @account_initialization_required
  433. def get(self, model_type):
  434. tenant_id = current_user.current_tenant_id
  435. model_provider_service = ModelProviderService()
  436. models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
  437. return jsonable_encoder({"data": models})
  438. api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
  439. api.add_resource(
  440. ModelProviderModelEnableApi,
  441. "/workspaces/current/model-providers/<path:provider>/models/enable",
  442. endpoint="model-provider-model-enable",
  443. )
  444. api.add_resource(
  445. ModelProviderModelDisableApi,
  446. "/workspaces/current/model-providers/<path:provider>/models/disable",
  447. endpoint="model-provider-model-disable",
  448. )
  449. api.add_resource(
  450. ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
  451. )
  452. api.add_resource(
  453. ModelProviderModelCredentialSwitchApi,
  454. "/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
  455. )
  456. api.add_resource(
  457. ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
  458. )
  459. api.add_resource(
  460. ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
  461. )
  462. api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
  463. api.add_resource(DefaultModelApi, "/workspaces/current/default-model")