|
|
|
@@ -4,12 +4,22 @@ from urllib.parse import urlparse |
|
|
|
|
|
|
|
import tiktoken |
|
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult |
|
|
|
from core.model_runtime.entities.common_entities import I18nObject |
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult |
|
|
|
from core.model_runtime.entities.message_entities import ( |
|
|
|
PromptMessage, |
|
|
|
PromptMessageTool, |
|
|
|
SystemPromptMessage, |
|
|
|
) |
|
|
|
from core.model_runtime.entities.model_entities import ( |
|
|
|
AIModelEntity, |
|
|
|
FetchFrom, |
|
|
|
ModelFeature, |
|
|
|
ModelPropertyKey, |
|
|
|
ModelType, |
|
|
|
ParameterRule, |
|
|
|
ParameterType, |
|
|
|
) |
|
|
|
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel |
|
|
|
|
|
|
|
|
|
|
|
@@ -125,3 +135,58 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): |
|
|
|
else: |
|
|
|
parsed_url = urlparse(credentials["endpoint_url"]) |
|
|
|
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" |
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: |
|
|
|
return AIModelEntity( |
|
|
|
model=model, |
|
|
|
label=I18nObject(en_US=model, zh_Hans=model), |
|
|
|
model_type=ModelType.LLM, |
|
|
|
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] |
|
|
|
if credentials.get("function_calling_type") == "tool_call" |
|
|
|
else [], |
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, |
|
|
|
model_properties={ |
|
|
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), |
|
|
|
ModelPropertyKey.MODE: LLMMode.CHAT.value, |
|
|
|
}, |
|
|
|
parameter_rules=[ |
|
|
|
ParameterRule( |
|
|
|
name="temperature", |
|
|
|
use_template="temperature", |
|
|
|
label=I18nObject(en_US="Temperature", zh_Hans="温度"), |
|
|
|
type=ParameterType.FLOAT, |
|
|
|
), |
|
|
|
ParameterRule( |
|
|
|
name="max_tokens", |
|
|
|
use_template="max_tokens", |
|
|
|
default=512, |
|
|
|
min=1, |
|
|
|
max=int(credentials.get("max_tokens", 8192)), |
|
|
|
label=I18nObject( |
|
|
|
en_US="Max Tokens", zh_Hans="指定生成结果长度的上限。如果生成结果截断,可以调大该参数" |
|
|
|
), |
|
|
|
type=ParameterType.INT, |
|
|
|
), |
|
|
|
ParameterRule( |
|
|
|
name="top_p", |
|
|
|
use_template="top_p", |
|
|
|
label=I18nObject( |
|
|
|
en_US="Top P", |
|
|
|
zh_Hans="控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。", |
|
|
|
), |
|
|
|
type=ParameterType.FLOAT, |
|
|
|
), |
|
|
|
ParameterRule( |
|
|
|
name="top_k", |
|
|
|
use_template="top_k", |
|
|
|
label=I18nObject(en_US="Top K", zh_Hans="取样数量"), |
|
|
|
type=ParameterType.FLOAT, |
|
|
|
), |
|
|
|
ParameterRule( |
|
|
|
name="frequency_penalty", |
|
|
|
use_template="frequency_penalty", |
|
|
|
label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"), |
|
|
|
type=ParameterType.FLOAT, |
|
|
|
), |
|
|
|
], |
|
|
|
) |