|
|
|
@@ -1,5 +1,6 @@ |
|
|
|
import logging |
|
|
|
from decimal import Decimal |
|
|
|
from urllib.parse import urljoin |
|
|
|
|
|
|
|
import requests |
|
|
|
import json |
|
|
|
@@ -9,9 +10,12 @@ from typing import Optional, Generator, Union, List, cast |
|
|
|
from core.model_runtime.entities.common_entities import I18nObject |
|
|
|
from core.model_runtime.utils import helper |
|
|
|
|
|
|
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \ |
|
|
|
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage |
|
|
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \ |
|
|
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \ |
|
|
|
AssistantPromptMessage, PromptMessageContent, \ |
|
|
|
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \ |
|
|
|
ToolPromptMessage |
|
|
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \ |
|
|
|
DefaultParameterName, \ |
|
|
|
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity |
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta |
|
|
|
from core.model_runtime.errors.invoke import InvokeError |
|
|
|
@@ -70,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
:return: |
|
|
|
""" |
|
|
|
return self._num_tokens_from_messages(model, prompt_messages, tools) |
|
|
|
|
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
""" |
|
|
|
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard. |
|
|
|
@@ -89,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
headers["Authorization"] = f"Bearer {api_key}" |
|
|
|
|
|
|
|
endpoint_url = credentials['endpoint_url'] |
|
|
|
if not endpoint_url.endswith('/'): |
|
|
|
endpoint_url += '/' |
|
|
|
|
|
|
|
# prepare the payload for a simple ping to the model |
|
|
|
data = { |
|
|
|
@@ -105,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
"content": "ping" |
|
|
|
}, |
|
|
|
] |
|
|
|
endpoint_url = urljoin(endpoint_url, 'chat/completions') |
|
|
|
elif completion_type is LLMMode.COMPLETION: |
|
|
|
data['prompt'] = 'ping' |
|
|
|
endpoint_url = urljoin(endpoint_url, 'completions') |
|
|
|
else: |
|
|
|
raise ValueError("Unsupported completion type for model configuration.") |
|
|
|
|
|
|
|
|
|
|
|
# send a post request to validate the credentials |
|
|
|
response = requests.post( |
|
|
|
endpoint_url, |
|
|
|
@@ -119,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
) |
|
|
|
|
|
|
|
if response.status_code != 200: |
|
|
|
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}') |
|
|
|
|
|
|
|
raise CredentialsValidateFailedError( |
|
|
|
f'Credentials validation failed with status code {response.status_code}') |
|
|
|
|
|
|
|
try: |
|
|
|
json_result = response.json() |
|
|
|
except json.JSONDecodeError as e: |
|
|
|
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error') |
|
|
|
|
|
|
|
if (completion_type is LLMMode.CHAT |
|
|
|
and ('object' not in json_result or json_result['object'] != 'chat.completion')): |
|
|
|
raise CredentialsValidateFailedError( |
|
|
|
f'Credentials validation failed: invalid response object, must be \'chat.completion\'') |
|
|
|
elif (completion_type is LLMMode.COMPLETION |
|
|
|
and ('object' not in json_result or json_result['object'] != 'text_completion')): |
|
|
|
raise CredentialsValidateFailedError( |
|
|
|
f'Credentials validation failed: invalid response object, must be \'text_completion\'') |
|
|
|
except CredentialsValidateFailedError: |
|
|
|
raise |
|
|
|
except Exception as ex: |
|
|
|
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') |
|
|
|
|
|
|
|
@@ -134,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
model_type=ModelType.LLM, |
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, |
|
|
|
model_properties={ |
|
|
|
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'), |
|
|
|
ModelPropertyKey.MODE: 'chat' |
|
|
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')), |
|
|
|
ModelPropertyKey.MODE: credentials.get('mode'), |
|
|
|
}, |
|
|
|
parameter_rules=[ |
|
|
|
ParameterRule( |
|
|
|
@@ -197,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
|
|
|
|
return entity |
|
|
|
|
|
|
|
|
|
|
|
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. |
|
|
|
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, |
|
|
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \ |
|
|
|
user: Optional[str] = None) -> Union[LLMResult, Generator]: |
|
|
|
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, |
|
|
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, |
|
|
|
stream: bool = True, \ |
|
|
|
user: Optional[str] = None) -> Union[LLMResult, Generator]: |
|
|
|
""" |
|
|
|
Invoke llm completion model |
|
|
|
|
|
|
|
@@ -223,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
headers["Authorization"] = f"Bearer {api_key}" |
|
|
|
|
|
|
|
endpoint_url = credentials["endpoint_url"] |
|
|
|
|
|
|
|
if not endpoint_url.endswith('/'): |
|
|
|
endpoint_url += '/' |
|
|
|
|
|
|
|
data = { |
|
|
|
"model": model, |
|
|
|
"stream": stream, |
|
|
|
@@ -233,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
completion_type = LLMMode.value_of(credentials['mode']) |
|
|
|
|
|
|
|
if completion_type is LLMMode.CHAT: |
|
|
|
endpoint_url = urljoin(endpoint_url, 'chat/completions') |
|
|
|
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] |
|
|
|
elif completion_type == LLMMode.COMPLETION: |
|
|
|
endpoint_url = urljoin(endpoint_url, 'completions') |
|
|
|
data['prompt'] = prompt_messages[0].content |
|
|
|
else: |
|
|
|
raise ValueError("Unsupported completion type for model configuration.") |
|
|
|
@@ -245,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
data["tool_choice"] = "auto" |
|
|
|
|
|
|
|
for tool in tools: |
|
|
|
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool))) |
|
|
|
|
|
|
|
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool))) |
|
|
|
|
|
|
|
data["tools"] = formatted_tools |
|
|
|
|
|
|
|
if stop: |
|
|
|
@@ -254,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
|
|
|
|
if user: |
|
|
|
data["user"] = user |
|
|
|
|
|
|
|
|
|
|
|
response = requests.post( |
|
|
|
endpoint_url, |
|
|
|
headers=headers, |
|
|
|
@@ -275,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
|
|
|
|
return self._handle_generate_response(model, credentials, response, prompt_messages) |
|
|
|
|
|
|
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, |
|
|
|
prompt_messages: list[PromptMessage]) -> Generator: |
|
|
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response, |
|
|
|
prompt_messages: list[PromptMessage]) -> Generator: |
|
|
|
""" |
|
|
|
Handle llm stream response |
|
|
|
|
|
|
|
@@ -313,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
if chunk: |
|
|
|
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip() |
|
|
|
|
|
|
|
chunk_json = None |
|
|
|
try: |
|
|
|
chunk_json = json.loads(decoded_chunk) |
|
|
|
# stream ended |
|
|
|
except json.JSONDecodeError as e: |
|
|
|
yield create_final_llm_result_chunk( |
|
|
|
index=chunk_index + 1, |
|
|
|
index=chunk_index + 1, |
|
|
|
message=AssistantPromptMessage(content=""), |
|
|
|
finish_reason="Non-JSON encountered." |
|
|
|
) |
|
|
|
|
|
|
|
if len(chunk_json['choices']) == 0: |
|
|
|
if not chunk_json or len(chunk_json['choices']) == 0: |
|
|
|
continue |
|
|
|
|
|
|
|
delta = chunk_json['choices'][0]['delta'] |
|
|
|
chunk_index = chunk_json['choices'][0]['index'] |
|
|
|
choice = chunk_json['choices'][0] |
|
|
|
chunk_index = choice['index'] if 'index' in choice else chunk_index |
|
|
|
|
|
|
|
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''): |
|
|
|
continue |
|
|
|
|
|
|
|
assistant_message_tool_calls = delta.get('tool_calls', None) |
|
|
|
# assistant_message_function_call = delta.delta.function_call |
|
|
|
|
|
|
|
# extract tool calls from response |
|
|
|
if assistant_message_tool_calls: |
|
|
|
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) |
|
|
|
# function_call = self._extract_response_function_call(assistant_message_function_call) |
|
|
|
# tool_calls = [function_call] if function_call else [] |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
|
|
content=delta.get('content', ''), |
|
|
|
tool_calls=tool_calls if assistant_message_tool_calls else [] |
|
|
|
) |
|
|
|
if 'delta' in choice: |
|
|
|
delta = choice['delta'] |
|
|
|
if delta.get('content') is None or delta.get('content') == '': |
|
|
|
continue |
|
|
|
|
|
|
|
assistant_message_tool_calls = delta.get('tool_calls', None) |
|
|
|
# assistant_message_function_call = delta.delta.function_call |
|
|
|
|
|
|
|
# extract tool calls from response |
|
|
|
if assistant_message_tool_calls: |
|
|
|
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) |
|
|
|
# function_call = self._extract_response_function_call(assistant_message_function_call) |
|
|
|
# tool_calls = [function_call] if function_call else [] |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
|
|
content=delta.get('content', ''), |
|
|
|
tool_calls=tool_calls if assistant_message_tool_calls else [] |
|
|
|
) |
|
|
|
|
|
|
|
full_assistant_content += delta.get('content', '') |
|
|
|
full_assistant_content += delta.get('content', '') |
|
|
|
elif 'text' in choice: |
|
|
|
if choice.get('text') is None or choice.get('text') == '': |
|
|
|
continue |
|
|
|
|
|
|
|
# transform assistant message to prompt message |
|
|
|
assistant_prompt_message = AssistantPromptMessage( |
|
|
|
content=choice.get('text', '') |
|
|
|
) |
|
|
|
|
|
|
|
full_assistant_content += choice.get('text', '') |
|
|
|
else: |
|
|
|
continue |
|
|
|
|
|
|
|
# check payload indicator for completion |
|
|
|
if chunk_json['choices'][0].get('finish_reason') is not None: |
|
|
|
|
|
|
|
yield create_final_llm_result_chunk( |
|
|
|
index=chunk_index, |
|
|
|
message=assistant_prompt_message, |
|
|
|
finish_reason=chunk_json['choices'][0]['finish_reason'] |
|
|
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
@@ -373,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
message=AssistantPromptMessage(content=""), |
|
|
|
finish_reason="End of stream." |
|
|
|
) |
|
|
|
|
|
|
|
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, |
|
|
|
prompt_messages: list[PromptMessage]) -> LLMResult: |
|
|
|
|
|
|
|
|
|
|
|
chunk_index += 1 |
|
|
|
|
|
|
|
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response, |
|
|
|
prompt_messages: list[PromptMessage]) -> LLMResult: |
|
|
|
|
|
|
|
response_json = response.json() |
|
|
|
|
|
|
|
completion_type = LLMMode.value_of(credentials['mode']) |
|
|
|
@@ -455,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
message = cast(AssistantPromptMessage, message) |
|
|
|
message_dict = {"role": "assistant", "content": message.content} |
|
|
|
if message.tool_calls: |
|
|
|
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in |
|
|
|
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call |
|
|
|
in |
|
|
|
message.tool_calls] |
|
|
|
# function_call = message.tool_calls[0] |
|
|
|
# message_dict["function_call"] = { |
|
|
|
@@ -484,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
message_dict["name"] = message.name |
|
|
|
|
|
|
|
return message_dict |
|
|
|
|
|
|
|
|
|
|
|
def _num_tokens_from_string(self, model: str, text: str, |
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int: |
|
|
|
""" |
|
|
|
@@ -507,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
""" |
|
|
|
Approximate num tokens with GPT2 tokenizer. |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
tokens_per_message = 3 |
|
|
|
tokens_per_name = 1 |
|
|
|
|
|
|
|
|
|
|
|
num_tokens = 0 |
|
|
|
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] |
|
|
|
for message in messages_dict: |
|
|
|
@@ -599,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): |
|
|
|
num_tokens += self._get_num_tokens_by_gpt2(required_field) |
|
|
|
|
|
|
|
return num_tokens |
|
|
|
|
|
|
|
|
|
|
|
def _extract_response_tool_calls(self, |
|
|
|
response_tool_calls: list[dict]) \ |
|
|
|
-> list[AssistantPromptMessage.ToolCall]: |