Преглед на файлове

fix(core): Fix incorrect type hints. (#5427)

tags/0.6.12
-LAN- преди 1 година
родител
ревизия
23fa3dedc4
No account linked to committer's email address

+ 4
- 2
api/core/extension/extensible.py Целия файл

import enum import enum
import importlib
import importlib.util
import json import json
import logging import logging
import os import os
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py') py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path) spec = importlib.util.spec_from_file_location(extension_name, py_path)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}")
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod)


position=position position=position
)) ))


sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)
sorted_extensions = sort_to_dict_by_position_map(position_map=position_map, data=extensions, name_func=lambda x: x.name)


return sorted_extensions return sorted_extensions

+ 10
- 11
api/core/helper/module_import_helper.py Целия файл

from typing import AnyStr from typing import AnyStr




def import_module_from_source(
module_name: str,
py_file_path: AnyStr,
use_lazy_loader: bool = False
) -> ModuleType:
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
""" """
Importing a module from the source file directly Importing a module from the source file directly
""" """
existed_spec = importlib.util.find_spec(module_name) existed_spec = importlib.util.find_spec(module_name)
if existed_spec: if existed_spec:
spec = existed_spec spec = existed_spec
if not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
else: else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(module_name, py_file_path) spec = importlib.util.spec_from_file_location(module_name, py_file_path)
if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path}")
if use_lazy_loader: if use_lazy_loader:
# Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports # Refer to: https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
spec.loader = importlib.util.LazyLoader(spec.loader) spec.loader = importlib.util.LazyLoader(spec.loader)
spec.loader.exec_module(module) spec.loader.exec_module(module)
return module return module
except Exception as e: except Exception as e:
logging.exception(f'Failed to load module {module_name} from {py_file_path}: {str(e)}')
logging.exception(f"Failed to load module {module_name} from {py_file_path}: {str(e)}")
raise e raise e








def load_single_subclass_from_source( def load_single_subclass_from_source(
module_name: str,
script_path: AnyStr,
parent_type: type,
use_lazy_loader: bool = False,
*, module_name: str, script_path: AnyStr, parent_type: type, use_lazy_loader: bool = False
) -> type: ) -> type:
""" """
Load a single subclass from the source Load a single subclass from the source
""" """
module = import_module_from_source(module_name, script_path, use_lazy_loader)
module = import_module_from_source(
module_name=module_name, py_file_path=script_path, use_lazy_loader=use_lazy_loader
)
subclasses = get_subclasses_from_module(module, parent_type) subclasses = get_subclasses_from_module(module, parent_type)
match len(subclasses): match len(subclasses):
case 1: case 1:

+ 2
- 5
api/core/helper/position_helper.py Целия файл

import os import os
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Callable from collections.abc import Callable
from typing import Any, AnyStr
from typing import Any


from core.tools.utils.yaml_utils import load_yaml_file from core.tools.utils.yaml_utils import load_yaml_file




def get_position_map(
folder_path: AnyStr,
file_name: str = '_position.yaml',
) -> dict[str, int]:
def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") -> dict[str, int]:
""" """
Get the mapping from name to index from a YAML file Get the mapping from name to index from a YAML file
:param folder_path: :param folder_path:

+ 9
- 4
api/core/model_manager.py Целия файл

import logging import logging
import os import os
from collections.abc import Generator
from collections.abc import Callable, Generator
from typing import IO, Optional, Union, cast from typing import IO, Optional, Union, cast


from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle


def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model
streaming=streaming streaming=streaming
) )


def _round_robin_invoke(self, function: callable, *args, **kwargs):
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
""" """
Round-robin invoke Round-robin invoke
:param function: function to invoke :param function: function to invoke


while True: while True:
current_index = redis_client.incr(cache_key) current_index = redis_client.incr(cache_key)
current_index = cast(int, current_index)
if current_index >= 10000000: if current_index >= 10000000:
current_index = 1 current_index = 1
redis_client.set(cache_key, current_index) redis_client.set(cache_key, current_index)
config.id config.id
) )


