| @@ -1,5 +1,5 @@ | |||
| import enum | |||
| import importlib | |||
| import importlib.util | |||
| import json | |||
| import logging | |||
| import os | |||
| @@ -74,6 +74,8 @@ class Extensible: | |||
| # Dynamic loading {subdir_name}.py file and find the subclass of Extensible | |||
| py_path = os.path.join(subdir_path, extension_name + '.py') | |||
| spec = importlib.util.spec_from_file_location(extension_name, py_path) | |||
| if not spec or not spec.loader: | |||
| raise Exception(f"Failed to load module {extension_name} from {py_path}") | |||
| mod = importlib.util.module_from_spec(spec) | |||
| spec.loader.exec_module(mod) | |||
| @@ -108,6 +110,6 @@ class Extensible: | |||
| position=position | |||
| )) | |||
| sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name) | |||
| sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) | |||
| return sorted_extensions | |||
| @@ -5,11 +5,7 @@ from types import ModuleType | |||
| from typing import AnyStr | |||
| def import_module_from_source( | |||
| module_name: str, | |||
| py_file_path: AnyStr, | |||
| use_lazy_loader: bool = False | |||
| ) -> ModuleType: | |||
| def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType: | |||
| """ | |||
| Importing a module from the source file directly | |||
| """ | |||
| @@ -17,9 +13,13 @@ def import_module_from_source( | |||
| existed_spec = importlib.util.find_spec(module_name) | |||
| if existed_spec: | |||
| spec = existed_spec | |||
| if not spec.loader: | |||
| raise Exception(f"Failed to load module {module_name} from {py_file_path}") | |||
| else: | |||
| # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly | |||
| spec = importlib.util.spec_from_file_location(module_name, py_file_path) | |||
| if not spec or not spec.loader: | |||
| raise Exception(f"Failed to load module {module_name} from {py_file_path}") | |||
| if use_lazy_loader: | |||
| # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports | |||
| spec.loader = importlib.util.LazyLoader(spec.loader) | |||
| @@ -29,7 +29,7 @@ def import_module_from_source( | |||
| spec.loader.exec_module(module) | |||
| return module | |||
| except Exception as e: | |||
| logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}') | |||
| logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}") | |||
| raise e | |||
| @@ -43,15 +43,14 @@ def get_subclasses_from_module(mod: ModuleType, parent_type: type) -> list[type] | |||
| def load_single_subclass_from_source( | |||
| module_name: str, | |||
| script_path: AnyStr, | |||
| parent_type: type, | |||
| use_lazy_loader: bool = False, | |||
| *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False | |||
| ) -> type: | |||
| """ | |||
| Load a single subclass from the source | |||
| """ | |||
| module = import_module_from_source(module_name, script_path, use_lazy_loader) | |||
| module = import_module_from_source( | |||
| module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader | |||
| ) | |||
| subclasses = get_subclasses_from_module(module, parent_type) | |||
| match len(subclasses): | |||
| case 1: | |||
| @@ -1,15 +1,12 @@ | |||
| import os | |||
| from collections import OrderedDict | |||
| from collections.abc import Callable | |||
| from typing import Any, AnyStr | |||
| from typing import Any | |||
| from core.tools.utils.yaml_utils import load_yaml_file | |||
| def get_position_map( | |||
| folder_path: AnyStr, | |||
| file_name: str = '_position.yaml', | |||
| ) -> dict[str, int]: | |||
| def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: | |||
| """ | |||
| Get the mapping from name to index from a YAML file | |||
| :param folder_path: | |||
| @@ -1,6 +1,6 @@ | |||
| import logging | |||
| import os | |||
| from collections.abc import Generator | |||
| from collections.abc import Callable, Generator | |||
| from typing import IO, Optional, Union, cast | |||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle | |||
| @@ -102,7 +102,7 @@ class ModelInstance: | |||
| def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ | |||
| stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| """ | |||
| Invoke large language model | |||
| @@ -291,7 +291,7 @@ class ModelInstance: | |||
| streaming=streaming | |||
| ) | |||
| def _round_robin_invoke(self, function: callable, *args, **kwargs): | |||
| def _round_robin_invoke(self, function: Callable, *args, **kwargs): | |||
| """ | |||
| Round-robin invoke | |||
| :param function: function to invoke | |||
| @@ -437,6 +437,7 @@ class LBModelManager: | |||
| while True: | |||
| current_index = redis_client.incr(cache_key) | |||
| current_index = cast(int, current_index) | |||
| if current_index >= 10000000: | |||
| current_index = 1 | |||
| redis_client.set(cache_key, current_index) | |||
| @@ -499,7 +500,10 @@ class LBModelManager: | |||
| config.id | |||
| ) | |||
| return redis_client.exists(cooldown_cache_key) | |||
| res = redis_client.exists(cooldown_cache_key) | |||
| res = cast(bool, res) | |||
| return res | |||
| @classmethod | |||
| def get_config_in_cooldown_and_ttl(cls, tenant_id: str, | |||
| @@ -528,4 +532,5 @@ class LBModelManager: | |||
| if ttl == -2: | |||
| return False, 0 | |||
| ttl = cast(int, ttl) | |||
| return True, ttl | |||
| @@ -1,10 +1,11 @@ | |||
| from collections.abc import Sequence | |||
| from enum import Enum | |||
| from typing import Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel | |||
| from core.model_runtime.entities.model_entities import ModelType, ProviderModel | |||
| class ConfigurateMethod(Enum): | |||
| @@ -93,8 +94,8 @@ class SimpleProviderEntity(BaseModel): | |||
| label: I18nObject | |||
| icon_small: Optional[I18nObject] = None | |||
| icon_large: Optional[I18nObject] = None | |||
| supported_model_types: list[ModelType] | |||
| models: list[AIModelEntity] = [] | |||
| supported_model_types: Sequence[ModelType] | |||
| models: list[ProviderModel] = [] | |||
| class ProviderHelpEntity(BaseModel): | |||
| @@ -116,7 +117,7 @@ class ProviderEntity(BaseModel): | |||
| icon_large: Optional[I18nObject] = None | |||
| background: Optional[str] = None | |||
| help: Optional[ProviderHelpEntity] = None | |||
| supported_model_types: list[ModelType] | |||
| supported_model_types: Sequence[ModelType] | |||
| configurate_methods: list[ConfigurateMethod] | |||
| models: list[ProviderModel] = [] | |||
| provider_credential_schema: Optional[ProviderCredentialSchema] = None | |||
| @@ -1,6 +1,7 @@ | |||
| import decimal | |||
| import os | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Mapping | |||
| from typing import Optional | |||
| from pydantic import ConfigDict | |||
| @@ -26,15 +27,16 @@ class AIModel(ABC): | |||
| """ | |||
| Base class for all models. | |||
| """ | |||
| model_type: ModelType | |||
| model_schemas: list[AIModelEntity] = None | |||
| model_schemas: Optional[list[AIModelEntity]] = None | |||
| started_at: float = 0 | |||
| # pydantic configs | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| @abstractmethod | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| def validate_credentials(self, model: str, credentials: Mapping) -> None: | |||
| """ | |||
| Validate model credentials | |||
| @@ -90,8 +92,8 @@ class AIModel(ABC): | |||
| # get price info from predefined model schema | |||
| price_config: Optional[PriceConfig] = None | |||
| if model_schema: | |||
| price_config: PriceConfig = model_schema.pricing | |||
| if model_schema and model_schema.pricing: | |||
| price_config = model_schema.pricing | |||
| # get unit price | |||
| unit_price = None | |||
| @@ -103,13 +105,15 @@ class AIModel(ABC): | |||
| if unit_price is None: | |||
| return PriceInfo( | |||
| unit_price=decimal.Decimal('0.0'), | |||
| unit=decimal.Decimal('0.0'), | |||
| total_amount=decimal.Decimal('0.0'), | |||
| unit_price=decimal.Decimal("0.0"), | |||
| unit=decimal.Decimal("0.0"), | |||
| total_amount=decimal.Decimal("0.0"), | |||
| currency="USD", | |||
| ) | |||
| # calculate total amount | |||
| if not price_config: | |||
| raise ValueError(f"Price config not found for model {model}") | |||
| total_amount = tokens * unit_price * price_config.unit | |||
| total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| @@ -209,7 +213,7 @@ class AIModel(ABC): | |||
| return model_schemas | |||
| def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: | |||
| def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]: | |||
| """ | |||
| Get model schema by model name and credentials | |||
| @@ -231,7 +235,7 @@ class AIModel(ABC): | |||
| return None | |||
| def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | |||
| def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: | |||
| """ | |||
| Get customizable model schema from credentials | |||
| @@ -240,8 +244,8 @@ class AIModel(ABC): | |||
| :return: model schema | |||
| """ | |||
| return self._get_customizable_model_schema(model, credentials) | |||
| def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | |||
| def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: | |||
| """ | |||
| Get customizable model schema and fill in the template | |||
| """ | |||
| @@ -249,7 +253,7 @@ class AIModel(ABC): | |||
| if not schema: | |||
| return None | |||
| # fill in the template | |||
| new_parameter_rules = [] | |||
| for parameter_rule in schema.parameter_rules: | |||
| @@ -271,10 +275,20 @@ class AIModel(ABC): | |||
| parameter_rule.help = I18nObject( | |||
| en_US=default_parameter_rule['help']['en_US'], | |||
| ) | |||
| if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']): | |||
| parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] | |||
| if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']): | |||
| parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) | |||
| if ( | |||
| parameter_rule.help | |||
| and not parameter_rule.help.en_US | |||
| and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"]) | |||
| ): | |||
| parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"] | |||
| if ( | |||
| parameter_rule.help | |||
| and not parameter_rule.help.zh_Hans | |||
| and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"]) | |||
| ): | |||
| parameter_rule.help.zh_Hans = default_parameter_rule["help"].get( | |||
| "zh_Hans", default_parameter_rule["help"]["en_US"] | |||
| ) | |||
| except ValueError: | |||
| pass | |||
| @@ -284,7 +298,7 @@ class AIModel(ABC): | |||
| return schema | |||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | |||
| def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: | |||
| """ | |||
| Get customizable model schema | |||
| @@ -304,7 +318,7 @@ class AIModel(ABC): | |||
| default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) | |||
| if not default_parameter_rule: | |||
| raise Exception(f'Invalid model parameter rule name {name}') | |||
| raise Exception(f"Invalid model parameter rule name {name}") | |||
| return default_parameter_rule | |||
| @@ -318,4 +332,4 @@ class AIModel(ABC): | |||
| :param text: plain text of prompt. You need to convert the original message to plain text | |||
| :return: number of tokens | |||
| """ | |||
| return GPT2Tokenizer.get_num_tokens(text) | |||
| return GPT2Tokenizer.get_num_tokens(text) | |||
| @@ -3,7 +3,7 @@ import os | |||
| import re | |||
| import time | |||
| from abc import abstractmethod | |||
| from collections.abc import Generator | |||
| from collections.abc import Generator, Mapping | |||
| from typing import Optional, Union | |||
| from pydantic import ConfigDict | |||
| @@ -43,7 +43,7 @@ class LargeLanguageModel(AIModel): | |||
| def invoke(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ | |||
| stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| """ | |||
| Invoke large language model | |||
| @@ -129,7 +129,7 @@ class LargeLanguageModel(AIModel): | |||
| user=user, | |||
| callbacks=callbacks | |||
| ) | |||
| else: | |||
| elif isinstance(result, LLMResult): | |||
| self._trigger_after_invoke_callbacks( | |||
| model=model, | |||
| result=result, | |||
| @@ -148,7 +148,7 @@ class LargeLanguageModel(AIModel): | |||
| def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, | |||
| callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: | |||
| callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: | |||
| """ | |||
| Code block mode wrapper, ensure the response is a code block with output markdown quote | |||
| @@ -196,7 +196,7 @@ if you are not sure about the structure. | |||
| # override the system message | |||
| prompt_messages[0] = SystemPromptMessage( | |||
| content=block_prompts | |||
| .replace("{{instructions}}", prompt_messages[0].content) | |||
| .replace("{{instructions}}", str(prompt_messages[0].content)) | |||
| ) | |||
| else: | |||
| # insert the system message | |||
| @@ -274,8 +274,9 @@ if you are not sure about the structure. | |||
| else: | |||
| yield piece | |||
| continue | |||
| new_piece = "" | |||
| new_piece: str = "" | |||
| for char in piece: | |||
| char = str(char) | |||
| if state == "normal": | |||
| if char == "`": | |||
| state = "in_backticks" | |||
| @@ -340,7 +341,7 @@ if you are not sure about the structure. | |||
| if state == "done": | |||
| continue | |||
| new_piece = "" | |||
| new_piece: str = "" | |||
| for char in piece: | |||
| if state == "search_start": | |||
| if char == "`": | |||
| @@ -365,7 +366,7 @@ if you are not sure about the structure. | |||
| # If backticks were counted but we're still collecting content, it was a false start | |||
| new_piece += "`" * backtick_count | |||
| backtick_count = 0 | |||
| new_piece += char | |||
| new_piece += str(char) | |||
| elif state == "done": | |||
| break | |||
| @@ -388,13 +389,14 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: | |||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator: | |||
| """ | |||
| Invoke result generator | |||
| :param result: result generator | |||
| :return: result generator | |||
| """ | |||
| callbacks = callbacks or [] | |||
| prompt_message = AssistantPromptMessage( | |||
| content="" | |||
| ) | |||
| @@ -489,6 +491,7 @@ if you are not sure about the structure. | |||
| def _llm_result_to_stream(self, result: LLMResult) -> Generator: | |||
| """ | |||
| from typing_extensions import deprecated | |||
| Transform llm result to stream | |||
| :param result: llm result | |||
| @@ -531,7 +534,7 @@ if you are not sure about the structure. | |||
| return [] | |||
| def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode: | |||
| def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode: | |||
| """ | |||
| Get model mode | |||
| @@ -595,7 +598,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: | |||
| """ | |||
| Trigger before invoke callbacks | |||
| @@ -633,7 +636,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: | |||
| """ | |||
| Trigger new chunk callbacks | |||
| @@ -672,7 +675,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: | |||
| """ | |||
| Trigger after invoke callbacks | |||
| @@ -712,7 +715,7 @@ if you are not sure about the structure. | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: | |||
| """ | |||
| Trigger invoke error callbacks | |||
| @@ -1,5 +1,6 @@ | |||
| import os | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional | |||
| from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, ModelType | |||
| @@ -9,7 +10,7 @@ from core.tools.utils.yaml_utils import load_yaml_file | |||
| class ModelProvider(ABC): | |||
| provider_schema: ProviderEntity = None | |||
| provider_schema: Optional[ProviderEntity] = None | |||
| model_instance_map: dict[str, AIModel] = {} | |||
| @abstractmethod | |||
| @@ -28,23 +29,23 @@ class ModelProvider(ABC): | |||
| def get_provider_schema(self) -> ProviderEntity: | |||
| """ | |||
| Get provider schema | |||
| :return: provider schema | |||
| """ | |||
| if self.provider_schema: | |||
| return self.provider_schema | |||
| # get dirname of the current path | |||
| provider_name = self.__class__.__module__.split('.')[-1] | |||
| # get the path of the model_provider classes | |||
| base_path = os.path.abspath(__file__) | |||
| current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) | |||
| # read provider schema from yaml file | |||
| yaml_path = os.path.join(current_path, f'{provider_name}.yaml') | |||
| yaml_data = load_yaml_file(yaml_path, ignore_error=True) | |||
| try: | |||
| # yaml_data to entity | |||
| provider_schema = ProviderEntity(**yaml_data) | |||
| @@ -53,7 +54,7 @@ class ModelProvider(ABC): | |||
| # cache schema | |||
| self.provider_schema = provider_schema | |||
| return provider_schema | |||
| def models(self, model_type: ModelType) -> list[AIModelEntity]: | |||
| @@ -84,7 +85,7 @@ class ModelProvider(ABC): | |||
| :return: | |||
| """ | |||
| # get dirname of the current path | |||
| provider_name = self.__class__.__module__.split('.')[-1] | |||
| provider_name = self.__class__.__module__.split(".")[-1] | |||
| if f"{provider_name}.{model_type.value}" in self.model_instance_map: | |||
| return self.model_instance_map[f"{provider_name}.{model_type.value}"] | |||
| @@ -101,11 +102,17 @@ class ModelProvider(ABC): | |||
| # Dynamic loading {model_type_name}.py file and find the subclass of AIModel | |||
| parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) | |||
| mod = import_module_from_source( | |||
| f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path) | |||
| model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, | |||
| get_subclasses_from_module(mod, AIModel)), None) | |||
| module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path | |||
| ) | |||
| model_class = next( | |||
| filter( | |||
| lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, | |||
| get_subclasses_from_module(mod, AIModel), | |||
| ), | |||
| None, | |||
| ) | |||
| if not model_class: | |||
| raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}') | |||
| raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}") | |||
| model_instance_map = model_class() | |||
| self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map | |||
| @@ -1,5 +1,6 @@ | |||
| import logging | |||
| import os | |||
| from collections.abc import Sequence | |||
| from typing import Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| @@ -16,20 +17,21 @@ logger = logging.getLogger(__name__) | |||
| class ModelProviderExtension(BaseModel): | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| provider_instance: ModelProvider | |||
| name: str | |||
| position: Optional[int] = None | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| class ModelProviderFactory: | |||
| model_provider_extensions: dict[str, ModelProviderExtension] = None | |||
| model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None | |||
| def __init__(self) -> None: | |||
| # for cache in memory | |||
| self.get_providers() | |||
| def get_providers(self) -> list[ProviderEntity]: | |||
| def get_providers(self) -> Sequence[ProviderEntity]: | |||
| """ | |||
| Get all providers | |||
| :return: list of providers | |||
| @@ -39,7 +41,7 @@ class ModelProviderFactory: | |||
| # traverse all model_provider_extensions | |||
| providers = [] | |||
| for name, model_provider_extension in model_provider_extensions.items(): | |||
| for model_provider_extension in model_provider_extensions.values(): | |||
| # get model_provider instance | |||
| model_provider_instance = model_provider_extension.provider_instance | |||
| @@ -57,7 +59,7 @@ class ModelProviderFactory: | |||
| # return providers | |||
| return providers | |||
| def provider_credentials_validate(self, provider: str, credentials: dict) -> dict: | |||
| def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict: | |||
| """ | |||
| Validate provider credentials | |||
| @@ -74,6 +76,9 @@ class ModelProviderFactory: | |||
| # get provider_credential_schema and validate credentials according to the rules | |||
| provider_credential_schema = provider_schema.provider_credential_schema | |||
| if not provider_credential_schema: | |||
| raise ValueError(f"Provider {provider} does not have provider_credential_schema") | |||
| # validate provider credential schema | |||
| validator = ProviderCredentialSchemaValidator(provider_credential_schema) | |||
| filtered_credentials = validator.validate_and_filter(credentials) | |||
| @@ -83,8 +88,9 @@ class ModelProviderFactory: | |||
| return filtered_credentials | |||
| def model_credentials_validate(self, provider: str, model_type: ModelType, | |||
| model: str, credentials: dict) -> dict: | |||
| def model_credentials_validate( | |||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | |||
| ) -> dict: | |||
| """ | |||
| Validate model credentials | |||
| @@ -103,6 +109,9 @@ class ModelProviderFactory: | |||
| # get model_credential_schema and validate credentials according to the rules | |||
| model_credential_schema = provider_schema.model_credential_schema | |||
| if not model_credential_schema: | |||
| raise ValueError(f"Provider {provider} does not have model_credential_schema") | |||
| # validate model credential schema | |||
| validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) | |||
| filtered_credentials = validator.validate_and_filter(credentials) | |||
| @@ -115,11 +124,13 @@ class ModelProviderFactory: | |||
| return filtered_credentials | |||
| def get_models(self, | |||
| provider: Optional[str] = None, | |||
| model_type: Optional[ModelType] = None, | |||
| provider_configs: Optional[list[ProviderConfig]] = None) \ | |||
| -> list[SimpleProviderEntity]: | |||
| def get_models( | |||
| self, | |||
| *, | |||
| provider: Optional[str] = None, | |||
| model_type: Optional[ModelType] = None, | |||
| provider_configs: Optional[list[ProviderConfig]] = None, | |||
| ) -> list[SimpleProviderEntity]: | |||
| """ | |||
| Get all models for given model type | |||
| @@ -128,6 +139,8 @@ class ModelProviderFactory: | |||
| :param provider_configs: list of provider configs | |||
| :return: list of models | |||
| """ | |||
| provider_configs = provider_configs or [] | |||
| # scan all providers | |||
| model_provider_extensions = self._get_model_provider_map() | |||
| @@ -184,7 +197,7 @@ class ModelProviderFactory: | |||
| # get the provider extension | |||
| model_provider_extension = model_provider_extensions.get(provider) | |||
| if not model_provider_extension: | |||
| raise Exception(f'Invalid provider: {provider}') | |||
| raise Exception(f"Invalid provider: {provider}") | |||
| # get the provider instance | |||
| model_provider_instance = model_provider_extension.provider_instance | |||
| @@ -192,10 +205,22 @@ class ModelProviderFactory: | |||
| return model_provider_instance | |||
| def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: | |||
| """ | |||
| Retrieves the model provider map. | |||
| This method retrieves the model provider map, which is a dictionary containing the model provider names as keys | |||
| and instances of `ModelProviderExtension` as values. The model provider map is used to store information about | |||
| available model providers. | |||
| Returns: | |||
| A dictionary containing the model provider map. | |||
| Raises: | |||
| None. | |||
| """ | |||
| if self.model_provider_extensions: | |||
| return self.model_provider_extensions | |||
| # get the path of current classes | |||
| current_path = os.path.abspath(__file__) | |||
| model_providers_path = os.path.dirname(current_path) | |||
| @@ -204,8 +229,8 @@ class ModelProviderFactory: | |||
| model_provider_dir_paths = [ | |||
| os.path.join(model_providers_path, model_provider_dir) | |||
| for model_provider_dir in os.listdir(model_providers_path) | |||
| if not model_provider_dir.startswith('__') | |||
| and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) | |||
| if not model_provider_dir.startswith("__") | |||
| and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) | |||
| ] | |||
| # get _position.yaml file path | |||
| @@ -219,30 +244,33 @@ class ModelProviderFactory: | |||
| file_names = os.listdir(model_provider_dir_path) | |||
| if (model_provider_name + '.py') not in file_names: | |||
| if (model_provider_name + ".py") not in file_names: | |||
| logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.") | |||
| continue | |||
| # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider | |||
| py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py') | |||
| py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py") | |||
| model_provider_class = load_single_subclass_from_source( | |||
| module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', | |||
| module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}", | |||
| script_path=py_path, | |||
| parent_type=ModelProvider) | |||
| parent_type=ModelProvider, | |||
| ) | |||
| if not model_provider_class: | |||
| logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") | |||
| continue | |||
| if f'{model_provider_name}.yaml' not in file_names: | |||
| if f"{model_provider_name}.yaml" not in file_names: | |||
| logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") | |||
| continue | |||
| model_providers.append(ModelProviderExtension( | |||
| name=model_provider_name, | |||
| provider_instance=model_provider_class(), | |||
| position=position_map.get(model_provider_name) | |||
| )) | |||
| model_providers.append( | |||
| ModelProviderExtension( | |||
| name=model_provider_name, | |||
| provider_instance=model_provider_class(), | |||
| position=position_map.get(model_provider_name), | |||
| ) | |||
| ) | |||
| sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) | |||
| @@ -1,3 +1,5 @@ | |||
| from collections.abc import Mapping | |||
| import openai | |||
| from httpx import Timeout | |||
| @@ -12,7 +14,7 @@ from core.model_runtime.errors.invoke import ( | |||
| class _CommonOpenAI: | |||
| def _to_credential_kwargs(self, credentials: dict) -> dict: | |||
| def _to_credential_kwargs(self, credentials: Mapping) -> dict: | |||
| """ | |||
| Transform credentials to kwargs for model instance | |||
| @@ -25,9 +27,9 @@ class _CommonOpenAI: | |||
| "max_retries": 1, | |||
| } | |||
| if credentials.get('openai_api_base'): | |||
| credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/') | |||
| credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1' | |||
| if credentials.get("openai_api_base"): | |||
| openai_api_base = credentials["openai_api_base"].rstrip("/") | |||
| credentials_kwargs["base_url"] = openai_api_base + "/v1" | |||
| if 'openai_organization' in credentials: | |||
| credentials_kwargs['organization'] = credentials['openai_organization'] | |||
| @@ -45,24 +47,14 @@ class _CommonOpenAI: | |||
| :return: Invoke error mapping | |||
| """ | |||
| return { | |||
| InvokeConnectionError: [ | |||
| openai.APIConnectionError, | |||
| openai.APITimeoutError | |||
| ], | |||
| InvokeServerUnavailableError: [ | |||
| openai.InternalServerError | |||
| ], | |||
| InvokeRateLimitError: [ | |||
| openai.RateLimitError | |||
| ], | |||
| InvokeAuthorizationError: [ | |||
| openai.AuthenticationError, | |||
| openai.PermissionDeniedError | |||
| ], | |||
| InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], | |||
| InvokeServerUnavailableError: [openai.InternalServerError], | |||
| InvokeRateLimitError: [openai.RateLimitError], | |||
| InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], | |||
| InvokeBadRequestError: [ | |||
| openai.BadRequestError, | |||
| openai.NotFoundError, | |||
| openai.UnprocessableEntityError, | |||
| openai.APIError | |||
| ] | |||
| openai.APIError, | |||
| ], | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| import logging | |||
| from collections.abc import Mapping | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| @@ -9,7 +10,7 @@ logger = logging.getLogger(__name__) | |||
| class OpenAIProvider(ModelProvider): | |||
| def validate_provider_credentials(self, credentials: dict) -> None: | |||
| def validate_provider_credentials(self, credentials: Mapping) -> None: | |||
| """ | |||
| Validate provider credentials | |||
| if validate failed, raise exception | |||