您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

openai_provider.py 1.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import logging
  2. from typing import Optional, Union
  3. import openai
  4. from openai.error import AuthenticationError, OpenAIError
  5. from core.llm.moderation import Moderation
  6. from core.llm.provider.base import BaseProvider
  7. from core.llm.provider.errors import ValidateFailedError
  8. from models.provider import ProviderName
  9. class OpenAIProvider(BaseProvider):
  10. def get_models(self, model_id: Optional[str] = None) -> list[dict]:
  11. credentials = self.get_credentials(model_id)
  12. response = openai.Model.list(**credentials)
  13. return [{
  14. 'id': model['id'],
  15. 'name': model['id'],
  16. } for model in response['data']]
  17. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  18. """
  19. Returns the credentials for the given tenant_id and provider_name.
  20. """
  21. return {
  22. 'openai_api_key': self.get_provider_api_key(model_id=model_id)
  23. }
  24. def get_provider_name(self):
  25. return ProviderName.OPENAI
  26. def config_validate(self, config: Union[dict | str]):
  27. """
  28. Validates the given config.
  29. """
  30. try:
  31. Moderation(self.get_provider_name().value, config).moderate('test')
  32. except (AuthenticationError, OpenAIError) as ex:
  33. raise ValidateFailedError(str(ex))
  34. except Exception as ex:
  35. logging.exception('OpenAI config validation failed')
  36. raise ex