Procházet zdrojové kódy

Feat/blocking function call (#2247)

tags/0.5.3
Yeuoly před 1 rokem
rodič
revize
6d5b386394
Žádný účet není propojen s e-mailovou adresou tvůrce revize
33 změnil soubory, kde provedl 430 přidání a 95 odebrání
  1. 11
    3
      api/core/app_runner/assistant_app_runner.py
  2. 14
    1
      api/core/features/assistant_base_runner.py
  3. 4
    3
      api/core/features/assistant_cot_runner.py
  4. 105
    23
      api/core/features/assistant_fc_runner.py
  5. 1
    0
      api/core/model_runtime/entities/model_entities.py
  6. 5
    0
      api/core/model_runtime/model_providers/azure_openai/_constant.py
  7. 24
    4
      api/core/model_runtime/model_providers/azure_openai/llm/llm.py
  8. 5
    1
      api/core/model_runtime/model_providers/chatglm/llm/llm.py
  9. 2
    0
      api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml
  10. 2
    0
      api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml
  11. 1
    2
      api/core/model_runtime/model_providers/minimax/llm/chat_completion.py
  12. 37
    9
      api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py
  13. 42
    1
      api/core/model_runtime/model_providers/minimax/llm/llm.py
  14. 10
    0
      api/core/model_runtime/model_providers/minimax/llm/types.py
  15. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
  16. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
  17. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
  18. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
  19. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
  20. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml
  21. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml
  22. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml
  23. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml
  24. 1
    0
      api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml
  25. 1
    1
      api/core/model_runtime/model_providers/openai/llm/llm.py
  26. 27
    4
      api/core/model_runtime/model_providers/xinference/llm/llm.py
  27. 19
    4
      api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py
  28. 19
    6
      api/core/model_runtime/model_providers/xinference/xinference_helper.py
  29. 4
    0
      api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml
  30. 4
    0
      api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml
  31. 21
    0
      api/core/model_runtime/model_providers/zhipuai/llm/llm.py
  32. 1
    1
      api/requirements.txt
  33. 61
    32
      api/tests/integration_tests/model_runtime/__mock/xinference.py

+ 11
- 3
api/core/app_runner/assistant_app_runner.py Zobrazit soubor

from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool from core.tools.entities.tool_entities import ToolRuntimeVariablePool
memory=memory, memory=memory,
) )


# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)

if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING

# start agent runner # start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner( assistant_cot_runner = AssistantCotApplicationRunner(
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables, db_variables=tool_conversation_variables,
model_instance=model_instance
) )
invoke_result = assistant_cot_runner.run( invoke_result = assistant_cot_runner.run(
model_instance=model_instance,
conversation=conversation, conversation=conversation,
message=message, message=message,
query=query, query=query,
memory=memory, memory=memory,
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables
db_variables=tool_conversation_variables,
model_instance=model_instance
) )
invoke_result = assistant_fc_runner.run( invoke_result = assistant_fc_runner.run(
model_instance=model_instance,
conversation=conversation, conversation=conversation,
message=message, message=message,
query=query, query=query,

+ 14
- 1
api/core/features/assistant_base_runner.py Zobrazit soubor

import logging import logging
import json import json


from typing import Optional, List, Tuple, Union
from typing import Optional, List, Tuple, Union, cast
from datetime import datetime from datetime import datetime
from mimetypes import guess_extension from mimetypes import guess_extension


AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_manager import ModelInstance
from core.file.message_file_parser import FileTransferMethod from core.file.message_file_parser import FileTransferMethod


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
prompt_messages: Optional[List[PromptMessage]] = None, prompt_messages: Optional[List[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None, db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None: ) -> None:
""" """
Agent runner Agent runner
self.history_prompt_messages = prompt_messages self.history_prompt_messages = prompt_messages
self.variables_pool = variables_pool self.variables_pool = variables_pool
self.db_variables_pool = db_variables self.db_variables_pool = db_variables
self.model_instance = model_instance


# init callback # init callback
self.agent_callback = DifyAgentCallbackHandler() self.agent_callback = DifyAgentCallbackHandler()
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
).count() ).count()


# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):
self.stream_tool_call = True
else:
self.stream_tool_call = False

def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity:
""" """
Repacket app orchestration config Repacket app orchestration config

+ 4
- 3
api/core/features/assistant_cot_runner.py Zobrazit soubor

from models.model import Conversation, Message from models.model import Conversation, Message


class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
def run(self, conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
) -> Union[Generator, LLMResult]: ) -> Union[Generator, LLMResult]:
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price llm_usage.completion_price += usage.completion_price