return redis_client.exists(cooldown_cache_key)

res = redis_client.exists(cooldown_cache_key)
res = cast(bool, res)
return res


@classmethod @classmethod
def get_config_in_cooldown_and_ttl(cls, tenant_id: str, def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
if ttl == -2: if ttl == -2:
return False, 0 return False, 0


ttl = cast(int, ttl)
return True, ttl return True, ttl

+ 5
- 4
api/core/model_runtime/entities/provider_entities.py Целия файл

from collections.abc import Sequence
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional


from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict


from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel
from core.model_runtime.entities.model_entities import ModelType, ProviderModel




class ConfigurateMethod(Enum): class ConfigurateMethod(Enum):
label: I18nObject label: I18nObject
icon_small: Optional[I18nObject] = None icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType]
models: list[AIModelEntity] = []
supported_model_types: Sequence[ModelType]
models: list[ProviderModel] = []




class ProviderHelpEntity(BaseModel): class ProviderHelpEntity(BaseModel):
icon_large: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None
background: Optional[str] = None background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None help: Optional[ProviderHelpEntity] = None
supported_model_types: list[ModelType]
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod] configurate_methods: list[ConfigurateMethod]
models: list[ProviderModel] = [] models: list[ProviderModel] = []
provider_credential_schema: Optional[ProviderCredentialSchema] = None provider_credential_schema: Optional[ProviderCredentialSchema] = None

+ 33
- 19
api/core/model_runtime/model_providers/__base/ai_model.py Целия файл

import decimal import decimal
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Optional from typing import Optional


from pydantic import ConfigDict from pydantic import ConfigDict
""" """
Base class for all models. Base class for all models.
""" """

model_type: ModelType model_type: ModelType
model_schemas: list[AIModelEntity] = None
model_schemas: Optional[list[AIModelEntity]] = None
started_at: float = 0 started_at: float = 0


# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())


@abstractmethod @abstractmethod
def validate_credentials(self, model: str, credentials: dict) -> None:
def validate_credentials(self, model: str, credentials: Mapping) -> None:
""" """
Validate model credentials Validate model credentials




# get price info from predefined model schema # get price info from predefined model schema
price_config: Optional[PriceConfig] = None price_config: Optional[PriceConfig] = None
if model_schema:
price_config: PriceConfig = model_schema.pricing
if model_schema and model_schema.pricing:
price_config = model_schema.pricing


# get unit price # get unit price
unit_price = None unit_price = None


if unit_price is None: if unit_price is None:
return PriceInfo( return PriceInfo(
unit_price=decimal.Decimal('0.0'),
unit=decimal.Decimal('0.0'),
total_amount=decimal.Decimal('0.0'),
unit_price=decimal.Decimal("0.0"),
unit=decimal.Decimal("0.0"),
total_amount=decimal.Decimal("0.0"),
currency="USD", currency="USD",
) )


# calculate total amount # calculate total amount
if not price_config:
raise ValueError(f"Price config not found for model {model}")
total_amount = tokens * unit_price * price_config.unit total_amount = tokens * unit_price * price_config.unit
total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) total_amount = total_amount.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)




return model_schemas return model_schemas


def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
def get_model_schema(self, model: str, credentials: Optional[Mapping] = None) -> Optional[AIModelEntity]:
""" """
Get model schema by model name and credentials Get model schema by model name and credentials




return None return None


def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema_from_credentials(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
""" """
Get customizable model schema from credentials Get customizable model schema from credentials


:return: model schema :return: model schema
""" """
return self._get_customizable_model_schema(model, credentials) return self._get_customizable_model_schema(model, credentials)
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def _get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
""" """
Get customizable model schema and fill in the template Get customizable model schema and fill in the template
""" """


if not schema: if not schema:
return None return None
# fill in the template # fill in the template
new_parameter_rules = [] new_parameter_rules = []
for parameter_rule in schema.parameter_rules: for parameter_rule in schema.parameter_rules:
parameter_rule.help = I18nObject( parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'], en_US=default_parameter_rule['help']['en_US'],
) )
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
if (
parameter_rule.help
and not parameter_rule.help.en_US
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
):
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
if (
parameter_rule.help
and not parameter_rule.help.zh_Hans
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
):
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
"zh_Hans", default_parameter_rule["help"]["en_US"]
)
except ValueError: except ValueError:
pass pass




