|
|
|
@@ -63,6 +63,9 @@ from core.model_runtime.model_providers.xinference.xinference_helper import ( |
|
|
|
) |
|
|
|
from core.model_runtime.utils import helper |
|
|
|
|
|
|
|
DEFAULT_MAX_RETRIES = 3 |
|
|
|
DEFAULT_INVOKE_TIMEOUT = 60 |
|
|
|
|
|
|
|
|
|
|
|
class XinferenceAILargeLanguageModel(LargeLanguageModel): |
|
|
|
def _invoke( |
|
|
|
@@ -315,7 +318,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): |
|
|
|
message_dict = {"role": "system", "content": message.content} |
|
|
|
elif isinstance(message, ToolPromptMessage): |
|
|
|
message = cast(ToolPromptMessage, message) |
|
|
|
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} |
|
|
|
message_dict = { |
|
|
|
"tool_call_id": message.tool_call_id, |
|
|
|
"role": "tool", |
|
|
|
"content": message.content, |
|
|
|
"name": message.name, |
|
|
|
} |
|
|
|
else: |
|
|
|
raise ValueError(f"Unknown message type {type(message)}") |
|
|
|
|
|
|
|
@@ -466,8 +474,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): |
|
|
|
client = OpenAI( |
|
|
|
base_url=f'{credentials["server_url"]}/v1', |
|
|
|
api_key=api_key, |
|
|
|
max_retries=3, |
|
|
|
timeout=60, |
|
|
|
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES), |
|
|
|
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT), |
|
|
|
) |
|
|
|
|
|
|
|
xinference_client = Client( |