|
|
|
@@ -25,6 +25,7 @@ from core.model_runtime.entities.model_entities import ( |
|
|
|
AIModelEntity, |
|
|
|
DefaultParameterName, |
|
|
|
FetchFrom, |
|
|
|
ModelFeature, |
|
|
|
ModelPropertyKey, |
|
|
|
ModelType, |
|
|
|
ParameterRule, |
|
|
|
@@ -166,11 +167,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
""" |
|
|
|
generate custom model entities from credentials |
|
|
|
""" |
|
|
|
support_function_call = False |
|
|
|
features = [] |
|
|
|
function_calling_type = credentials.get('function_calling_type', 'no_call') |
|
|
|
if function_calling_type == 'function_call': |
|
|
|
features = [ModelFeature.TOOL_CALL] |
|
|
|
support_function_call = True |
|
|
|
endpoint_url = credentials["endpoint_url"] |
|
|
|
# if not endpoint_url.endswith('/'): |
|
|
|
# endpoint_url += '/' |
|
|
|
# if 'https://api.openai.com/v1/' == endpoint_url: |
|
|
|
# features = [ModelFeature.STREAM_TOOL_CALL] |
|
|
|
entity = AIModelEntity( |
|
|
|
model=model, |
|
|
|
label=I18nObject(en_US=model), |
|
|
|
model_type=ModelType.LLM, |
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, |
|
|
|
features=features if support_function_call else [], |
|
|
|
model_properties={ |
|
|
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")), |
|
|
|
ModelPropertyKey.MODE: credentials.get('mode'), |
|
|
|
@@ -194,14 +207,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
max=1, |
|
|
|
precision=2 |
|
|
|
), |
|
|
|
ParameterRule( |
|
|
|
name="top_k", |
|
|
|
label=I18nObject(en_US="Top K"), |
|
|
|
type=ParameterType.INT, |
|
|
|
default=int(credentials.get('top_k', 1)), |
|
|
|
min=1, |
|
|
|
max=100 |
|
|
|
), |
|
|
|
ParameterRule( |
|
|
|
name=DefaultParameterName.FREQUENCY_PENALTY.value, |
|
|
|
label=I18nObject(en_US="Frequency Penalty"), |
|
|
|
@@ -232,7 +237,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
output=Decimal(credentials.get('output_price', 0)), |
|
|
|
unit=Decimal(credentials.get('unit', 0)), |
|
|
|
currency=credentials.get('currency', "USD") |
|
|
|
) |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
if credentials['mode'] == 'chat': |
|
|
|
@@ -292,14 +297,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
raise ValueError("Unsupported completion type for model configuration.") |
|
|
|
|
|
|
|
# annotate tools with names, descriptions, etc. |
|
|
|
function_calling_type = credentials.get('function_calling_type', 'no_call') |
|
|
|
formatted_tools = [] |
|
|
|
if tools: |
|
|
|
data["tool_choice"] = "auto" |
|
|
|
if function_calling_type == 'function_call': |
|
|
|
data['functions'] = [{ |
|
|
|
"name": tool.name, |
|
|
|
"description": tool.description, |
|
|
|
"parameters": tool.parameters |
|
|
|
} for tool in tools] |
|
|
|
elif function_calling_type == 'tool_call': |
|
|
|
data["tool_choice"] = "auto" |
|
|
|
|
|
|
|
for tool in tools: |
|
|
|
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) |
|
|
|
for tool in tools: |
|
|
|
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) |
|
|
|
|
|
|
|
data["tools"] = formatted_tools |
|
|
|
data["tools"] = formatted_tools |
|
|
|
|
|
|
|
if stop: |
|
|
|
data["stop"] = stop |
|
|
|
@@ -367,9 +380,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
|
|
|
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): |
|
|
|
if chunk: |
|
|
|
#ignore sse comments |
|
|
|
# ignore sse comments |
|
|
|
if chunk.startswith(':'): |
|
|
|
continue |
|
|
|
continue |
|
|
|
decoded_chunk = chunk.strip().lstrip('data: ').lstrip() |
|
|
|
chunk_json = None |
|
|
|
try: |
|
|
|
@@ -452,10 +465,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
|
|
|
|
response_content = '' |
|
|
|
tool_calls = None |
|
|
|
|
|
|
|
function_calling_type = credentials.get('function_calling_type', 'no_call') |
|
|
|
if completion_type is LLMMode.CHAT: |
|
|
|
response_content = output.get('message', {})['content'] |
|
|
|
tool_calls = output.get('message', {}).get('tool_calls') |
|
|
|
if function_calling_type == 'tool_call': |
|
|
|
tool_calls = output.get('message', {}).get('tool_calls') |
|
|
|
elif function_calling_type == 'function_call': |
|
|
|
tool_calls = output.get('message', {}).get('function_call') |
|
|
|
|
|
|
|
elif completion_type is LLMMode.COMPLETION: |
|
|
|
response_content = output['text'] |
|
|
|
@@ -463,7 +479,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[]) |
|
|
|
|
|
|
|
if tool_calls: |
|
|
|
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) |
|
|
|
if function_calling_type == 'tool_call': |
|
|
|
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls) |
|
|
|
elif function_calling_type == 'function_call': |
|
|
|
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)] |
|
|
|
|
|
|
|
usage = response_json.get("usage") |
|
|
|
if usage: |
|
|
|
@@ -522,33 +541,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
message = cast(AssistantPromptMessage, message) |
|
|
|
message_dict = {"role": "assistant", "content": message.content} |
|
|
|
if message.tool_calls: |
|
|
|
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call |
|
|
|
in |
|
|
|
message.tool_calls] |
|
|
|
# function_call = message.tool_calls[0] |
|
|
|
# message_dict["function_call"] = { |
|
|
|
# "name": function_call.function.name, |
|
|
|
# "arguments": function_call.function.arguments, |
|
|
|
# } |
|
|
|
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call |
|
|
|
# in |
|
|
|
# message.tool_calls] |
|
|
|
|
|
|
|
function_call = message.tool_calls[0] |
|
|
|
message_dict["function_call"] = { |
|
|
|
"name": function_call.function.name, |
|
|
|
"arguments": function_call.function.arguments, |
|
|
|
} |
|
|
|
elif isinstance(message, SystemPromptMessage): |
|
|
|
message = cast(SystemPromptMessage, message) |
|
|
|
message_dict = {"role": "system", "content": message.content} |
|
|
|
elif isinstance(message, ToolPromptMessage): |
|
|
|
message = cast(ToolPromptMessage, message) |
|
|
|
message_dict = { |
|
|
|
"role": "tool", |
|
|
|
"content": message.content, |
|
|
|
"tool_call_id": message.tool_call_id |
|
|
|
} |
|
|
|
# message_dict = { |
|
|
|
# "role": "function", |
|
|
|
# "role": "tool", |
|
|
|
# "content": message.content, |
|
|
|
# "name": message.tool_call_id |
|
|
|
# "tool_call_id": message.tool_call_id |
|
|
|
# } |
|
|
|
message_dict = { |
|
|
|
"role": "function", |
|
|
|
"content": message.content, |
|
|
|
"name": message.tool_call_id |
|
|
|
} |
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown type {message}") |
|
|
|
|
|
|
|
if message.name is not None: |
|
|
|
if message.name: |
|
|
|
message_dict["name"] = message.name |
|
|
|
|
|
|
|
return message_dict |
|
|
|
@@ -693,3 +713,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
tool_calls.append(tool_call) |
|
|
|
|
|
|
|
return tool_calls |
|
|
|
|
|
|
|
def _extract_response_function_call(self, response_function_call) \ |
|
|
|
-> AssistantPromptMessage.ToolCall: |
|
|
|
""" |
|
|
|
Extract function call from response |
|
|
|
|
|
|
|
:param response_function_call: response function call |
|
|
|
:return: tool call |
|
|
|
""" |
|
|
|
tool_call = None |
|
|
|
if response_function_call: |
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=response_function_call['name'], |
|
|
|
arguments=response_function_call['arguments'] |
|
|
|
) |
|
|
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=response_function_call['name'], |
|
|
|
type="function", |
|
|
|
function=function |
|
|
|
) |
|
|
|
|
|
|
|
return tool_call |