|
|
|
@@ -1,6 +1,5 @@ |
|
|
|
import codecs |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import re |
|
|
|
from collections.abc import Generator |
|
|
|
from decimal import Decimal |
|
|
|
from typing import Optional, Union, cast |
|
|
|
@@ -39,8 +38,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large |
|
|
|
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat |
|
|
|
from core.model_runtime.utils import helper |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
""" |
|
|
|
@@ -100,7 +97,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
:param tools: tools for tool calling |
|
|
|
:return: |
|
|
|
""" |
|
|
|
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials) |
|
|
|
return self._num_tokens_from_messages(prompt_messages, tools, credentials) |
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
""" |
|
|
|
@@ -399,6 +396,73 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
|
|
|
|
return self._handle_generate_response(model, credentials, response, prompt_messages) |
|
|
|
|
|
|
|
def _create_final_llm_result_chunk( |
|
|
|
self, |
|
|
|
index: int, |
|
|
|
message: AssistantPromptMessage, |
|
|
|
finish_reason: str, |
|
|
|
usage: dict, |
|
|
|
model: str, |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
credentials: dict, |
|
|
|
full_content: str, |
|
|
|
) -> LLMResultChunk: |
|
|
|
# calculate num tokens |
|
|
|
prompt_tokens = usage and usage.get("prompt_tokens") |
|
|
|
if prompt_tokens is None: |
|
|
|
prompt_tokens = self._num_tokens_from_string(text=prompt_messages[0].content) |
|
|
|
completion_tokens = usage and usage.get("completion_tokens") |
|
|
|
if completion_tokens is None: |
|
|
|
completion_tokens = self._num_tokens_from_string(text=full_content) |
|
|
|
|
|
|
|
# transform usage |
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) |
|
|
|
|
|
|
|
return LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), |
|
|
|
) |
|
|
|
|
|
|
|
def _get_tool_call(self, tool_call_id: str, tools_calls: list[AssistantPromptMessage.ToolCall]): |
|
|
|
""" |
|
|
|
Get or create a tool call by ID |
|
|
|
|
|
|
|
:param tool_call_id: tool call ID |
|
|
|
:param tools_calls: list of existing tool calls |
|
|
|
:return: existing or new tool call, updated tools_calls |
|
|
|
""" |
|
|
|
if not tool_call_id: |
|
|
|
return tools_calls[-1], tools_calls |
|
|
|
|
|
|
|
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) |
|
|
|
if tool_call is None: |
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=tool_call_id, |
|
|
|
type="function", |
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), |
|
|
|
) |
|
|
|
tools_calls.append(tool_call) |
|
|
|
|
|
|
|
return tool_call, tools_calls |
|
|
|
|
|
|
|
def _increase_tool_call( |
|
|
|
self, new_tool_calls: list[AssistantPromptMessage.ToolCall], tools_calls: list[AssistantPromptMessage.ToolCall] |
|
|
|
) -> list[AssistantPromptMessage.ToolCall]: |
|
|
|
for new_tool_call in new_tool_calls: |
|
|
|
# get tool call |
|
|
|
tool_call, tools_calls = self._get_tool_call(new_tool_call.function.name, tools_calls) |
|
|
|
# update tool call |
|
|
|
if new_tool_call.id: |
|
|
|
tool_call.id = new_tool_call.id |
|
|
|
if new_tool_call.type: |
|
|
|
tool_call.type = new_tool_call.type |
|
|
|
if new_tool_call.function.name: |
|
|
|
tool_call.function.name = new_tool_call.function.name |
|
|
|
if new_tool_call.function.arguments: |
|
|
|
tool_call.function.arguments += new_tool_call.function.arguments |
|
|
|
return tools_calls |
|
|
|
|
|
|
|
def _handle_generate_stream_response( |
|
|
|
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage] |
|
|
|
) -> Generator: |
|
|
|
@@ -411,71 +475,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
:param prompt_messages: prompt messages |
|
|
|
:return: llm response chunk generator |
|
|
|
""" |
|
|
|
full_assistant_content = "" |
|
|
|
chunk_index = 0 |
|
|
|
|
|
|
|
def create_final_llm_result_chunk( |
|
|
|
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict |
|
|
|
) -> LLMResultChunk: |
|
|
|
# calculate num tokens |
|
|
|
prompt_tokens = usage and usage.get("prompt_tokens") |
|
|
|
if prompt_tokens is None: |
|
|
|
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) |
|
|
|
completion_tokens = usage and usage.get("completion_tokens") |
|
|
|
if completion_tokens is None: |
|
|
|
completion_tokens = self._num_tokens_from_string(model, full_assistant_content) |
|
|
|
|
|
|
|
# transform usage |
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) |
|
|
|
|
|
|
|
return LLMResultChunk( |
|
|
|
id=id, |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage), |
|
|
|
) |
|
|
|
|
|
|
|
full_assistant_content = "" |
|
|
|
tools_calls: list[AssistantPromptMessage.ToolCall] = [] |
|
|
|
finish_reason = None |
|
|
|
usage = None |
|
|
|
is_reasoning_started = False |
|
|
|
# delimiter for stream response, need unicode_escape |
|
|
|
import codecs |
|
|
|
|
|
|
|
delimiter = credentials.get("stream_mode_delimiter", "\n\n") |
|
|
|
delimiter = codecs.decode(delimiter, "unicode_escape") |
|
|
|
|
|
|
|
tools_calls: list[AssistantPromptMessage.ToolCall] = [] |
|
|
|
|
|
|
|
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): |
|
|
|
def get_tool_call(tool_call_id: str): |
|
|
|
if not tool_call_id: |
|
|
|
return tools_calls[-1] |
|
|
|
|
|
|
|
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None) |
|
|
|
if tool_call is None: |
|
|
|
tool_call = AssistantPromptMessage.ToolCall( |
|
|
|
id=tool_call_id, |
|
|
|
type="function", |
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), |
|
|
|
) |
|
|
|
tools_calls.append(tool_call) |
|
|
|
|
|
|
|
return tool_call |
|
|
|
|
|
|
|
for new_tool_call in new_tool_calls: |
|
|
|
# get tool call |
|
|
|
tool_call = get_tool_call(new_tool_call.function.name) |
|
|
|
# update tool call |
|
|
|
if new_tool_call.id: |
|
|
|
tool_call.id = new_tool_call.id |
|
|
|
if new_tool_call.type: |
|
|
|
tool_call.type = new_tool_call.type |
|
|
|
if new_tool_call.function.name: |
|
|
|
tool_call.function.name = new_tool_call.function.name |
|
|
|
if new_tool_call.function.arguments: |
|
|
|
tool_call.function.arguments += new_tool_call.function.arguments |
|
|
|
|
|
|
|
finish_reason = None # The default value of finish_reason is None |
|
|
|
message_id, usage = None, None |
|
|
|
is_reasoning_started = False |
|
|
|
is_reasoning_started_tag = False |
|
|
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): |
|
|
|
chunk = chunk.strip() |
|
|
|
if chunk: |
|
|
|
@@ -490,12 +498,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
chunk_json: dict = json.loads(decoded_chunk) |
|
|
|
# stream ended |
|
|
|
except json.JSONDecodeError as e: |
|
|
|
yield create_final_llm_result_chunk( |
|
|
|
id=message_id, |
|
|
|
yield self._create_final_llm_result_chunk( |
|
|
|
index=chunk_index + 1, |
|
|
|
message=AssistantPromptMessage(content=""), |
|
|
|
finish_reason="Non-JSON encountered.", |
|
|
|
usage=usage, |
|
|
|
model=model, |
|
|
|
credentials=credentials, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
full_content=full_assistant_content, |
|
|
|
) |
|
|
|
break |
|
|
|
# handle the error here. for issue #11629 |
|
|
|
@@ -510,42 +521,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
|
|
|
|
choice = chunk_json["choices"][0] |
|
|
|
finish_reason = chunk_json["choices"][0].get("finish_reason") |
|
|
|
message_id = chunk_json.get("id") |
|
|
|
chunk_index += 1 |
|
|
|
|
|
|
|
if "delta" in choice: |
|
|
|
delta = choice["delta"] |
|
|
|
delta_content = delta.get("content") |
|
|
|
if not delta_content: |
|
|
|
delta_content = "" |
|
|
|
|
|
|
|
if not is_reasoning_started_tag and "<think>" in delta_content: |
|
|
|
is_reasoning_started_tag = True |
|
|
|
delta_content = "> 💭 " + delta_content.replace("<think>", "") |
|
|
|
elif is_reasoning_started_tag and "</think>" in delta_content: |
|
|
|
delta_content = delta_content.replace("</think>", "") + "\n\n" |
|
|
|
is_reasoning_started_tag = False |
|
|
|
elif is_reasoning_started_tag: |
|
|
|
if "\n" in delta_content: |
|
|
|
delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content) |
|
|
|
|
|
|
|
reasoning_content = delta.get("reasoning_content") |
|
|
|
if is_reasoning_started and not reasoning_content and not delta_content: |
|
|
|
delta_content = "" |
|
|
|
elif reasoning_content: |
|
|
|
if not is_reasoning_started: |
|
|
|
delta_content = "> 💭 " + reasoning_content |
|
|
|
is_reasoning_started = True |
|
|
|
else: |
|
|
|
delta_content = reasoning_content |
|
|
|
|
|
|
|
if "\n" in delta_content: |
|
|
|
delta_content = re.sub(r"\n(?!(>|\n))", "\n> ", delta_content) |
|
|
|
elif is_reasoning_started: |
|
|
|
# If we were in reasoning mode but now getting regular content, |
|
|
|
# add \n\n to close the reasoning block |
|
|
|
delta_content = "\n\n" + delta_content |
|
|
|
is_reasoning_started = False |
|
|
|
delta_content, is_reasoning_started = self._wrap_thinking_by_reasoning_content( |
|
|
|
delta, is_reasoning_started |
|
|
|
) |
|
|
|
delta_content = self._wrap_thinking_by_tag(delta_content) |
|
|
|
|
|
|
|
assistant_message_tool_calls = None |
|
|
|
|
|
|
|
@@ -559,12 +542,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})} |
|
|
|
] |
|
|
|
|
|
|
|
# assistant_message_function_call = delta.delta.function_call |
|
|
|
|
|
|
|
# extract tool calls from response |
|
|
|
if assistant_message_tool_calls: |
|
|
|
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) |
|
|
|
increase_tool_call(tool_calls) |
|
|
|
tools_calls = self._increase_tool_call(tool_calls, tools_calls) |
|
|
|
|
|
|
|
if delta_content is None or delta_content == "": |
|
|
|
continue |
|
|
|
@@ -589,7 +570,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
continue |
|
|
|
|
|
|
|
yield LLMResultChunk( |
|
|
|
id=message_id, |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
@@ -602,7 +582,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
|
|
|
|
if tools_calls: |
|
|
|
yield LLMResultChunk( |
|
|
|
id=message_id, |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
@@ -611,12 +590,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
yield create_final_llm_result_chunk( |
|
|
|
id=message_id, |
|
|
|
yield self._create_final_llm_result_chunk( |
|
|
|
index=chunk_index, |
|
|
|
message=AssistantPromptMessage(content=""), |
|
|
|
finish_reason=finish_reason, |
|
|
|
usage=usage, |
|
|
|
model=model, |
|
|
|
credentials=credentials, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
full_content=full_assistant_content, |
|
|
|
) |
|
|
|
|
|
|
|
def _handle_generate_response( |
|
|
|
@@ -730,12 +712,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
return message_dict |
|
|
|
|
|
|
|
def _num_tokens_from_string( |
|
|
|
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None |
|
|
|
self, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None |
|
|
|
) -> int: |
|
|
|
""" |
|
|
|
Approximate num tokens for model with gpt2 tokenizer. |
|
|
|
|
|
|
|
:param model: model name |
|
|
|
:param text: prompt text |
|
|
|
:param tools: tools for tool calling |
|
|
|
:return: number of tokens |
|
|
|
@@ -758,7 +739,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): |
|
|
|
|
|
|
|
def _num_tokens_from_messages( |
|
|
|
self, |
|
|
|
model: str, |
|
|
|
messages: list[PromptMessage], |
|
|
|
tools: Optional[list[PromptMessageTool]] = None, |
|
|
|
credentials: Optional[dict] = None, |