| @@ -0,0 +1 @@ | |||
| <svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg> | |||
| @@ -0,0 +1,63 @@ | |||
| model: grok-beta | |||
| label: | |||
| en_US: Grok beta | |||
| model_type: llm | |||
| features: | |||
| - multi-tool-call | |||
| model_properties: | |||
| mode: chat | |||
| context_size: 131072 | |||
| parameter_rules: | |||
| - name: temperature | |||
| label: | |||
| en_US: "Temperature" | |||
| zh_Hans: "采样温度" | |||
| type: float | |||
| default: 0.7 | |||
| min: 0.0 | |||
| max: 2.0 | |||
| precision: 1 | |||
| required: true | |||
| help: | |||
| en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." | |||
| zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" | |||
| - name: top_p | |||
| label: | |||
| en_US: "Top P" | |||
| zh_Hans: "Top P" | |||
| type: float | |||
| default: 0.7 | |||
| min: 0.0 | |||
| max: 1.0 | |||
| precision: 1 | |||
| required: true | |||
| help: | |||
| en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." | |||
| zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" | |||
| - name: frequency_penalty | |||
| use_template: frequency_penalty | |||
| label: | |||
| en_US: "Frequency Penalty" | |||
| zh_Hans: "频率惩罚" | |||
| type: float | |||
| default: 0 | |||
| min: 0 | |||
| max: 2.0 | |||
| precision: 1 | |||
| required: false | |||
| help: | |||
| en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." | |||
| zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" | |||
| - name: user | |||
| use_template: text | |||
| label: | |||
| en_US: "User" | |||
| zh_Hans: "用户" | |||
| type: string | |||
| required: false | |||
| help: | |||
| en_US: "Used to track and differentiate conversation requests from different users." | |||
| zh_Hans: "用于追踪和区分不同用户的对话请求。" | |||
| @@ -0,0 +1,37 @@ | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| from yarl import URL | |||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMResult | |||
| from core.model_runtime.entities.message_entities import ( | |||
| PromptMessage, | |||
| PromptMessageTool, | |||
| ) | |||
| from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel | |||
| class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||
| def _invoke( | |||
| self, | |||
| model: str, | |||
| credentials: dict, | |||
| prompt_messages: list[PromptMessage], | |||
| model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[list[str]] = None, | |||
| stream: bool = True, | |||
| user: Optional[str] = None, | |||
| ) -> Union[LLMResult, Generator]: | |||
| self._add_custom_parameters(credentials) | |||
| return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| self._add_custom_parameters(credentials) | |||
| super().validate_credentials(model, credentials) | |||
| @staticmethod | |||
| def _add_custom_parameters(credentials) -> None: | |||
| credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" | |||
| credentials["mode"] = LLMMode.CHAT.value | |||
| credentials["function_calling_type"] = "tool_call" | |||
| @@ -0,0 +1,25 @@ | |||
| import logging | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||
| logger = logging.getLogger(__name__) | |||
| class XAIProvider(ModelProvider): | |||
| def validate_provider_credentials(self, credentials: dict) -> None: | |||
| """ | |||
| Validate provider credentials | |||
| if validate failed, raise exception | |||
| :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. | |||
| """ | |||
| try: | |||
| model_instance = self.get_model_instance(ModelType.LLM) | |||
| model_instance.validate_credentials(model="grok-beta", credentials=credentials) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ex | |||
| except Exception as ex: | |||
| logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") | |||
| raise ex | |||
| @@ -0,0 +1,38 @@ | |||
| provider: x | |||
| label: | |||
| en_US: xAI | |||
| description: | |||
| en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. | |||
| icon_small: | |||
| en_US: x-ai-logo.svg | |||
| icon_large: | |||
| en_US: x-ai-logo.svg | |||
| help: | |||
| title: | |||
| en_US: Get your token from xAI | |||
| zh_Hans: 从 xAI 获取 token | |||
| url: | |||
| en_US: https://x.ai/api | |||
| supported_model_types: | |||
| - llm | |||
| configurate_methods: | |||
| - predefined-model | |||
| provider_credential_schema: | |||
| credential_form_schemas: | |||
| - variable: api_key | |||
| label: | |||
| en_US: API Key | |||
| type: secret-input | |||
| required: true | |||
| placeholder: | |||
| zh_Hans: 在此输入您的 API Key | |||
| en_US: Enter your API Key | |||
| - variable: endpoint_url | |||
| label: | |||
| en_US: API Base | |||
| type: text-input | |||
| required: false | |||
| default: https://api.x.ai/v1 | |||
| placeholder: | |||
| zh_Hans: 在此输入您的 API Base | |||
| en_US: Enter your API Base | |||
| @@ -95,3 +95,7 @@ GPUSTACK_API_KEY= | |||
| # Gitee AI Credentials | |||
| GITEE_AI_API_KEY= | |||
| # xAI Credentials | |||
| XAI_API_KEY= | |||
| XAI_API_BASE= | |||
| @@ -0,0 +1,204 @@ | |||
| import os | |||
| from collections.abc import Generator | |||
| import pytest | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| PromptMessageTool, | |||
| SystemPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.entities.model_entities import AIModelEntity | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel | |||
| """FOR MOCK FIXTURES, DO NOT REMOVE""" | |||
| from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock | |||
| def test_predefined_models(): | |||
| model = XAILargeLanguageModel() | |||
| model_schemas = model.predefined_models() | |||
| assert len(model_schemas) >= 1 | |||
| assert isinstance(model_schemas[0], AIModelEntity) | |||
| @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) | |||
| def test_validate_credentials_for_chat_model(setup_openai_mock): | |||
| model = XAILargeLanguageModel() | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| # model name to gpt-3.5-turbo because of mocking | |||
| model.validate_credentials( | |||
| model="gpt-3.5-turbo", | |||
| credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"}, | |||
| ) | |||
| model.validate_credentials( | |||
| model="grok-beta", | |||
| credentials={ | |||
| "api_key": os.environ.get("XAI_API_KEY"), | |||
| "endpoint_url": os.environ.get("XAI_API_BASE"), | |||
| "mode": "chat", | |||
| }, | |||
| ) | |||
| @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) | |||
| def test_invoke_chat_model(setup_openai_mock): | |||
| model = XAILargeLanguageModel() | |||
| result = model.invoke( | |||
| model="grok-beta", | |||
| credentials={ | |||
| "api_key": os.environ.get("XAI_API_KEY"), | |||
| "endpoint_url": os.environ.get("XAI_API_BASE"), | |||
| "mode": "chat", | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content="You are a helpful AI assistant.", | |||
| ), | |||
| UserPromptMessage(content="Hello World!"), | |||
| ], | |||
| model_parameters={ | |||
| "temperature": 0.0, | |||
| "top_p": 1.0, | |||
| "presence_penalty": 0.0, | |||
| "frequency_penalty": 0.0, | |||
| "max_tokens": 10, | |||
| }, | |||
| stop=["How"], | |||
| stream=False, | |||
| user="foo", | |||
| ) | |||
| assert isinstance(result, LLMResult) | |||
| assert len(result.message.content) > 0 | |||
| @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) | |||
| def test_invoke_chat_model_with_tools(setup_openai_mock): | |||
| model = XAILargeLanguageModel() | |||
| result = model.invoke( | |||
| model="grok-beta", | |||
| credentials={ | |||
| "api_key": os.environ.get("XAI_API_KEY"), | |||
| "endpoint_url": os.environ.get("XAI_API_BASE"), | |||
| "mode": "chat", | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content="You are a helpful AI assistant.", | |||
| ), | |||
| UserPromptMessage( | |||
| content="what's the weather today in London?", | |||
| ), | |||
| ], | |||
| model_parameters={"temperature": 0.0, "max_tokens": 100}, | |||
| tools=[ | |||
| PromptMessageTool( | |||
| name="get_weather", | |||
| description="Determine weather in my location", | |||
| parameters={ | |||
| "type": "object", | |||
| "properties": { | |||
| "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, | |||
| "unit": {"type": "string", "enum": ["c", "f"]}, | |||
| }, | |||
| "required": ["location"], | |||
| }, | |||
| ), | |||
| PromptMessageTool( | |||
| name="get_stock_price", | |||
| description="Get the current stock price", | |||
| parameters={ | |||
| "type": "object", | |||
| "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, | |||
| "required": ["symbol"], | |||
| }, | |||
| ), | |||
| ], | |||
| stream=False, | |||
| user="foo", | |||
| ) | |||
| assert isinstance(result, LLMResult) | |||
| assert isinstance(result.message, AssistantPromptMessage) | |||
| @pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) | |||
| def test_invoke_stream_chat_model(setup_openai_mock): | |||
| model = XAILargeLanguageModel() | |||
| result = model.invoke( | |||
| model="grok-beta", | |||
| credentials={ | |||
| "api_key": os.environ.get("XAI_API_KEY"), | |||
| "endpoint_url": os.environ.get("XAI_API_BASE"), | |||
| "mode": "chat", | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content="You are a helpful AI assistant.", | |||
| ), | |||
| UserPromptMessage(content="Hello World!"), | |||
| ], | |||
| model_parameters={"temperature": 0.0, "max_tokens": 100}, | |||
| stream=True, | |||
| user="foo", | |||
| ) | |||
| assert isinstance(result, Generator) | |||
| for chunk in result: | |||
| assert isinstance(chunk, LLMResultChunk) | |||
| assert isinstance(chunk.delta, LLMResultChunkDelta) | |||
| assert isinstance(chunk.delta.message, AssistantPromptMessage) | |||
| assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True | |||
| if chunk.delta.finish_reason is not None: | |||
| assert chunk.delta.usage is not None | |||
| assert chunk.delta.usage.completion_tokens > 0 | |||
| def test_get_num_tokens(): | |||
| model = XAILargeLanguageModel() | |||
| num_tokens = model.get_num_tokens( | |||
| model="grok-beta", | |||
| credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, | |||
| prompt_messages=[UserPromptMessage(content="Hello World!")], | |||
| ) | |||
| assert num_tokens == 10 | |||
| num_tokens = model.get_num_tokens( | |||
| model="grok-beta", | |||
| credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content="You are a helpful AI assistant.", | |||
| ), | |||
| UserPromptMessage(content="Hello World!"), | |||
| ], | |||
| tools=[ | |||
| PromptMessageTool( | |||
| name="get_weather", | |||
| description="Determine weather in my location", | |||
| parameters={ | |||
| "type": "object", | |||
| "properties": { | |||
| "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, | |||
| "unit": {"type": "string", "enum": ["c", "f"]}, | |||
| }, | |||
| "required": ["location"], | |||
| }, | |||
| ), | |||
| ], | |||
| ) | |||
| assert num_tokens == 77 | |||