You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
  17. #
  18. import importlib
  19. import inspect
  20. from strenum import StrEnum
  21. class SupportedLiteLLMProvider(StrEnum):
  22. Tongyi_Qianwen = "Tongyi-Qianwen"
  23. Dashscope = "Dashscope"
  24. Bedrock = "Bedrock"
  25. Moonshot = "Moonshot"
  26. xAI = "xAI"
  27. DeepInfra = "DeepInfra"
  28. Groq = "Groq"
  29. Cohere = "Cohere"
  30. Gemini = "Gemini"
  31. DeepSeek = "DeepSeek"
  32. Nvidia = "NVIDIA"
  33. TogetherAI = "TogetherAI"
  34. Anthropic = "Anthropic"
  35. FACTORY_DEFAULT_BASE_URL = {
  36. SupportedLiteLLMProvider.Tongyi_Qianwen: "https://dashscope.aliyuncs.com/compatible-mode/v1",
  37. SupportedLiteLLMProvider.Dashscope: "https://dashscope.aliyuncs.com/compatible-mode/v1",
  38. SupportedLiteLLMProvider.Moonshot: "https://api.moonshot.cn/v1",
  39. }
  40. LITELLM_PROVIDER_PREFIX = {
  41. SupportedLiteLLMProvider.Tongyi_Qianwen: "dashscope/",
  42. SupportedLiteLLMProvider.Dashscope: "dashscope/",
  43. SupportedLiteLLMProvider.Bedrock: "bedrock/",
  44. SupportedLiteLLMProvider.Moonshot: "moonshot/",
  45. SupportedLiteLLMProvider.xAI: "xai/",
  46. SupportedLiteLLMProvider.DeepInfra: "deepinfra/",
  47. SupportedLiteLLMProvider.Groq: "groq/",
  48. SupportedLiteLLMProvider.Cohere: "", # don't need a prefix
  49. SupportedLiteLLMProvider.Gemini: "gemini/",
  50. SupportedLiteLLMProvider.DeepSeek: "deepseek/",
  51. SupportedLiteLLMProvider.Nvidia: "nvidia_nim/",
  52. SupportedLiteLLMProvider.TogetherAI: "together_ai/",
  53. SupportedLiteLLMProvider.Anthropic: "", # don't need a prefix
  54. }
  55. ChatModel = globals().get("ChatModel", {})
  56. CvModel = globals().get("CvModel", {})
  57. EmbeddingModel = globals().get("EmbeddingModel", {})
  58. RerankModel = globals().get("RerankModel", {})
  59. Seq2txtModel = globals().get("Seq2txtModel", {})
  60. TTSModel = globals().get("TTSModel", {})
  61. MODULE_MAPPING = {
  62. "chat_model": ChatModel,
  63. "cv_model": CvModel,
  64. "embedding_model": EmbeddingModel,
  65. "rerank_model": RerankModel,
  66. "sequence2txt_model": Seq2txtModel,
  67. "tts_model": TTSModel,
  68. }
  69. package_name = __name__
  70. for module_name, mapping_dict in MODULE_MAPPING.items():
  71. full_module_name = f"{package_name}.{module_name}"
  72. module = importlib.import_module(full_module_name)
  73. base_class = None
  74. lite_llm_base_class = None
  75. for name, obj in inspect.getmembers(module):
  76. if inspect.isclass(obj):
  77. if name == "Base":
  78. base_class = obj
  79. elif name == "LiteLLMBase":
  80. lite_llm_base_class = obj
  81. assert hasattr(obj, "_FACTORY_NAME"), "LiteLLMbase should have _FACTORY_NAME field."
  82. if hasattr(obj, "_FACTORY_NAME"):
  83. if isinstance(obj._FACTORY_NAME, list):
  84. for factory_name in obj._FACTORY_NAME:
  85. mapping_dict[factory_name] = obj
  86. else:
  87. mapping_dict[obj._FACTORY_NAME] = obj
  88. if base_class is not None:
  89. for _, obj in inspect.getmembers(module):
  90. if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
  91. if isinstance(obj._FACTORY_NAME, list):
  92. for factory_name in obj._FACTORY_NAME:
  93. mapping_dict[factory_name] = obj
  94. else:
  95. mapping_dict[obj._FACTORY_NAME] = obj
  96. __all__ = [
  97. "ChatModel",
  98. "CvModel",
  99. "EmbeddingModel",
  100. "RerankModel",
  101. "Seq2txtModel",
  102. "TTSModel",
  103. ]