model_instance = self.model_instance

while function_call_state and iteration_step <= max_iteration_steps: while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = False function_call_state = False
# remove Action: xxx from agent thought # remove Action: xxx from agent thought
agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)


if action_name and action_input:
if action_name and action_input is not None:
return AgentScratchpadUnit( return AgentScratchpadUnit(
agent_response=content, agent_response=content,
thought=agent_thought, thought=agent_thought,

+ 105
- 23
api/core/features/assistant_fc_runner.py Zobrazit soubor



from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\ from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\
SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.application_queue_manager import PublishFrom from core.application_queue_manager import PublishFrom


logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)


class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
def run(self, model_instance: ModelInstance,
conversation: Conversation,
def run(self, conversation: Conversation,
message: Message, message: Message,
query: str, query: str,
) -> Generator[LLMResultChunk, None, None]: ) -> Generator[LLMResultChunk, None, None]:
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price llm_usage.completion_price += usage.completion_price


model_instance = self.model_instance

while function_call_state and iteration_step <= max_iteration_steps: while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False function_call_state = False


# recale llm max tokens # recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages) self.recale_llm_max_tokens(self.model_config, prompt_messages)
# invoke model # invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters, model_parameters=app_orchestration_config.model_config.parameters,
tools=prompt_messages_tools, tools=prompt_messages_tools,
stop=app_orchestration_config.model_config.stop, stop=app_orchestration_config.model_config.stop,
stream=True,
stream=self.stream_tool_call,
user=self.user_id, user=self.user_id,
callbacks=[], callbacks=[],
) )


current_llm_usage = None current_llm_usage = None


for chunk in chunks:
if self.stream_tool_call:
for chunk in chunks:
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})

if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += chunk.delta.message.content

if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage

yield chunk
else:
result: LLMResult = chunks
# check if there is any tool call # check if there is any tool call
if self.check_tool_calls(chunk):
if self.check_blocking_tool_calls(result):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps({ tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls tool_call[1]: tool_call[2] for tool_call in tool_calls
}) })


if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
if result.usage:
increase_usage(llm_usage, result.usage)
current_llm_usage = result.usage

if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data response += content.data
else: else:
response += chunk.delta.message.content

if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
response += result.message.content

if not result.message.content:
result.message.content = ''

yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
)
)


yield chunk
if tool_calls:
prompt_messages.append(AssistantPromptMessage(
content='',
name='',
tool_calls=[AssistantPromptMessage.ToolCall(
id=tool_call[0],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls]
))


# save thought # save thought
self.save_agent_thought( self.save_agent_thought(
final_answer += response + '\n' final_answer += response + '\n'


# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))
# call tools # call tools
tool_responses = [] tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls: for tool_call_id, tool_call_name, tool_call_args in tool_calls:
) )
self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)


# update prompt messages
if response.strip():
prompt_messages.append(AssistantPromptMessage(
content=response,
))