return schema return schema


def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
def get_customizable_model_schema(self, model: str, credentials: Mapping) -> Optional[AIModelEntity]:
""" """
Get customizable model schema Get customizable model schema


default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)


if not default_parameter_rule: if not default_parameter_rule:
raise Exception(f'Invalid model parameter rule name {name}')
raise Exception(f"Invalid model parameter rule name {name}")


return default_parameter_rule return default_parameter_rule


:param text: plain text of prompt. You need to convert the original message to plain text :param text: plain text of prompt. You need to convert the original message to plain text
:return: number of tokens :return: number of tokens
""" """
return GPT2Tokenizer.get_num_tokens(text)
return GPT2Tokenizer.get_num_tokens(text)

+ 17
- 14
api/core/model_runtime/model_providers/__base/large_language_model.py Целия файл

import re import re
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Optional, Union


from pydantic import ConfigDict from pydantic import ConfigDict
def invoke(self, model: str, credentials: dict, def invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
-> Union[LLMResult, Generator]: -> Union[LLMResult, Generator]:
""" """
Invoke large language model Invoke large language model
user=user, user=user,
callbacks=callbacks callbacks=callbacks
) )
else:
elif isinstance(result, LLMResult):
self._trigger_after_invoke_callbacks( self._trigger_after_invoke_callbacks(
model=model, model=model,
result=result, result=result,
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
callbacks: Optional[list[Callback]] = None) -> Union[LLMResult, Generator]:
""" """
Code block mode wrapper, ensure the response is a code block with output markdown quote Code block mode wrapper, ensure the response is a code block with output markdown quote


# override the system message # override the system message
prompt_messages[0] = SystemPromptMessage( prompt_messages[0] = SystemPromptMessage(
content=block_prompts content=block_prompts
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{instructions}}", str(prompt_messages[0].content))
) )
else: else:
# insert the system message # insert the system message
else: else:
yield piece yield piece
continue continue
new_piece = ""
new_piece: str = ""
for char in piece: for char in piece:
char = str(char)
if state == "normal": if state == "normal":
if char == "`": if char == "`":
state = "in_backticks" state = "in_backticks"
if state == "done": if state == "done":
continue continue


new_piece = ""
new_piece: str = ""
for char in piece: for char in piece:
if state == "search_start": if state == "search_start":
if char == "`": if char == "`":
# If backticks were counted but we're still collecting content, it was a false start # If backticks were counted but we're still collecting content, it was a false start
new_piece += "`" * backtick_count new_piece += "`" * backtick_count
backtick_count = 0 backtick_count = 0
new_piece += char
new_piece += str(char)


elif state == "done": elif state == "done":
break break
prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> Generator:
""" """
Invoke result generator Invoke result generator


:param result: result generator :param result: result generator
:return: result generator :return: result generator
""" """
callbacks = callbacks or []
prompt_message = AssistantPromptMessage( prompt_message = AssistantPromptMessage(
content="" content=""
) )


def _llm_result_to_stream(self, result: LLMResult) -> Generator: def _llm_result_to_stream(self, result: LLMResult) -> Generator:
""" """
from typing_extensions import deprecated
Transform llm result to stream Transform llm result to stream


:param result: llm result :param result: llm result


return [] return []


def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
def get_model_mode(self, model: str, credentials: Optional[Mapping] = None) -> LLMMode:
""" """
Get model mode Get model mode


prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
""" """
Trigger before invoke callbacks Trigger before invoke callbacks


prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
""" """
Trigger new chunk callbacks Trigger new chunk callbacks


prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
""" """
Trigger after invoke callbacks Trigger after invoke callbacks


