|
|
|
@@ -170,13 +170,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
features = [] |
|
|
|
|
|
|
|
function_calling_type = credentials.get('function_calling_type', 'no_call') |
|
|
|
if function_calling_type == 'function_call': |
|
|
|
if function_calling_type in ['function_call']: |
|
|
|
features.append(ModelFeature.TOOL_CALL) |
|
|
|
endpoint_url = credentials["endpoint_url"] |
|
|
|
# if not endpoint_url.endswith('/'): |
|
|
|
# endpoint_url += '/' |
|
|
|
# if 'https://api.openai.com/v1/' == endpoint_url: |
|
|
|
# features.append(ModelFeature.STREAM_TOOL_CALL) |
|
|
|
elif function_calling_type in ['tool_call']: |
|
|
|
features.append(ModelFeature.MULTI_TOOL_CALL) |
|
|
|
|
|
|
|
stream_function_calling = credentials.get('stream_function_calling', 'supported') |
|
|
|
if stream_function_calling == 'supported': |
|
|
|
features.append(ModelFeature.STREAM_TOOL_CALL) |
|
|
|
|
|
|
|
vision_support = credentials.get('vision_support', 'not_support') |
|
|
|
if vision_support == 'support': |
|
|
|
@@ -386,29 +387,37 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
|
|
|
|
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): |
|
|
|
def get_tool_call(tool_call_id: str): |
|
|
|
tool_call = next( |
|
|
|
(tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None |
|
|
|
) |
|
|
|
if not tool_call_id: |
|
|
|
return tools_calls[-1] |
|
|
|
|
|
|
|
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) |
|
|
|
if tool_call is None: |
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id='', |
|
|
|
type='function', |
|
|
|
id=tool_call_id, |
|
|
|
type="function", |
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name='', |
|
|
|
arguments='' |
|
|
|
name="", |
|
|
|
arguments="" |
|
|
|
) |
|
|
|
) |
|
|
|
tools_calls.append(tool_call) |
|
|
|
|
|
|
|
return tool_call |
|
|
|
|
|
|
|
for new_tool_call in new_tool_calls: |
|
|
|
# get tool call |
|
|
|
tool_call = get_tool_call(new_tool_call.id) |
|
|
|
tool_call = get_tool_call(new_tool_call.function.name) |
|
|
|
# update tool call |
|
|
|
tool_call.id = new_tool_call.id |
|
|
|
tool_call.type = new_tool_call.type |
|
|
|
tool_call.function.name = new_tool_call.function.name |
|
|
|
tool_call.function.arguments += new_tool_call.function.arguments |
|
|
|
if new_tool_call.id: |
|
|
|
tool_call.id = new_tool_call.id |
|
|
|
if new_tool_call.type: |
|
|
|
tool_call.type = new_tool_call.type |
|
|
|
if new_tool_call.function.name: |
|
|
|
tool_call.function.name = new_tool_call.function.name |
|
|
|
if new_tool_call.function.arguments: |
|
|
|
tool_call.function.arguments += new_tool_call.function.arguments |
|
|
|
|
|
|
|
finish_reason = 'Unknown' |
|
|
|
|
|
|
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): |
|
|
|
if chunk: |
|
|
|
@@ -438,7 +447,17 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
delta = choice['delta'] |
|
|
|
delta_content = delta.get('content') |
|
|
|
|
|
|
|
assistant_message_tool_calls = delta.get('tool_calls', None) |
|
|
|
assistant_message_tool_calls = None |
|
|
|
|
|
|
|
if 'tool_calls' in delta and credentials.get('function_calling_type', 'no_call') == 'tool_call': |
|
|
|
assistant_message_tool_calls = delta.get('tool_calls', None) |
|
|
|
elif 'function_call' in delta and credentials.get('function_calling_type', 'no_call') == 'function_call': |
|
|
|
assistant_message_tool_calls = [{ |
|
|
|
'id': 'tool_call_id', |
|
|
|
'type': 'function', |
|
|
|
'function': delta.get('function_call', {}) |
|
|
|
}] |
|
|
|
|
|
|
|
# assistant_message_function_call = delta.delta.function_call |
|
|
|
|
|
|
|
# extract tool calls from response |
|
|
|
@@ -449,15 +468,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
if delta_content is None or delta_content == '': |
|
|
|
continue |
|
|
|
|
|
|
|
# function_call = self._extract_response_function_call(assistant_message_function_call) |
|
|
|
# tool_calls = [function_call] if function_call else [] |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
|
|
content=delta_content, |
|
|
|
tool_calls=tool_calls if assistant_message_tool_calls else [] |
|
|
|
) |
|
|
|
|
|
|
|
# reset tool calls |
|
|
|
tool_calls = [] |
|
|
|
full_assistant_content += delta_content |
|
|
|
elif 'text' in choice: |
|
|
|
choice_text = choice.get('text', '') |
|
|
|
@@ -470,37 +487,36 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
else: |
|
|
|
continue |
|
|
|
|
|
|
|
# check payload indicator for completion |
|
|
|
if finish_reason is not None: |
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=chunk_index, |
|
|
|
message=AssistantPromptMessage( |
|
|
|
tool_calls=tools_calls, |
|
|
|
), |
|
|
|
finish_reason=finish_reason |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
yield create_final_llm_result_chunk( |
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=chunk_index, |
|
|
|
message=assistant_prompt_message, |
|
|
|
finish_reason=finish_reason |
|
|
|
) |
|
|
|
else: |
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=chunk_index, |
|
|
|
message=assistant_prompt_message, |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
chunk_index += 1 |
|
|
|
|
|
|
|
if tools_calls: |
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=chunk_index, |
|
|
|
message=AssistantPromptMessage( |
|
|
|
tool_calls=tools_calls, |
|
|
|
content="" |
|
|
|
), |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
yield create_final_llm_result_chunk( |
|
|
|
index=chunk_index, |
|
|
|
message=AssistantPromptMessage(content=""), |
|
|
|
finish_reason=finish_reason |
|
|
|
) |
|
|
|
|
|
|
|
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, |
|
|
|
prompt_messages: list[PromptMessage]) -> LLMResult: |
|
|
|
|
|
|
|
@@ -757,13 +773,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
if response_tool_calls: |
|
|
|
for response_tool_call in response_tool_calls: |
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=response_tool_call["function"]["name"], |
|
|
|
arguments=response_tool_call["function"]["arguments"] |
|
|
|
name=response_tool_call.get("function", {}).get("name", ""), |
|
|
|
arguments=response_tool_call.get("function", {}).get("arguments", "") |
|
|
|
) |
|
|
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=response_tool_call["id"], |
|
|
|
type=response_tool_call["type"], |
|
|
|
id=response_tool_call.get("id", ""), |
|
|
|
type=response_tool_call.get("type", ""), |
|
|
|
function=function |
|
|
|
) |
|
|
|
tool_calls.append(tool_call) |
|
|
|
@@ -781,12 +797,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
tool_call = None |
|
|
|
if response_function_call: |
|
|
|
function = AssistantPromptMessage.ToolCall.ToolCallFunction( |
|
|
|
name=response_function_call['name'], |
|
|
|
arguments=response_function_call['arguments'] |
|
|
|
name=response_function_call.get('name', ''), |
|
|
|
arguments=response_function_call.get('arguments', '') |
|
|
|
) |
|
|
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=response_function_call['name'], |
|
|
|
id=response_function_call.get('id', ''), |
|
|
|
type="function", |
|
|
|
function=function |
|
|
|
) |