您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

__init__.py 2.1KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # AFTER UPDATING THIS FILE, PLEASE ENSURE THAT docs/references/supported_models.mdx IS ALSO UPDATED for consistency!
  17. #
  18. import importlib
  19. import inspect
  20. ChatModel = globals().get("ChatModel", {})
  21. CvModel = globals().get("CvModel", {})
  22. EmbeddingModel = globals().get("EmbeddingModel", {})
  23. RerankModel = globals().get("RerankModel", {})
  24. Seq2txtModel = globals().get("Seq2txtModel", {})
  25. TTSModel = globals().get("TTSModel", {})
  26. MODULE_MAPPING = {
  27. "chat_model": ChatModel,
  28. "cv_model": CvModel,
  29. "embedding_model": EmbeddingModel,
  30. "rerank_model": RerankModel,
  31. "sequence2txt_model": Seq2txtModel,
  32. "tts_model": TTSModel,
  33. }
  34. package_name = __name__
  35. for module_name, mapping_dict in MODULE_MAPPING.items():
  36. full_module_name = f"{package_name}.{module_name}"
  37. module = importlib.import_module(full_module_name)
  38. base_class = None
  39. for name, obj in inspect.getmembers(module):
  40. if inspect.isclass(obj) and name == "Base":
  41. base_class = obj
  42. break
  43. if base_class is None:
  44. continue
  45. for _, obj in inspect.getmembers(module):
  46. if inspect.isclass(obj) and issubclass(obj, base_class) and obj is not base_class and hasattr(obj, "_FACTORY_NAME"):
  47. if isinstance(obj._FACTORY_NAME, list):
  48. for factory_name in obj._FACTORY_NAME:
  49. mapping_dict[factory_name] = obj
  50. else:
  51. mapping_dict[obj._FACTORY_NAME] = obj
  52. __all__ = [
  53. "ChatModel",
  54. "CvModel",
  55. "EmbeddingModel",
  56. "RerankModel",
  57. "Seq2txtModel",
  58. "TTSModel",
  59. ]