prompt_messages: list[PromptMessage], model_parameters: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None, callbacks: list[Callback] = None) -> None:
user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) -> None:
""" """
Trigger invoke error callbacks Trigger invoke error callbacks



+ 18
- 11
api/core/model_runtime/model_providers/__base/model_provider.py Целия файл

import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional


from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType




class ModelProvider(ABC): class ModelProvider(ABC):
provider_schema: ProviderEntity = None
provider_schema: Optional[ProviderEntity] = None
model_instance_map: dict[str, AIModel] = {} model_instance_map: dict[str, AIModel] = {}


@abstractmethod @abstractmethod
def get_provider_schema(self) -> ProviderEntity: def get_provider_schema(self) -> ProviderEntity:
""" """
Get provider schema Get provider schema
:return: provider schema :return: provider schema
""" """
if self.provider_schema: if self.provider_schema:
return self.provider_schema return self.provider_schema
# get dirname of the current path # get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1] provider_name = self.__class__.__module__.split('.')[-1]


# get the path of the model_provider classes # get the path of the model_provider classes
base_path = os.path.abspath(__file__) base_path = os.path.abspath(__file__)
current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name) current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
# read provider schema from yaml file # read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml') yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = load_yaml_file(yaml_path, ignore_error=True) yaml_data = load_yaml_file(yaml_path, ignore_error=True)
try: try:
# yaml_data to entity # yaml_data to entity
provider_schema = ProviderEntity(**yaml_data) provider_schema = ProviderEntity(**yaml_data)


# cache schema # cache schema
self.provider_schema = provider_schema self.provider_schema = provider_schema
return provider_schema return provider_schema


def models(self, model_type: ModelType) -> list[AIModelEntity]: def models(self, model_type: ModelType) -> list[AIModelEntity]:
:return: :return:
""" """
# get dirname of the current path # get dirname of the current path
provider_name = self.__class__.__module__.split('.')[-1]
provider_name = self.__class__.__module__.split(".")[-1]


if f"{provider_name}.{model_type.value}" in self.model_instance_map: if f"{provider_name}.{model_type.value}" in self.model_instance_map:
return self.model_instance_map[f"{provider_name}.{model_type.value}"] return self.model_instance_map[f"{provider_name}.{model_type.value}"]
# Dynamic loading {model_type_name}.py file and find the subclass of AIModel # Dynamic loading {model_type_name}.py file and find the subclass of AIModel
parent_module = '.'.join(self.__class__.__module__.split('.')[:-1]) parent_module = '.'.join(self.__class__.__module__.split('.')[:-1])
mod = import_module_from_source( mod = import_module_from_source(
f'{parent_module}.{model_type_name}.{model_type_name}', model_type_py_path)
model_class = next(filter(lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel)), None)
module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
)
model_class = next(
filter(
lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__,
get_subclasses_from_module(mod, AIModel),
),
None,
)
if not model_class: if not model_class:
raise Exception(f'Missing AIModel Class for model type {model_type} in {model_type_py_path}')
raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")


model_instance_map = model_class() model_instance_map = model_class()
self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map

+ 54
- 26
api/core/model_runtime/model_providers/model_provider_factory.py Целия файл

import logging import logging
import os import os
from collections.abc import Sequence
from typing import Optional from typing import Optional


from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict




class ModelProviderExtension(BaseModel): class ModelProviderExtension(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

provider_instance: ModelProvider provider_instance: ModelProvider
name: str name: str
position: Optional[int] = None position: Optional[int] = None
model_config = ConfigDict(arbitrary_types_allowed=True)




class ModelProviderFactory: class ModelProviderFactory:
model_provider_extensions: dict[str, ModelProviderExtension] = None
model_provider_extensions: Optional[dict[str, ModelProviderExtension]] = None


def __init__(self) -> None: def __init__(self) -> None:
# for cache in memory # for cache in memory
self.get_providers() self.get_providers()


def get_providers(self) -> list[ProviderEntity]:
def get_providers(self) -> Sequence[ProviderEntity]:
""" """
Get all providers Get all providers
:return: list of providers :return: list of providers


# traverse all model_provider_extensions # traverse all model_provider_extensions
providers = [] providers = []
for name, model_provider_extension in model_provider_extensions.items():
for model_provider_extension in model_provider_extensions.values():
# get model_provider instance # get model_provider instance
model_provider_instance = model_provider_extension.provider_instance model_provider_instance = model_provider_extension.provider_instance


# return providers # return providers
return providers return providers


def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict:
""" """
Validate provider credentials Validate provider credentials


# get provider_credential_schema and validate credentials according to the rules # get provider_credential_schema and validate credentials according to the rules
provider_credential_schema = provider_schema.provider_credential_schema provider_credential_schema = provider_schema.provider_credential_schema


if not provider_credential_schema:
raise ValueError(f"Provider {provider} does not have provider_credential_schema")

# validate provider credential schema # validate provider credential schema
validator = ProviderCredentialSchemaValidator(provider_credential_schema) validator = ProviderCredentialSchemaValidator(provider_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials) filtered_credentials = validator.validate_and_filter(credentials)


return filtered_credentials return filtered_credentials


def model_credentials_validate(self, provider: str, model_type: ModelType,
model: str, credentials: dict) -> dict:
def model_credentials_validate(
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
) -> dict:
""" """
Validate model credentials Validate model credentials


# get model_credential_schema and validate credentials according to the rules # get model_credential_schema and validate credentials according to the rules
model_credential_schema = provider_schema.model_credential_schema model_credential_schema = provider_schema.model_credential_schema


if not model_credential_schema:
raise ValueError(f"Provider {provider} does not have model_credential_schema")

# validate model credential schema # validate model credential schema
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials) filtered_credentials = validator.validate_and_filter(credentials)


return filtered_credentials return filtered_credentials


def get_models(self,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
provider_configs: Optional[list[ProviderConfig]] = None) \
-> list[SimpleProviderEntity]:
def get_models(
self,
*,
provider: Optional[str] = None,
model_type: Optional[ModelType] = None,
provider_configs: Optional[list[ProviderConfig]] = None,
) -> list[SimpleProviderEntity]:
""" """
Get all models for given model type Get all models for given model type


:param provider_configs: list of provider configs :param provider_configs: list of provider configs
:return: list of models :return: list of models
""" """
provider_configs = provider_configs or []

# scan all providers # scan all providers
model_provider_extensions = self._get_model_provider_map() model_provider_extensions = self._get_model_provider_map()


# get the provider extension # get the provider extension
model_provider_extension = model_provider_extensions.get(provider) model_provider_extension = model_provider_extensions.get(provider)
if not model_provider_extension: if not model_provider_extension:
raise Exception(f'Invalid provider: {provider}')
raise Exception(f"Invalid provider: {provider}")


# get the provider instance # get the provider instance
model_provider_instance = model_provider_extension.provider_instance model_provider_instance = model_provider_extension.provider_instance
return model_provider_instance return model_provider_instance


def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
"""
Retrieves the model provider map.

This method retrieves the model provider map, which is a dictionary containing the model provider names as keys
and instances of `ModelProviderExtension` as values. The model provider map is used to store information about
available model providers.

Returns:
A dictionary containing the model provider map.

Raises:
None.
"""
if self.model_provider_extensions: if self.model_provider_extensions:
return self.model_provider_extensions return self.model_provider_extensions



# get the path of current classes # get the path of current classes
current_path = os.path.abspath(__file__) current_path = os.path.abspath(__file__)
model_providers_path = os.path.dirname(current_path) model_providers_path = os.path.dirname(current_path)
model_provider_dir_paths = [ model_provider_dir_paths = [
os.path.join(model_providers_path, model_provider_dir) os.path.join(model_providers_path, model_provider_dir)
for model_provider_dir in os.listdir(model_providers_path) for model_provider_dir in os.listdir(model_providers_path)
if not model_provider_dir.startswith('__')
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
if not model_provider_dir.startswith("__")
and os.path.isdir(os.path.join(model_providers_path, model_provider_dir))
] ]