# update prompt tool # update prompt tool
for prompt_tool in prompt_messages_tools: for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
if llm_result_chunk.delta.message.tool_calls: if llm_result_chunk.delta.message.tool_calls:
return True return True
return False return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
"""
if llm_result.message.tool_calls:
return True
return False


def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
""" """
)) ))


return tool_calls return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result

Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
json.loads(prompt_message.function.arguments),
))

return tool_calls


def organize_prompt_messages(self, prompt_template: str, def organize_prompt_messages(self, prompt_template: str,
query: str = None, query: str = None,

+ 1
- 0
api/core/model_runtime/entities/model_entities.py Zobrazit soubor

MULTI_TOOL_CALL = "multi-tool-call" MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought" AGENT_THOUGHT = "agent-thought"
VISION = "vision" VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"




class DefaultParameterName(Enum): class DefaultParameterName(Enum):

+ 5
- 0
api/core/model_runtime/model_providers/azure_openai/_constant.py Zobrazit soubor

features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
features=[ features=[
ModelFeature.AGENT_THOUGHT, ModelFeature.AGENT_THOUGHT,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
], ],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={

+ 24
- 4
api/core/model_runtime/model_providers/azure_openai/llm/llm.py Zobrazit soubor

tools: Optional[list[PromptMessageTool]] = None) -> Generator: tools: Optional[list[PromptMessageTool]] = None) -> Generator:
index = 0 index = 0
full_assistant_content = '' full_assistant_content = ''
delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None
real_model = model real_model = model
system_fingerprint = None system_fingerprint = None
completion = '' completion = ''


delta = chunk.choices[0] delta = chunk.choices[0]


if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \
delta.delta.function_call is None:
continue continue
# assistant_message_tool_calls = delta.delta.tool_calls # assistant_message_tool_calls = delta.delta.tool_calls
assistant_message_function_call = delta.delta.function_call assistant_message_function_call = delta.delta.function_call


# extract tool calls from response
if delta_assistant_message_function_call_storage is not None:
# handle process of stream function call
if assistant_message_function_call:
# message has not ended ever
delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments
continue
else:
# message has ended
assistant_message_function_call = delta_assistant_message_function_call_storage
delta_assistant_message_function_call_storage = None
else:
if assistant_message_function_call:
# start of stream function call
delta_assistant_message_function_call_storage = assistant_message_function_call
if delta_assistant_message_function_call_storage.arguments is None:
delta_assistant_message_function_call_storage.arguments = ''
continue

# extract tool calls from response # extract tool calls from response
# tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
function_call = self._extract_response_function_call(assistant_message_function_call) function_call = self._extract_response_function_call(assistant_message_function_call)
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")


if message.name is not None:
if message.name:
message_dict["name"] = message.name message_dict["name"] = message.name


return message_dict return message_dict
num_tokens = 0 num_tokens = 0
for tool in tools: for tool in tools:
num_tokens += len(encoding.encode('type')) num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(tool.get("type")))
num_tokens += len(encoding.encode('function')) num_tokens += len(encoding.encode('function'))


# calculate num tokens for function object # calculate num tokens for function object

+ 5
- 1
api/core/model_runtime/model_providers/chatglm/llm/llm.py Zobrazit soubor



from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction,
PromptMessageTool, SystemPromptMessage, UserPromptMessage)
PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# check if last message is user message
message = cast(ToolPromptMessage, message)
message_dict = {"role": "function", "content": message.content}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")

+ 2
- 0
api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml Zobrazit soubor

model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16384 context_size: 16384

+ 2
- 0
api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml Zobrazit soubor

model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32768 context_size: 32768

+ 1
- 2
api/core/model_runtime/model_providers/minimax/llm/chat_completion.py Zobrazit soubor

""" """
def generate(self, model: str, api_key: str, group_id: str, def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: List[MinimaxMessage], model_parameters: dict, prompt_messages: List[MinimaxMessage], model_parameters: dict,
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
""" """
generate chat completion generate chat completion
continue continue


for choice in choices: for choice in choices:
print(choice)
message = choice['delta'] message = choice['delta']
yield MinimaxMessage( yield MinimaxMessage(
content=message, content=message,

+ 37
- 9
api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py Zobrazit soubor

""" """
def generate(self, model: str, api_key: str, group_id: str, def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: List[MinimaxMessage], model_parameters: dict, prompt_messages: List[MinimaxMessage], model_parameters: dict,
tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \
tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
""" """
generate chat completion generate chat completion
**extra_kwargs **extra_kwargs
} }


if tools:
body['functions'] = tools
body['function_call'] = { 'type': 'auto' }

try: try:
response = post( response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
""" """
handle stream chat generate response handle stream chat generate response
""" """
function_call_storage = None
for line in response.iter_lines(): for line in response.iter_lines():
if not line: if not line:
continue continue
msg = data['base_resp']['status_msg'] msg = data['base_resp']['status_msg']
self._handle_error(code, msg) self._handle_error(code, msg)


if data['reply']:
if data['reply'] or 'usage' in data and data['usage']:
total_tokens = data['usage']['total_tokens'] total_tokens = data['usage']['total_tokens']
message = MinimaxMessage( message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value, role=MinimaxMessage.Role.ASSISTANT.value,
'total_tokens': total_tokens 'total_tokens': total_tokens
} }
message.stop_reason = data['choices'][0]['finish_reason'] message.stop_reason = data['choices'][0]['finish_reason']

if function_call_storage:
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = function_call_storage
yield function_call_message

yield message yield message
return return


continue continue


for choice in choices: for choice in choices:
message = choice['messages'][0]['text']
if not message:
continue
message = choice['messages'][0]

if 'function_call' in message:
if not function_call_storage:
function_call_storage = message['function_call']
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
function_call_storage['arguments'] = ''
continue
else:
function_call_storage['arguments'] += message['function_call']['arguments']
continue
else:
if function_call_storage:
message['function_call'] = function_call_storage
function_call_storage = None
yield MinimaxMessage(
content=message,
role=MinimaxMessage.Role.ASSISTANT.value
)
minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)

if 'function_call' in message:
minimax_message.function_call = message['function_call']

if 'text' in message:
minimax_message.content = message['text']

yield minimax_message

+ 42
- 1
api/core/model_runtime/model_providers/minimax/llm/llm.py Zobrazit soubor



from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
""" """
client: MinimaxChatCompletionPro = self.model_apis[model]() client: MinimaxChatCompletionPro = self.model_apis[model]()


if tools:
tools = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]

response = client.generate( response = client.generate(
model=model, model=model,
api_key=credentials['minimax_api_key'], api_key=credentials['minimax_api_key'],
elif isinstance(prompt_message, UserPromptMessage): elif isinstance(prompt_message, UserPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content)
elif isinstance(prompt_message, AssistantPromptMessage): elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value,
content=''
)
message.function_call={
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
}
return message
return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content)
elif isinstance(prompt_message, ToolPromptMessage):
return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content)
else: else:
raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported')


finish_reason=message.stop_reason if message.stop_reason else None, finish_reason=message.stop_reason if message.stop_reason else None,
), ),
) )
elif message.function_call:
if 'name' not in message.function_call or 'arguments' not in message.function_call:
continue

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content='',
tool_calls=[AssistantPromptMessage.ToolCall(
id='',
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.function_call['name'],
arguments=message.function_call['arguments']
)
)]
),
),
)
else: else:
yield LLMResultChunk( yield LLMResultChunk(
model=model, model=model,

+ 10
- 0
api/core/model_runtime/model_providers/minimax/llm/types.py Zobrazit soubor

USER = 'USER' USER = 'USER'
ASSISTANT = 'BOT' ASSISTANT = 'BOT'
SYSTEM = 'SYSTEM' SYSTEM = 'SYSTEM'
FUNCTION = 'FUNCTION'


role: str = Role.USER.value role: str = Role.USER.value
content: str content: str
usage: Dict[str, int] = None usage: Dict[str, int] = None
stop_reason: str = '' stop_reason: str = ''
function_call: Dict[str, Any] = None


def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value:
return {
'sender_type': 'BOT',
'sender_name': '专家',
'text': '',
'function_call': self.function_call
}
return { return {
'sender_type': self.role, 'sender_type': self.role,
'sender_name': '我' if self.role == 'USER' else '专家', 'sender_name': '我' if self.role == 'USER' else '专家',

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 4096 context_size: 4096

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 16385 context_size: 16385

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 4096 context_size: 4096

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 32768 context_size: 32768

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 128000

+ 1
- 0
api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml Zobrazit soubor

features: features:
- multi-tool-call - multi-tool-call
- agent-thought - agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 8192 context_size: 8192

+ 1
- 1
api/core/model_runtime/model_providers/openai/llm/llm.py Zobrazit soubor

else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")


if message.name is not None:
if message.name:
message_dict["name"] = message.name message_dict["name"] = message.name


return message_dict return message_dict

+ 27
- 4
api/core/model_runtime/model_providers/xinference/llm/llm.py Zobrazit soubor

from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
SystemPromptMessage, UserPromptMessage)
SystemPromptMessage, UserPromptMessage, ToolPromptMessage)
from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType,
ParameterRule, ParameterType)
ParameterRule, ParameterType, ModelFeature)
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper,
from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper,
XinferenceModelExtraParameter) XinferenceModelExtraParameter)
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError,


see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
""" """
if 'temperature' in model_parameters:
if model_parameters['temperature'] < 0.01:
model_parameters['temperature'] = 0.01
elif model_parameters['temperature'] > 1.0:
model_parameters['temperature'] = 0.99

return self._generate( return self._generate(
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
tools=tools, stop=stop, stream=stream, user=user, tools=tools, stop=stop, stream=stream, user=user,
credentials['completion_type'] = 'completion' credentials['completion_type'] = 'completion'
else: else:
raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported') raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported')
if extra_param.support_function_call:
credentials['support_function_call'] = True


except RuntimeError as e: except RuntimeError as e:
raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}')
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content}
else: else:
raise ValueError(f"Unknown message type {type(message)}") raise ValueError(f"Unknown message type {type(message)}")
label=I18nObject( label=I18nObject(
zh_Hans='温度', zh_Hans='温度',
en_US='Temperature' en_US='Temperature'
)
),
), ),
ParameterRule( ParameterRule(
name='top_p', name='top_p',
completion_type = LLMMode.COMPLETION.value completion_type = LLMMode.COMPLETION.value
else: else:
raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported')
support_function_call = credentials.get('support_function_call', False)


entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
features=[
ModelFeature.TOOL_CALL
] if support_function_call else [],
model_properties={ model_properties={
ModelPropertyKey.MODE: completion_type, ModelPropertyKey.MODE: completion_type,
}, },
extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter`
""" """
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]

client = OpenAI( client = OpenAI(
base_url=f'{credentials["server_url"]}/v1', base_url=f'{credentials["server_url"]}/v1',
api_key='abc', api_key='abc',

+ 19
- 4
api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py Zobrazit soubor

from typing import Optional from typing import Optional


from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) InvokeError, InvokeRateLimitError, InvokeServerUnavailableError)
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle


from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper


class XinferenceTextEmbeddingModel(TextEmbeddingModel): class XinferenceTextEmbeddingModel(TextEmbeddingModel):
""" """
""" """
server_url = credentials['server_url'] server_url = credentials['server_url']
model_uid = credentials['model_uid'] model_uid = credentials['model_uid']

if server_url.endswith('/'):
server_url = server_url[:-1]

client = Client(base_url=server_url) client = Client(base_url=server_url)
try: try:
:return: :return:
""" """
try: try:
server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)

if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens

self._invoke(model=model, credentials=credentials, texts=['ping']) self._invoke(model=model, credentials=credentials, texts=['ping'])
except InvokeAuthorizationError:
except (InvokeAuthorizationError, RuntimeError):
raise CredentialsValidateFailedError('Invalid api key') raise CredentialsValidateFailedError('Invalid api key')


@property @property
""" """
used to define customizable model schema used to define customizable model schema
""" """
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject( label=I18nObject(
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model_properties={},
model_properties={
ModelPropertyKey.MAX_CHUNKS: 1,
ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512,
},
parameter_rules=[] parameter_rules=[]
) )



api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py → api/core/model_runtime/model_providers/xinference/xinference_helper.py Zobrazit soubor

from threading import Lock from threading import Lock
from time import time from time import time
from typing import List from typing import List
from os import path


from requests import get from requests import get
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
model_format: str model_format: str
model_handle_type: str model_handle_type: str
model_ability: List[str] model_ability: List[str]
max_tokens: int = 512
support_function_call: bool = False


def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None:
def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str],
support_function_call: bool, max_tokens: int) -> None:
self.model_format = model_format self.model_format = model_format
self.model_handle_type = model_handle_type self.model_handle_type = model_handle_type
self.model_ability = model_ability self.model_ability = model_ability
self.support_function_call = support_function_call
self.max_tokens = max_tokens


cache = {} cache = {}
cache_lock = Lock() cache_lock = Lock()
get xinference model extra parameter like model_format and model_handle_type get xinference model extra parameter like model_format and model_handle_type
""" """


url = f'{server_url}/v1/models/{model_uid}'
url = path.join(server_url, 'v1/models', model_uid)


# this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session() session = Session()
response_json = response.json() response_json = response.json()


model_format = response_json['model_format']
model_ability = response_json['model_ability']
model_format = response_json.get('model_format', 'ggmlv3')
model_ability = response_json.get('model_ability', [])


if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
if response_json.get('model_type') == 'embedding':
model_handle_type = 'embedding'
elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']:
model_handle_type = 'chatglm' model_handle_type = 'chatglm'
elif 'generate' in model_ability: elif 'generate' in model_ability:
model_handle_type = 'generate' model_handle_type = 'generate'
else: else:
raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported')
support_function_call = 'tools' in model_ability
max_tokens = response_json.get('max_tokens', 512)
return XinferenceModelExtraParameter( return XinferenceModelExtraParameter(
model_format=model_format, model_format=model_format,
model_handle_type=model_handle_type, model_handle_type=model_handle_type,
model_ability=model_ability
model_ability=model_ability,
support_function_call=support_function_call,
max_tokens=max_tokens
) )

+ 4
- 0
api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml Zobrazit soubor

label: label:
en_US: glm-3-turbo en_US: glm-3-turbo
model_type: llm model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
parameter_rules: parameter_rules:

+ 4
- 0
api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml Zobrazit soubor

label: label:
en_US: glm-4 en_US: glm-4
model_type: llm model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
parameter_rules: parameter_rules:

+ 21
- 0
api/core/model_runtime/model_providers/zhipuai/llm/llm.py Zobrazit soubor

'content': prompt_message.content, 'content': prompt_message.content,
'tool_call_id': prompt_message.tool_call_id 'tool_call_id': prompt_message.tool_call_id
}) })
elif isinstance(prompt_message, AssistantPromptMessage):
if prompt_message.tool_calls:
params['messages'].append({
'role': 'assistant',
'content': prompt_message.content,
'tool_calls': [
{
'id': tool_call.id,
'type': tool_call.type,
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments
}
} for tool_call in prompt_message.tool_calls
]
})
else:
params['messages'].append({
'role': 'assistant',
'content': prompt_message.content
})
else: else:
params['messages'].append({ params['messages'].append({
'role': prompt_message.role.value, 'role': prompt_message.role.value,

+ 1
- 1
api/requirements.txt Zobrazit soubor

huggingface_hub~=0.16.4 huggingface_hub~=0.16.4
transformers~=4.31.0 transformers~=4.31.0
pandas==1.5.3 pandas==1.5.3
xinference-client~=0.6.4
xinference-client~=0.8.1
safetensors==0.3.2 safetensors==0.3.2
zhipuai==1.0.7 zhipuai==1.0.7
werkzeug~=3.0.1 werkzeug~=3.0.1

+ 61
- 32
api/tests/integration_tests/model_runtime/__mock/xinference.py Zobrazit soubor

raise RuntimeError('404 Not Found') raise RuntimeError('404 Not Found')
if 'generate' == model_uid: if 'generate' == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url)
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'chat' == model_uid: if 'chat' == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url)
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'embedding' == model_uid: if 'embedding' == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url)
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'rerank' == model_uid: if 'rerank' == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url)
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError('404 Not Found') raise RuntimeError('404 Not Found')
def get(self: Session, url: str, **kwargs): def get(self: Session, url: str, **kwargs):
if '/v1/models/' in url:
response = Response()
response = Response()
if 'v1/models/' in url:
# get model uid # get model uid
model_uid = url.split('/')[-1] model_uid = url.split('/')[-1]
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']: model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404 response.status_code = 404
raise ConnectionError('404 Not Found')
return response


# check if url is valid # check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404 response.status_code = 404
raise ConnectionError('404 Not Found')

return response
if model_uid in ['generate', 'chat']:
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response
elif model_uid == 'embedding':
response.status_code = 200
response._content = b'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return response
elif 'v1/cluster/auth' in url:
response.status_code = 200 response.status_code = 200
response._content = b'''{ response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
"auth": true
}''' }'''
return response return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
# check if self._model_uid is a valid uuid # check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
def setup_xinference_mock(request, monkeypatch: MonkeyPatch): def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK: if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank) monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)

Načítá se…
Zrušit
Uložit