| import enum | import enum | ||||
| import importlib | |||||
| import importlib.util | |||||
| import json | import json | ||||
| import logging | import logging | ||||
| import os | import os | ||||
| # Dynamic loading {subdir_name}.py file and find the subclass of Extensible | # Dynamic loading {subdir_name}.py file and find the subclass of Extensible | ||||
| py_path = os.path.join(subdir_path, extension_name + '.py') | py_path = os.path.join(subdir_path, extension_name + '.py') | ||||
| spec = importlib.util.spec_from_file_location(extension_name, py_path) | spec = importlib.util.spec_from_file_location(extension_name, py_path) | ||||
| if not spec or not spec.loader: | |||||
| raise Exception(f"Failed to load module {extension_name} from {py_path}") | |||||
| mod = importlib.util.module_from_spec(spec) | mod = importlib.util.module_from_spec(spec) | ||||
| spec.loader.exec_module(mod) | spec.loader.exec_module(mod) | ||||
| position=position | position=position | ||||
| )) | )) | ||||
| sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name) | |||||
| sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name) | |||||
| return sorted_extensions | return sorted_extensions |
| from typing import AnyStr | from typing import AnyStr | ||||
| def import_module_from_source( | |||||
| module_name: str, | |||||
| py_file_path: AnyStr, | |||||
| use_lazy_loader: bool = False | |||||
| ) -> ModuleType: | |||||
| def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType: | |||||
| """ | """ | ||||
| Importing a module from the source file directly | Importing a module from the source file directly | ||||
| """ | """ | ||||
| existed_spec = importlib.util.find_spec(module_name) | existed_spec = importlib.util.find_spec(module_name) | ||||
| if existed_spec: | if existed_spec: | ||||
| spec = existed_spec | spec = existed_spec | ||||
| if not spec.loader: | |||||
| raise Exception(f"Failed to load module {module_name} from {py_file_path}") | |||||
| else: | else: | ||||
| # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly | # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly | ||||
| spec = importlib.util.spec_from_file_location(module_name, py_file_path) | spec = importlib.util.spec_from_file_location(module_name, py_file_path) | ||||
| if not spec or not spec.loader: | |||||
| raise Exception(f"Failed to load module {module_name} from {py_file_path}") | |||||
| if use_lazy_loader: | if use_lazy_loader: | ||||
| # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports | # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports | ||||
| spec.loader = importlib.util.LazyLoader(spec.loader) | spec.loader = importlib.util.LazyLoader(spec.loader) | ||||
| spec.loader.exec_module(module) | spec.loader.exec_module(module) | ||||
| return module | return module | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}') | |||||
| logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}") | |||||
| raise e | raise e | ||||
| def load_single_subclass_from_source( | def load_single_subclass_from_source( | ||||
| module_name: str, | |||||
| script_path: AnyStr, | |||||
| parent_type: type, | |||||
| use_lazy_loader: bool = False, | |||||
| *, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False | |||||
| ) -> type: | ) -> type: | ||||
| """ | """ | ||||
| Load a single subclass from the source | Load a single subclass from the source | ||||
| """ | """ | ||||
| module = import_module_from_source(module_name, script_path, use_lazy_loader) | |||||
| module = import_module_from_source( | |||||
| module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader | |||||
| ) | |||||
| subclasses = get_subclasses_from_module(module, parent_type) | subclasses = get_subclasses_from_module(module, parent_type) | ||||
| match len(subclasses): | match len(subclasses): | ||||
| case 1: | case 1: |
| import os | import os | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from collections.abc import Callable | from collections.abc import Callable | ||||
| from typing import Any, AnyStr | |||||
| from typing import Any | |||||
| from core.tools.utils.yaml_utils import load_yaml_file | from core.tools.utils.yaml_utils import load_yaml_file | ||||
| def get_position_map( | |||||
| folder_path: AnyStr, | |||||
| file_name: str = '_position.yaml', | |||||
| ) -> dict[str, int]: | |||||
| def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]: | |||||
| """ | """ | ||||
| Get the mapping from name to index from a YAML file | Get the mapping from name to index from a YAML file | ||||
| :param folder_path: | :param folder_path: |
| import logging | import logging | ||||
| import os | import os | ||||
| from collections.abc import Generator | |||||
| from collections.abc import Callable, Generator | |||||
| from typing import IO, Optional, Union, cast | from typing import IO, Optional, Union, cast | ||||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle | from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle | ||||
| def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | ||||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | ||||
| stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ | |||||
| stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ | |||||
| -> Union[LLMResult, Generator]: | -> Union[LLMResult, Generator]: | ||||
| """ | """ | ||||
| Invoke large language model | Invoke large language model | ||||
| streaming=streaming | streaming=streaming | ||||
| ) | ) | ||||
| def _round_robin_invoke(self, function: callable, *args, **kwargs): | |||||
| def _round_robin_invoke(self, function: Callable, *args, **kwargs): | |||||
| """ | """ | ||||
| Round-robin invoke | Round-robin invoke | ||||
| :param function: function to invoke | :param function: function to invoke | ||||
| while True: | while True: | ||||
| current_index = redis_client.incr(cache_key) | current_index = redis_client.incr(cache_key) | ||||
| current_index = cast(int, current_index) | |||||
| if current_index >= 10000000: | if current_index >= 10000000: | ||||
| current_index = 1 | current_index = 1 | ||||
| redis_client.set(cache_key, current_index) | redis_client.set(cache_key, current_index) | ||||
| config.id | config.id | ||||
| ) | ) | ||||
| return redis_client.exists(cooldown_cache_key) | |||||
| res = redis_client.exists(cooldown_cache_key) | |||||
| res = cast(bool, res) | |||||
| return res | |||||
| @classmethod | @classmethod | ||||
| def get_config_in_cooldown_and_ttl(cls, tenant_id: str, | def get_config_in_cooldown_and_ttl(cls, tenant_id: str, | ||||
| if ttl == -2: | if ttl == -2: | ||||
| return False, 0 | return False, 0 | ||||
| ttl = cast(int, ttl) | |||||
| return True, ttl | return True, ttl |
| from collections.abc import Sequence | |||||
| from enum import Enum | from enum import Enum | ||||
| from typing import Optional | from typing import Optional | ||||
| from pydantic import BaseModel, ConfigDict | from pydantic import BaseModel, ConfigDict | ||||
| from core.model_runtime.entities.common_entities import I18nObject | from core.model_runtime.entities.common_entities import I18nObject | ||||
| from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel | |||||
| from core.model_runtime.entities.model_entities import ModelType, ProviderModel | |||||
| class ConfigurateMethod(Enum): | class ConfigurateMethod(Enum): | ||||
| label: I18nObject | label: I18nObject | ||||
| icon_small: Optional[I18nObject] = None | icon_small: Optional[I18nObject] = None | ||||
| icon_large: Optional[I18nObject] = None | icon_large: Optional[I18nObject] = None | ||||
| supported_model_types: list[ModelType] | |||||
| models: list[AIModelEntity] = [] | |||||
| supported_model_types: Sequence[ModelType] | |||||
| models: list[ProviderModel] = [] | |||||
| class ProviderHelpEntity(BaseModel): | class ProviderHelpEntity(BaseModel): | ||||
| icon_large: Optional[I18nObject] = None | icon_large: Optional[I18nObject] = None | ||||
| background: Optional[str] = None | background: Optional[str] = None | ||||
| help: Optional[ProviderHelpEntity] = None | help: Optional[ProviderHelpEntity] = None | ||||
| supported_model_types: list[ModelType] | |||||
| supported_model_types: Sequence[ModelType] | |||||
| configurate_methods: list[ConfigurateMethod] | configurate_methods: list[ConfigurateMethod] | ||||
| models: list[ProviderModel] = [] | models: list[ProviderModel] = [] | ||||
| provider_credential_schema: Optional[ProviderCredentialSchema] = None | provider_credential_schema: Optional[ProviderCredentialSchema] = None |
| import decimal | import decimal | ||||
| import os | import os | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from collections.abc import Mapping | |||||
| from typing import Optional | from typing import Optional | ||||
| from pydantic import ConfigDict | from pydantic import ConfigDict | ||||
| """ | """ | ||||
| Base class for all models. | Base class for all models. | ||||
| """ | """ | ||||
| model_type: ModelType | model_type: ModelType | ||||
| model_schemas: list[AIModelEntity] = None | |||||
| model_schemas: Optional[list[AIModelEntity]] = None | |||||
| started_at: float = 0 | started_at: float = 0 | ||||
| # pydantic configs | # pydantic configs | ||||
| model_config = ConfigDict(protected_namespaces=()) | model_config = ConfigDict(protected_namespaces=()) | ||||
| @abstractmethod | @abstractmethod | ||||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||||
| def validate_credentials(self, model: str, credentials: Mapping) -> None: | |||||
| """ | """ | ||||
| Validate model credentials | Validate model credentials | ||||
| # get price info from predefined model schema | # get price info from predefined model schema | ||||
| price_config: Optional[PriceConfig] = None | price_config: Optional[PriceConfig] = None | ||||
| if model_schema: | |||||
| price_config: PriceConfig = model_schema.pricing | |||||
| if model_schema and model_schema.pricing: | |||||
| price_config = model_schema.pricing | |||||
| # get unit price | # get unit price | ||||
| unit_price = None | unit_price = None | ||||
| if unit_price is None: | if unit_price is None: | ||||
| return PriceInfo( | return PriceInfo( | ||||
| unit_price=decimal.Decimal('0.0'), | |||||
| unit=decimal.Decimal('0.0'), | |||||
| total_amount=decimal.Decimal('0.0'), | |||||
| unit_price=decimal.Decimal("0.0"), | |||||
| unit=decimal.Decimal("0.0"), | |||||
| total_amount=decimal.Decimal("0.0"), | |||||
| currency="USD", | currency="USD", | ||||
| ) | ) | ||||
| # calculate total amount | # calculate total amount | ||||
| if not price_config: | |||||
| raise ValueError(f"Price config not found for model {model}") | |||||
| total_amount = tokens * unit_price * price_config.unit | total_amount = tokens * unit_price * price_config.unit | ||||
| total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | ||||
| return model_schemas | return model_schemas | ||||
| def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]: | |||||
| def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]: | |||||
| """ | """ | ||||
| Get model schema by model name and credentials | Get model schema by model name and credentials | ||||
| return None | return None | ||||
| def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | |||||
| def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: | |||||
| """ | """ | ||||
| Get customizable model schema from credentials | Get customizable model schema from credentials | ||||
| :return: model schema | :return: model schema | ||||
| """ | """ | ||||
| return self._get_customizable_model_schema(model, credentials) | return self._get_customizable_model_schema(model, credentials) | ||||
| def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | |||||
| def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: | |||||
| """ | """ | ||||
| Get customizable model schema and fill in the template | Get customizable model schema and fill in the template | ||||
| """ | """ | ||||
| if not schema: | if not schema: | ||||
| return None | return None | ||||
| # fill in the template | # fill in the template | ||||
| new_parameter_rules = [] | new_parameter_rules = [] | ||||
| for parameter_rule in schema.parameter_rules: | for parameter_rule in schema.parameter_rules: | ||||
| parameter_rule.help = I18nObject( | parameter_rule.help = I18nObject( | ||||
| en_US=default_parameter_rule['help']['en_US'], | en_US=default_parameter_rule['help']['en_US'], | ||||
| ) | ) | ||||
| if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']): | |||||
| parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] | |||||
| if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']): | |||||
| parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) | |||||
| if ( | |||||
| parameter_rule.help | |||||
| and not parameter_rule.help.en_US | |||||
| and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"]) | |||||
| ): | |||||
| parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"] | |||||
| if ( | |||||
| parameter_rule.help | |||||
| and not parameter_rule.help.zh_Hans | |||||
| and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"]) | |||||
| ): | |||||
| parameter_rule.help.zh_Hans = default_parameter_rule["help"].get( | |||||
| "zh_Hans", default_parameter_rule["help"]["en_US"] | |||||
| ) | |||||
| except ValueError: | except ValueError: | ||||
| pass | pass | ||||
| return schema | return schema | ||||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | |||||
| def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]: | |||||
| """ | """ | ||||
| Get customizable model schema | Get customizable model schema | ||||
| default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) | default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) | ||||
| if not default_parameter_rule: | if not default_parameter_rule: | ||||
| raise Exception(f'Invalid model parameter rule name {name}') | |||||
| raise Exception(f"Invalid model parameter rule name {name}") | |||||
| return default_parameter_rule | return default_parameter_rule | ||||
| :param text: plain text of prompt. You need to convert the original message to plain text | :param text: plain text of prompt. You need to convert the original message to plain text | ||||
| :return: number of tokens | :return: number of tokens | ||||
| """ | """ | ||||
| return GPT2Tokenizer.get_num_tokens(text) | |||||
| return GPT2Tokenizer.get_num_tokens(text) |
| import re | import re | ||||
| import time | import time | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| from collections.abc import Generator | |||||
| from collections.abc import Generator, Mapping | |||||
| from typing import Optional, Union | from typing import Optional, Union | ||||
| from pydantic import ConfigDict | from pydantic import ConfigDict | ||||
| def invoke(self, model: str, credentials: dict, | def invoke(self, model: str, credentials: dict, | ||||
| prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | ||||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | ||||
| stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ | |||||
| stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \ | |||||
| -> Union[LLMResult, Generator]: | -> Union[LLMResult, Generator]: | ||||
| """ | """ | ||||
| Invoke large language model | Invoke large language model | ||||
| user=user, | user=user, | ||||
| callbacks=callbacks | callbacks=callbacks | ||||
| ) | ) | ||||
| else: | |||||
| elif isinstance(result, LLMResult): | |||||
| self._trigger_after_invoke_callbacks( | self._trigger_after_invoke_callbacks( | ||||
| model=model, | model=model, | ||||
| result=result, | result=result, | ||||
| def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], | ||||
| model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, | model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, | stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, | ||||
| callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: | |||||
| callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]: | |||||
| """ | """ | ||||
| Code block mode wrapper, ensure the response is a code block with output markdown quote | Code block mode wrapper, ensure the response is a code block with output markdown quote | ||||
| # override the system message | # override the system message | ||||
| prompt_messages[0] = SystemPromptMessage( | prompt_messages[0] = SystemPromptMessage( | ||||
| content=block_prompts | content=block_prompts | ||||
| .replace("{{instructions}}", prompt_messages[0].content) | |||||
| .replace("{{instructions}}", str(prompt_messages[0].content)) | |||||
| ) | ) | ||||
| else: | else: | ||||
| # insert the system message | # insert the system message | ||||
| else: | else: | ||||
| yield piece | yield piece | ||||
| continue | continue | ||||
| new_piece = "" | |||||
| new_piece: str = "" | |||||
| for char in piece: | for char in piece: | ||||
| char = str(char) | |||||
| if state == "normal": | if state == "normal": | ||||
| if char == "`": | if char == "`": | ||||
| state = "in_backticks" | state = "in_backticks" | ||||
| if state == "done": | if state == "done": | ||||
| continue | continue | ||||
| new_piece = "" | |||||
| new_piece: str = "" | |||||
| for char in piece: | for char in piece: | ||||
| if state == "search_start": | if state == "search_start": | ||||
| if char == "`": | if char == "`": | ||||
| # If backticks were counted but we're still collecting content, it was a false start | # If backticks were counted but we're still collecting content, it was a false start | ||||
| new_piece += "`" * backtick_count | new_piece += "`" * backtick_count | ||||
| backtick_count = 0 | backtick_count = 0 | ||||
| new_piece += char | |||||
| new_piece += str(char) | |||||
| elif state == "done": | elif state == "done": | ||||
| break | break | ||||
| prompt_messages: list[PromptMessage], model_parameters: dict, | prompt_messages: list[PromptMessage], model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, stream: bool = True, | stop: Optional[list[str]] = None, stream: bool = True, | ||||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: | |||||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator: | |||||
| """ | """ | ||||
| Invoke result generator | Invoke result generator | ||||
| :param result: result generator | :param result: result generator | ||||
| :return: result generator | :return: result generator | ||||
| """ | """ | ||||
| callbacks = callbacks or [] | |||||
| prompt_message = AssistantPromptMessage( | prompt_message = AssistantPromptMessage( | ||||
| content="" | content="" | ||||
| ) | ) | ||||
| def _llm_result_to_stream(self, result: LLMResult) -> Generator: | def _llm_result_to_stream(self, result: LLMResult) -> Generator: | ||||
| """ | """ | ||||
| from typing_extensions import deprecated | |||||
| Transform llm result to stream | Transform llm result to stream | ||||
| :param result: llm result | :param result: llm result | ||||
| return [] | return [] | ||||
| def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode: | |||||
| def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode: | |||||
| """ | """ | ||||
| Get model mode | Get model mode | ||||
| prompt_messages: list[PromptMessage], model_parameters: dict, | prompt_messages: list[PromptMessage], model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, stream: bool = True, | stop: Optional[list[str]] = None, stream: bool = True, | ||||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: | |||||
| """ | """ | ||||
| Trigger before invoke callbacks | Trigger before invoke callbacks | ||||
| prompt_messages: list[PromptMessage], model_parameters: dict, | prompt_messages: list[PromptMessage], model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, stream: bool = True, | stop: Optional[list[str]] = None, stream: bool = True, | ||||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: | |||||
| """ | """ | ||||
| Trigger new chunk callbacks | Trigger new chunk callbacks | ||||
| prompt_messages: list[PromptMessage], model_parameters: dict, | prompt_messages: list[PromptMessage], model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, stream: bool = True, | stop: Optional[list[str]] = None, stream: bool = True, | ||||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: | |||||
| """ | """ | ||||
| Trigger after invoke callbacks | Trigger after invoke callbacks | ||||
| prompt_messages: list[PromptMessage], model_parameters: dict, | prompt_messages: list[PromptMessage], model_parameters: dict, | ||||
| tools: Optional[list[PromptMessageTool]] = None, | tools: Optional[list[PromptMessageTool]] = None, | ||||
| stop: Optional[list[str]] = None, stream: bool = True, | stop: Optional[list[str]] = None, stream: bool = True, | ||||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||||
| user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None: | |||||
| """ | """ | ||||
| Trigger invoke error callbacks | Trigger invoke error callbacks | ||||
| import os | import os | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Optional | |||||
| from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source | from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source | ||||
| from core.model_runtime.entities.model_entities import AIModelEntity, ModelType | from core.model_runtime.entities.model_entities import AIModelEntity, ModelType | ||||
| class ModelProvider(ABC): | class ModelProvider(ABC): | ||||
| provider_schema: ProviderEntity = None | |||||
| provider_schema: Optional[ProviderEntity] = None | |||||
| model_instance_map: dict[str, AIModel] = {} | model_instance_map: dict[str, AIModel] = {} | ||||
| @abstractmethod | @abstractmethod | ||||
| def get_provider_schema(self) -> ProviderEntity: | def get_provider_schema(self) -> ProviderEntity: | ||||
| """ | """ | ||||
| Get provider schema | Get provider schema | ||||
| :return: provider schema | :return: provider schema | ||||
| """ | """ | ||||
| if self.provider_schema: | if self.provider_schema: | ||||
| return self.provider_schema | return self.provider_schema | ||||
| # get dirname of the current path | # get dirname of the current path | ||||
| provider_name = self.__class__.__module__.split('.')[-1] | provider_name = self.__class__.__module__.split('.')[-1] | ||||
| # get the path of the model_provider classes | # get the path of the model_provider classes | ||||
| base_path = os.path.abspath(__file__) | base_path = os.path.abspath(__file__) | ||||
| current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) | current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) | ||||
| # read provider schema from yaml file | # read provider schema from yaml file | ||||
| yaml_path = os.path.join(current_path, f'{provider_name}.yaml') | yaml_path = os.path.join(current_path, f'{provider_name}.yaml') | ||||
| yaml_data = load_yaml_file(yaml_path, ignore_error=True) | yaml_data = load_yaml_file(yaml_path, ignore_error=True) | ||||
| try: | try: | ||||
| # yaml_data to entity | # yaml_data to entity | ||||
| provider_schema = ProviderEntity(**yaml_data) | provider_schema = ProviderEntity(**yaml_data) | ||||
| # cache schema | # cache schema | ||||
| self.provider_schema = provider_schema | self.provider_schema = provider_schema | ||||
| return provider_schema | return provider_schema | ||||
| def models(self, model_type: ModelType) -> list[AIModelEntity]: | def models(self, model_type: ModelType) -> list[AIModelEntity]: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| # get dirname of the current path | # get dirname of the current path | ||||
| provider_name = self.__class__.__module__.split('.')[-1] | |||||
| provider_name = self.__class__.__module__.split(".")[-1] | |||||
| if f"{provider_name}.{model_type.value}" in self.model_instance_map: | if f"{provider_name}.{model_type.value}" in self.model_instance_map: | ||||
| return self.model_instance_map[f"{provider_name}.{model_type.value}"] | return self.model_instance_map[f"{provider_name}.{model_type.value}"] | ||||
| # Dynamic loading {model_type_name}.py file and find the subclass of AIModel | # Dynamic loading {model_type_name}.py file and find the subclass of AIModel | ||||
| parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) | parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) | ||||
| mod = import_module_from_source( | mod = import_module_from_source( | ||||
| f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path) | |||||
| model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, | |||||
| get_subclasses_from_module(mod, AIModel)), None) | |||||
| module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path | |||||
| ) | |||||
| model_class = next( | |||||
| filter( | |||||
| lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, | |||||
| get_subclasses_from_module(mod, AIModel), | |||||
| ), | |||||
| None, | |||||
| ) | |||||
| if not model_class: | if not model_class: | ||||
| raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}') | |||||
| raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}") | |||||
| model_instance_map = model_class() | model_instance_map = model_class() | ||||
| self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map | self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map |
| import logging | import logging | ||||
| import os | import os | ||||
| from collections.abc import Sequence | |||||
| from typing import Optional | from typing import Optional | ||||
| from pydantic import BaseModel, ConfigDict | from pydantic import BaseModel, ConfigDict | ||||
| class ModelProviderExtension(BaseModel): | class ModelProviderExtension(BaseModel): | ||||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||||
| provider_instance: ModelProvider | provider_instance: ModelProvider | ||||
| name: str | name: str | ||||
| position: Optional[int] = None | position: Optional[int] = None | ||||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||||
| class ModelProviderFactory: | class ModelProviderFactory: | ||||
| model_provider_extensions: dict[str, ModelProviderExtension] = None | |||||
| model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None | |||||
| def __init__(self) -> None: | def __init__(self) -> None: | ||||
| # for cache in memory | # for cache in memory | ||||
| self.get_providers() | self.get_providers() | ||||
| def get_providers(self) -> list[ProviderEntity]: | |||||
| def get_providers(self) -> Sequence[ProviderEntity]: | |||||
| """ | """ | ||||
| Get all providers | Get all providers | ||||
| :return: list of providers | :return: list of providers | ||||
| # traverse all model_provider_extensions | # traverse all model_provider_extensions | ||||
| providers = [] | providers = [] | ||||
| for name, model_provider_extension in model_provider_extensions.items(): | |||||
| for model_provider_extension in model_provider_extensions.values(): | |||||
| # get model_provider instance | # get model_provider instance | ||||
| model_provider_instance = model_provider_extension.provider_instance | model_provider_instance = model_provider_extension.provider_instance | ||||
| # return providers | # return providers | ||||
| return providers | return providers | ||||
| def provider_credentials_validate(self, provider: str, credentials: dict) -> dict: | |||||
| def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict: | |||||
| """ | """ | ||||
| Validate provider credentials | Validate provider credentials | ||||
| # get provider_credential_schema and validate credentials according to the rules | # get provider_credential_schema and validate credentials according to the rules | ||||
| provider_credential_schema = provider_schema.provider_credential_schema | provider_credential_schema = provider_schema.provider_credential_schema | ||||
| if not provider_credential_schema: | |||||
| raise ValueError(f"Provider {provider} does not have provider_credential_schema") | |||||
| # validate provider credential schema | # validate provider credential schema | ||||
| validator = ProviderCredentialSchemaValidator(provider_credential_schema) | validator = ProviderCredentialSchemaValidator(provider_credential_schema) | ||||
| filtered_credentials = validator.validate_and_filter(credentials) | filtered_credentials = validator.validate_and_filter(credentials) | ||||
| return filtered_credentials | return filtered_credentials | ||||
| def model_credentials_validate(self, provider: str, model_type: ModelType, | |||||
| model: str, credentials: dict) -> dict: | |||||
| def model_credentials_validate( | |||||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | |||||
| ) -> dict: | |||||
| """ | """ | ||||
| Validate model credentials | Validate model credentials | ||||
| # get model_credential_schema and validate credentials according to the rules | # get model_credential_schema and validate credentials according to the rules | ||||
| model_credential_schema = provider_schema.model_credential_schema | model_credential_schema = provider_schema.model_credential_schema | ||||
| if not model_credential_schema: | |||||
| raise ValueError(f"Provider {provider} does not have model_credential_schema") | |||||
| # validate model credential schema | # validate model credential schema | ||||
| validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) | validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) | ||||
| filtered_credentials = validator.validate_and_filter(credentials) | filtered_credentials = validator.validate_and_filter(credentials) | ||||
| return filtered_credentials | return filtered_credentials | ||||
| def get_models(self, | |||||
| provider: Optional[str] = None, | |||||
| model_type: Optional[ModelType] = None, | |||||
| provider_configs: Optional[list[ProviderConfig]] = None) \ | |||||
| -> list[SimpleProviderEntity]: | |||||
| def get_models( | |||||
| self, | |||||
| *, | |||||
| provider: Optional[str] = None, | |||||
| model_type: Optional[ModelType] = None, | |||||
| provider_configs: Optional[list[ProviderConfig]] = None, | |||||
| ) -> list[SimpleProviderEntity]: | |||||
| """ | """ | ||||
| Get all models for given model type | Get all models for given model type | ||||
| :param provider_configs: list of provider configs | :param provider_configs: list of provider configs | ||||
| :return: list of models | :return: list of models | ||||
| """ | """ | ||||
| provider_configs = provider_configs or [] | |||||
| # scan all providers | # scan all providers | ||||
| model_provider_extensions = self._get_model_provider_map() | model_provider_extensions = self._get_model_provider_map() | ||||
| # get the provider extension | # get the provider extension | ||||
| model_provider_extension = model_provider_extensions.get(provider) | model_provider_extension = model_provider_extensions.get(provider) | ||||
| if not model_provider_extension: | if not model_provider_extension: | ||||
| raise Exception(f'Invalid provider: {provider}') | |||||
| raise Exception(f"Invalid provider: {provider}") | |||||
| # get the provider instance | # get the provider instance | ||||
| model_provider_instance = model_provider_extension.provider_instance | model_provider_instance = model_provider_extension.provider_instance | ||||
| return model_provider_instance | return model_provider_instance | ||||
| def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: | def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: | ||||
| """ | |||||
| Retrieves the model provider map. | |||||
| This method retrieves the model provider map, which is a dictionary containing the model provider names as keys | |||||
| and instances of `ModelProviderExtension` as values. The model provider map is used to store information about | |||||
| available model providers. | |||||
| Returns: | |||||
| A dictionary containing the model provider map. | |||||
| Raises: | |||||
| None. | |||||
| """ | |||||
| if self.model_provider_extensions: | if self.model_provider_extensions: | ||||
| return self.model_provider_extensions | return self.model_provider_extensions | ||||
| # get the path of current classes | # get the path of current classes | ||||
| current_path = os.path.abspath(__file__) | current_path = os.path.abspath(__file__) | ||||
| model_providers_path = os.path.dirname(current_path) | model_providers_path = os.path.dirname(current_path) | ||||
| model_provider_dir_paths = [ | model_provider_dir_paths = [ | ||||
| os.path.join(model_providers_path, model_provider_dir) | os.path.join(model_providers_path, model_provider_dir) | ||||
| for model_provider_dir in os.listdir(model_providers_path) | for model_provider_dir in os.listdir(model_providers_path) | ||||
| if not model_provider_dir.startswith('__') | |||||
| and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) | |||||
| if not model_provider_dir.startswith("__") | |||||
| and os.path.isdir(os.path.join(model_providers_path, model_provider_dir)) | |||||
| ] | ] | ||||
| # get _position.yaml file path | # get _position.yaml file path | ||||
| file_names = os.listdir(model_provider_dir_path) | file_names = os.listdir(model_provider_dir_path) | ||||
| if (model_provider_name + '.py') not in file_names: | |||||
| if (model_provider_name + ".py") not in file_names: | |||||
| logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.") | logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.") | ||||
| continue | continue | ||||
| # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider | # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider | ||||
| py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py') | |||||
| py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py") | |||||
| model_provider_class = load_single_subclass_from_source( | model_provider_class = load_single_subclass_from_source( | ||||
| module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}', | |||||
| module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}", | |||||
| script_path=py_path, | script_path=py_path, | ||||
| parent_type=ModelProvider) | |||||
| parent_type=ModelProvider, | |||||
| ) | |||||
| if not model_provider_class: | if not model_provider_class: | ||||
| logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") | logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") | ||||
| continue | continue | ||||
| if f'{model_provider_name}.yaml' not in file_names: | |||||
| if f"{model_provider_name}.yaml" not in file_names: | |||||
| logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") | logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") | ||||
| continue | continue | ||||
| model_providers.append(ModelProviderExtension( | |||||
| name=model_provider_name, | |||||
| provider_instance=model_provider_class(), | |||||
| position=position_map.get(model_provider_name) | |||||
| )) | |||||
| model_providers.append( | |||||
| ModelProviderExtension( | |||||
| name=model_provider_name, | |||||
| provider_instance=model_provider_class(), | |||||
| position=position_map.get(model_provider_name), | |||||
| ) | |||||
| ) | |||||
| sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) | sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) | ||||
| from collections.abc import Mapping | |||||
| import openai | import openai | ||||
| from httpx import Timeout | from httpx import Timeout | ||||
| class _CommonOpenAI: | class _CommonOpenAI: | ||||
| def _to_credential_kwargs(self, credentials: dict) -> dict: | |||||
| def _to_credential_kwargs(self, credentials: Mapping) -> dict: | |||||
| """ | """ | ||||
| Transform credentials to kwargs for model instance | Transform credentials to kwargs for model instance | ||||
| "max_retries": 1, | "max_retries": 1, | ||||
| } | } | ||||
| if credentials.get('openai_api_base'): | |||||
| credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/') | |||||
| credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1' | |||||
| if credentials.get("openai_api_base"): | |||||
| openai_api_base = credentials["openai_api_base"].rstrip("/") | |||||
| credentials_kwargs["base_url"] = openai_api_base + "/v1" | |||||
| if 'openai_organization' in credentials: | if 'openai_organization' in credentials: | ||||
| credentials_kwargs['organization'] = credentials['openai_organization'] | credentials_kwargs['organization'] = credentials['openai_organization'] | ||||
| :return: Invoke error mapping | :return: Invoke error mapping | ||||
| """ | """ | ||||
| return { | return { | ||||
| InvokeConnectionError: [ | |||||
| openai.APIConnectionError, | |||||
| openai.APITimeoutError | |||||
| ], | |||||
| InvokeServerUnavailableError: [ | |||||
| openai.InternalServerError | |||||
| ], | |||||
| InvokeRateLimitError: [ | |||||
| openai.RateLimitError | |||||
| ], | |||||
| InvokeAuthorizationError: [ | |||||
| openai.AuthenticationError, | |||||
| openai.PermissionDeniedError | |||||
| ], | |||||
| InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError], | |||||
| InvokeServerUnavailableError: [openai.InternalServerError], | |||||
| InvokeRateLimitError: [openai.RateLimitError], | |||||
| InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError], | |||||
| InvokeBadRequestError: [ | InvokeBadRequestError: [ | ||||
| openai.BadRequestError, | openai.BadRequestError, | ||||
| openai.NotFoundError, | openai.NotFoundError, | ||||
| openai.UnprocessableEntityError, | openai.UnprocessableEntityError, | ||||
| openai.APIError | |||||
| ] | |||||
| openai.APIError, | |||||
| ], | |||||
| } | } |
| import logging | import logging | ||||
| from collections.abc import Mapping | |||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||||
| class OpenAIProvider(ModelProvider): | class OpenAIProvider(ModelProvider): | ||||
| def validate_provider_credentials(self, credentials: dict) -> None: | |||||
| def validate_provider_credentials(self, credentials: Mapping) -> None: | |||||
| """ | """ | ||||
| Validate provider credentials | Validate provider credentials | ||||
| if validate failed, raise exception | if validate failed, raise exception |