Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
  17. #
  18. import importlib
  19. import inspect
  20. from strenum import StrEnum
  21. class SupportedLiteLLMProvider(StrEnum):
  22. Tongyi_Qianwen = "Tongyi-Qianwen"
  23. Dashscope = "Dashscope"
  24. Bedrock = "Bedrock"
  25. Moonshot = "Moonshot"
  26. xAI = "xAI"
  27. DeepInfra = "DeepInfra"
  28. Groq = "Groq"
  29. Cohere = "Cohere"
  30. Gemini = "Gemini"
  31. DeepSeek = "DeepSeek"
  32. Nvidia = "NVIDIA"
  33. TogetherAI = "TogetherAI"
  34. Anthropic = "Anthropic"
  35. Ollama = "Ollama"
  36. FACTORY_DEFAULT_BASE_URL = {
  37. SupportedLiteLLMProvider.Tongyi_Qianwen: "https://dashscope.aliyuncs.com/compatible-mode/v1",
  38. SupportedLiteLLMProvider.Dashscope: "https://dashscope.aliyuncs.com/compatible-mode/v1",
  39. SupportedLiteLLMProvider.Moonshot: "https://api.moonshot.cn/v1",
  40. }
  41. LITELLM_PROVIDER_PREFIX = {
  42. SupportedLiteLLMProvider.Tongyi_Qianwen: "dashscope/",
  43. SupportedLiteLLMProvider.Dashscope: "dashscope/",
  44. SupportedLiteLLMProvider.Bedrock: "bedrock/",
  45. SupportedLiteLLMProvider.Moonshot: "moonshot/",
  46. SupportedLiteLLMProvider.xAI: "xai/",
  47. SupportedLiteLLMProvider.DeepInfra: "deepinfra/",
  48. SupportedLiteLLMProvider.Groq: "groq/",
  49. SupportedLiteLLMProvider.Cohere: "", # don't need a prefix
  50. SupportedLiteLLMProvider.Gemini: "gemini/",
  51. SupportedLiteLLMProvider.DeepSeek: "deepseek/",
  52. SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
  53. SupportedLiteLLMProvider.TogetherAI: "together_ai/",
  54. SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
  55. SupportedLiteLLMProvider.Ollama: "ollama_chat/",
  56. }
  57. ChatModel = globals().get("ChatModel", {})
  58. CvModel = globals().get("CvModel", {})
  59. EmbeddingModel = globals().get("EmbeddingModel", {})
  60. RerankModel = globals().get("RerankModel", {})
  61. Seq2txtModel = globals().get("Seq2txtModel", {})
  62. TTSModel = globals().get("TTSModel", {})
  63. MODULE_MAPPING = {
  64. "chat_model": ChatModel,
  65. "cv_model": CvModel,
  66. "embedding_model": EmbeddingModel,
  67. "rerank_model": RerankModel,
  68. "sequence2txt_model": Seq2txtModel,
  69. "tts_model": TTSModel,
  70. }
  71. package_name = __name__
  72. for module_name, mapping_dict in MODULE_MAPPING.items():
  73. full_module_name = f"{package_name}.{module_name}"
  74. module = importlib.import_module(full_module_name)
  75. base_class = None
  76. lite_llm_base_class = None
  77. for name, obj in inspect.getmembers(module):
  78. if inspect.isclass(obj):
  79. if name == "Base":
  80. base_class = obj
  81. elif name == "LiteLLMBase":
  82. lite_llm_base_class = obj
  83. assert hasattr(obj, "_FACTORY_NAME"), "LiteLLMbase should have _FACTORY_NAME field."
  84. if hasattr(obj, "_FACTORY_NAME"):
  85. if isinstance(obj._FACTORY_NAME, list):
  86. for factory_name in obj._FACTORY_NAME:
  87. mapping_dict[factory_name] = obj
  88. else:
  89. mapping_dict[obj._FACTORY_NAME] = obj
  90. if base_class is not None:
  91. for _, obj in inspect.getmembers(module):
  92. if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
  93. if isinstance(obj._FACTORY_NAME, list):
  94. for factory_name in obj._FACTORY_NAME:
  95. mapping_dict[factory_name] = obj
  96. else:
  97. mapping_dict[obj._FACTORY_NAME] = obj
  98. __all__ = [
  99. "ChatModel",
  100. "CvModel",
  101. "EmbeddingModel",
  102. "RerankModel",
  103. "Seq2txtModel",
  104. "TTSModel",
  105. ]