| @@ -1,6 +1,5 @@ | |||
| from collections.abc import Generator | |||
| from typing import cast | |||
| from urllib.parse import urljoin | |||
| from httpx import Timeout | |||
| from openai import ( | |||
| @@ -19,6 +18,7 @@ from openai import ( | |||
| from openai.types.chat import ChatCompletion, ChatCompletionChunk | |||
| from openai.types.chat.chat_completion_message import FunctionCall | |||
| from openai.types.completion import Completion | |||
| from yarl import URL | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta | |||
| @@ -181,7 +181,7 @@ class LocalAILarguageModel(LargeLanguageModel): | |||
| UserPromptMessage(content='ping') | |||
| ], model_parameters={ | |||
| 'max_tokens': 10, | |||
| }, stop=[]) | |||
| }, stop=[], stream=False) | |||
| except Exception as ex: | |||
| raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') | |||
| @@ -227,6 +227,12 @@ class LocalAILarguageModel(LargeLanguageModel): | |||
| ) | |||
| ] | |||
| model_properties = { | |||
| ModelPropertyKey.MODE: completion_model, | |||
| } if completion_model else {} | |||
| model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048')) | |||
| entity = AIModelEntity( | |||
| model=model, | |||
| label=I18nObject( | |||
| @@ -234,7 +240,7 @@ class LocalAILarguageModel(LargeLanguageModel): | |||
| ), | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_type=ModelType.LLM, | |||
| model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {}, | |||
| model_properties=model_properties, | |||
| parameter_rules=rules | |||
| ) | |||
| @@ -319,7 +325,7 @@ class LocalAILarguageModel(LargeLanguageModel): | |||
| client_kwargs = { | |||
| "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), | |||
| "api_key": "1", | |||
| "base_url": urljoin(credentials['server_url'], 'v1'), | |||
| "base_url": str(URL(credentials['server_url']) / 'v1'), | |||
| } | |||
| return client_kwargs | |||
| @@ -56,3 +56,12 @@ model_credential_schema: | |||
| placeholder: | |||
| zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080 | |||
| en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080 | |||
| - variable: context_size | |||
| label: | |||
| zh_Hans: 上下文大小 | |||
| en_US: Context size | |||
| placeholder: | |||
| zh_Hans: 输入上下文大小 | |||
| en_US: Enter context size | |||
| required: false | |||
| type: text-input | |||
| @@ -1,11 +1,12 @@ | |||
| import time | |||
| from json import JSONDecodeError, dumps | |||
| from os.path import join | |||
| from typing import Optional | |||
| from requests import post | |||
| from yarl import URL | |||
| from core.model_runtime.entities.model_entities import PriceType | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType | |||
| from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| @@ -57,7 +58,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): | |||
| } | |||
| try: | |||
| response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10) | |||
| response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10) | |||
| except Exception as e: | |||
| raise InvokeConnectionError(str(e)) | |||
| @@ -113,6 +114,27 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel): | |||
| # use GPT2Tokenizer to get num tokens | |||
| num_tokens += self._get_num_tokens_by_gpt2(text) | |||
| return num_tokens | |||
| def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: | |||
| """ | |||
| Get customizable model schema | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: model schema | |||
| """ | |||
| return AIModelEntity( | |||
| model=model, | |||
| label=I18nObject(zh_Hans=model, en_US=model), | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| features=[], | |||
| fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | |||
| model_properties={ | |||
| ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')), | |||
| ModelPropertyKey.MAX_CHUNKS: 1, | |||
| }, | |||
| parameter_rules=[] | |||
| ) | |||
| def validate_credentials(self, model: str, credentials: dict) -> None: | |||
| """ | |||