Co-authored-by: moon <moon@vessl.ai>tags/0.11.0
| @@ -0,0 +1,3 @@ | |||
| <svg width="1200" height="925" viewBox="0 0 1200 925" fill="none" xmlns="http://www.w3.org/2000/svg"> | |||
| <path d="M780.152 250.999L907.882 462.174C907.882 462.174 880.925 510.854 867.43 535.21C834.845 594.039 764.171 612.49 710.442 508.333L420.376 0H0L459.926 803.307C552.303 964.663 787.366 964.663 879.743 803.307C989.874 610.952 1089.87 441.97 1200 249.646L1052.28 0H639.519L780.152 250.999Z" fill="#3366FF"/> | |||
| </svg> | |||
| @@ -0,0 +1,83 @@ | |||
| from decimal import Decimal | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.llm_entities import LLMMode | |||
| from core.model_runtime.entities.model_entities import ( | |||
| AIModelEntity, | |||
| DefaultParameterName, | |||
| FetchFrom, | |||
| ModelPropertyKey, | |||
| ModelType, | |||
| ParameterRule, | |||
| ParameterType, | |||
| PriceConfig, | |||
| ) | |||
| from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel | |||
| class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||
| def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: | |||
| features = [] | |||
| entity = AIModelEntity( | |||
| model=model, | |||
| label=I18nObject(en_US=model), | |||
| model_type=ModelType.LLM, | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| features=features, | |||
| model_properties={ | |||
| ModelPropertyKey.MODE: credentials.get("mode"), | |||
| }, | |||
| parameter_rules=[ | |||
| ParameterRule( | |||
| name=DefaultParameterName.TEMPERATURE.value, | |||
| label=I18nObject(en_US="Temperature"), | |||
| type=ParameterType.FLOAT, | |||
| default=float(credentials.get("temperature", 0.7)), | |||
| min=0, | |||
| max=2, | |||
| precision=2, | |||
| ), | |||
| ParameterRule( | |||
| name=DefaultParameterName.TOP_P.value, | |||
| label=I18nObject(en_US="Top P"), | |||
| type=ParameterType.FLOAT, | |||
| default=float(credentials.get("top_p", 1)), | |||
| min=0, | |||
| max=1, | |||
| precision=2, | |||
| ), | |||
| ParameterRule( | |||
| name=DefaultParameterName.TOP_K.value, | |||
| label=I18nObject(en_US="Top K"), | |||
| type=ParameterType.INT, | |||
| default=int(credentials.get("top_k", 50)), | |||
| min=-2147483647, | |||
| max=2147483647, | |||
| precision=0, | |||
| ), | |||
| ParameterRule( | |||
| name=DefaultParameterName.MAX_TOKENS.value, | |||
| label=I18nObject(en_US="Max Tokens"), | |||
| type=ParameterType.INT, | |||
| default=512, | |||
| min=1, | |||
| max=int(credentials.get("max_tokens_to_sample", 4096)), | |||
| ), | |||
| ], | |||
| pricing=PriceConfig( | |||
| input=Decimal(credentials.get("input_price", 0)), | |||
| output=Decimal(credentials.get("output_price", 0)), | |||
| unit=Decimal(credentials.get("unit", 0)), | |||
| currency=credentials.get("currency", "USD"), | |||
| ), | |||
| ) | |||
| if credentials["mode"] == "chat": | |||
| entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value | |||
| elif credentials["mode"] == "completion": | |||
| entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value | |||
| else: | |||
| raise ValueError(f"Unknown completion type {credentials['completion_type']}") | |||
| return entity | |||
| @@ -0,0 +1,10 @@ | |||
| import logging | |||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||
| logger = logging.getLogger(__name__) | |||
| class VesslAIProvider(ModelProvider): | |||
| def validate_provider_credentials(self, credentials: dict) -> None: | |||
| pass | |||
| @@ -0,0 +1,56 @@ | |||
| provider: vessl_ai | |||
| label: | |||
| en_US: vessl_ai | |||
| icon_small: | |||
| en_US: icon_s_en.svg | |||
| icon_large: | |||
| en_US: icon_l_en.png | |||
| background: "#F1EFED" | |||
| help: | |||
| title: | |||
| en_US: How to deploy VESSL AI LLM Model Endpoint | |||
| url: | |||
| en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment | |||
| supported_model_types: | |||
| - llm | |||
| configurate_methods: | |||
| - customizable-model | |||
| model_credential_schema: | |||
| model: | |||
| label: | |||
| en_US: Model Name | |||
| placeholder: | |||
| en_US: Enter your model name | |||
| credential_form_schemas: | |||
| - variable: endpoint_url | |||
| label: | |||
| en_US: endpoint url | |||
| type: text-input | |||
| required: true | |||
| placeholder: | |||
| en_US: Enter the url of your endpoint url | |||
| - variable: api_key | |||
| required: true | |||
| label: | |||
| en_US: API Key | |||
| type: secret-input | |||
| placeholder: | |||
| en_US: Enter your VESSL AI secret key | |||
| - variable: mode | |||
| show_on: | |||
| - variable: __model_type | |||
| value: llm | |||
| label: | |||
| en_US: Completion mode | |||
| type: select | |||
| required: false | |||
| default: chat | |||
| placeholder: | |||
| en_US: Select completion mode | |||
| options: | |||
| - value: completion | |||
| label: | |||
| en_US: Completion | |||
| - value: chat | |||
| label: | |||
| en_US: Chat | |||
| @@ -84,5 +84,10 @@ VOLC_EMBEDDING_ENDPOINT_ID= | |||
| # 360 AI Credentials | |||
| ZHINAO_API_KEY= | |||
| # VESSL AI Credentials | |||
| VESSL_AI_MODEL_NAME= | |||
| VESSL_AI_API_KEY= | |||
| VESSL_AI_ENDPOINT_URL= | |||
| # Gitee AI Credentials | |||
| GITEE_AI_API_KEY= | |||
| GITEE_AI_API_KEY= | |||
| @@ -0,0 +1,131 @@ | |||
| import os | |||
| from collections.abc import Generator | |||
| import pytest | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import ( | |||
| AssistantPromptMessage, | |||
| SystemPromptMessage, | |||
| UserPromptMessage, | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel | |||
| def test_validate_credentials(): | |||
| model = VesslAILargeLanguageModel() | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model=os.environ.get("VESSL_AI_MODEL_NAME"), | |||
| credentials={ | |||
| "api_key": "invalid_key", | |||
| "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), | |||
| "mode": "chat", | |||
| }, | |||
| ) | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| model.validate_credentials( | |||
| model=os.environ.get("VESSL_AI_MODEL_NAME"), | |||
| credentials={ | |||
| "api_key": os.environ.get("VESSL_AI_API_KEY"), | |||
| "endpoint_url": "http://invalid_url", | |||
| "mode": "chat", | |||
| }, | |||
| ) | |||
| model.validate_credentials( | |||
| model=os.environ.get("VESSL_AI_MODEL_NAME"), | |||
| credentials={ | |||
| "api_key": os.environ.get("VESSL_AI_API_KEY"), | |||
| "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), | |||
| "mode": "chat", | |||
| }, | |||
| ) | |||
| def test_invoke_model(): | |||
| model = VesslAILargeLanguageModel() | |||
| response = model.invoke( | |||
| model=os.environ.get("VESSL_AI_MODEL_NAME"), | |||
| credentials={ | |||
| "api_key": os.environ.get("VESSL_AI_API_KEY"), | |||
| "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), | |||
| "mode": "chat", | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content="You are a helpful AI assistant.", | |||
| ), | |||
| UserPromptMessage(content="Who are you?"), | |||
| ], | |||
| model_parameters={ | |||
| "temperature": 1.0, | |||
| "top_k": 2, | |||
| "top_p": 0.5, | |||
| }, | |||
| stop=["How"], | |||
| stream=False, | |||
| user="abc-123", | |||
| ) | |||
| assert isinstance(response, LLMResult) | |||
| assert len(response.message.content) > 0 | |||
| def test_invoke_stream_model(): | |||
| model = VesslAILargeLanguageModel() | |||
| response = model.invoke( | |||
| model=os.environ.get("VESSL_AI_MODEL_NAME"), | |||
| credentials={ | |||
| "api_key": os.environ.get("VESSL_AI_API_KEY"), | |||
| "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), | |||
| "mode": "chat", | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content="You are a helpful AI assistant.", | |||
| ), | |||
| UserPromptMessage(content="Who are you?"), | |||
| ], | |||
| model_parameters={ | |||
| "temperature": 1.0, | |||
| "top_k": 2, | |||
| "top_p": 0.5, | |||
| }, | |||
| stop=["How"], | |||
| stream=True, | |||
| user="abc-123", | |||
| ) | |||
| assert isinstance(response, Generator) | |||
| for chunk in response: | |||
| assert isinstance(chunk, LLMResultChunk) | |||
| assert isinstance(chunk.delta, LLMResultChunkDelta) | |||
| assert isinstance(chunk.delta.message, AssistantPromptMessage) | |||
| def test_get_num_tokens(): | |||
| model = VesslAILargeLanguageModel() | |||
| num_tokens = model.get_num_tokens( | |||
| model=os.environ.get("VESSL_AI_MODEL_NAME"), | |||
| credentials={ | |||
| "api_key": os.environ.get("VESSL_AI_API_KEY"), | |||
| "endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"), | |||
| }, | |||
| prompt_messages=[ | |||
| SystemPromptMessage( | |||
| content="You are a helpful AI assistant.", | |||
| ), | |||
| UserPromptMessage(content="Hello World!"), | |||
| ], | |||
| ) | |||
| assert isinstance(num_tokens, int) | |||
| assert num_tokens == 21 | |||