# get _position.yaml file path # get _position.yaml file path


file_names = os.listdir(model_provider_dir_path) file_names = os.listdir(model_provider_dir_path)


if (model_provider_name + '.py') not in file_names:
if (model_provider_name + ".py") not in file_names:
logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.") logger.warning(f"Missing {model_provider_name}.py file in {model_provider_dir_path}, Skip.")
continue continue


# Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider # Dynamic loading {model_provider_name}.py file and find the subclass of ModelProvider
py_path = os.path.join(model_provider_dir_path, model_provider_name + '.py')
py_path = os.path.join(model_provider_dir_path, model_provider_name + ".py")
model_provider_class = load_single_subclass_from_source( model_provider_class = load_single_subclass_from_source(
module_name=f'core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}',
module_name=f"core.model_runtime.model_providers.{model_provider_name}.{model_provider_name}",
script_path=py_path, script_path=py_path,
parent_type=ModelProvider)
parent_type=ModelProvider,
)


if not model_provider_class: if not model_provider_class:
logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.") logger.warning(f"Missing Model Provider Class that extends ModelProvider in {py_path}, Skip.")
continue continue


if f'{model_provider_name}.yaml' not in file_names:
if f"{model_provider_name}.yaml" not in file_names:
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
continue continue


model_providers.append(ModelProviderExtension(
name=model_provider_name,
provider_instance=model_provider_class(),
position=position_map.get(model_provider_name)
))
model_providers.append(
ModelProviderExtension(
name=model_provider_name,
provider_instance=model_provider_class(),
position=position_map.get(model_provider_name),
)
)


sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)



+ 12
- 20
api/core/model_runtime/model_providers/openai/_common.py Целия файл

from collections.abc import Mapping

import openai import openai
from httpx import Timeout from httpx import Timeout






class _CommonOpenAI: class _CommonOpenAI:
def _to_credential_kwargs(self, credentials: dict) -> dict:
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
""" """
Transform credentials to kwargs for model instance Transform credentials to kwargs for model instance


"max_retries": 1, "max_retries": 1,
} }


if credentials.get('openai_api_base'):
credentials['openai_api_base'] = credentials['openai_api_base'].rstrip('/')
credentials_kwargs['base_url'] = credentials['openai_api_base'] + '/v1'
if credentials.get("openai_api_base"):
openai_api_base = credentials["openai_api_base"].rstrip("/")
credentials_kwargs["base_url"] = openai_api_base + "/v1"


if 'openai_organization' in credentials: if 'openai_organization' in credentials:
credentials_kwargs['organization'] = credentials['openai_organization'] credentials_kwargs['organization'] = credentials['openai_organization']
:return: Invoke error mapping :return: Invoke error mapping
""" """
return { return {
InvokeConnectionError: [
openai.APIConnectionError,
openai.APITimeoutError
],
InvokeServerUnavailableError: [
openai.InternalServerError
],
InvokeRateLimitError: [
openai.RateLimitError
],
InvokeAuthorizationError: [
openai.AuthenticationError,
openai.PermissionDeniedError
],
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],
InvokeRateLimitError: [openai.RateLimitError],
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
InvokeBadRequestError: [ InvokeBadRequestError: [
openai.BadRequestError, openai.BadRequestError,
openai.NotFoundError, openai.NotFoundError,
openai.UnprocessableEntityError, openai.UnprocessableEntityError,
openai.APIError
]
openai.APIError,
],
} }

+ 2
- 1
api/core/model_runtime/model_providers/openai/openai.py Целия файл

import logging import logging
from collections.abc import Mapping


from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError


class OpenAIProvider(ModelProvider): class OpenAIProvider(ModelProvider):


def validate_provider_credentials(self, credentials: dict) -> None:
def validate_provider_credentials(self, credentials: Mapping) -> None:
""" """
Validate provider credentials Validate provider credentials
if validate failed, raise exception if validate failed, raise exception

Loading…
Отказ
Запис