Przeglądaj źródła

Fix/localai (#2840)

tags/0.5.10
Yeuoly 1 rok temu
rodzic
commit
742be06ea9
No account linked to committer's email address

+ 10
- 4
api/core/model_runtime/model_providers/localai/llm/llm.py Wyświetl plik

from collections.abc import Generator from collections.abc import Generator
from typing import cast from typing import cast
from urllib.parse import urljoin


from httpx import Timeout from httpx import Timeout
from openai import ( from openai import (
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.completion import Completion from openai.types.completion import Completion
from yarl import URL


from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
UserPromptMessage(content='ping') UserPromptMessage(content='ping')
], model_parameters={ ], model_parameters={
'max_tokens': 10, 'max_tokens': 10,
}, stop=[])
}, stop=[], stream=False)
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}') raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')


) )
] ]


model_properties = {
ModelPropertyKey.MODE: completion_model,
} if completion_model else {}

model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))

entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,
label=I18nObject( label=I18nObject(
), ),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM, model_type=ModelType.LLM,
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
model_properties=model_properties,
parameter_rules=rules parameter_rules=rules
) )


client_kwargs = { client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1", "api_key": "1",
"base_url": urljoin(credentials['server_url'], 'v1'),
"base_url": str(URL(credentials['server_url']) / 'v1'),
} }


return client_kwargs return client_kwargs

+ 9
- 0
api/core/model_runtime/model_providers/localai/localai.yaml Wyświetl plik

placeholder: placeholder:
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080 zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080 en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
- variable: context_size
label:
zh_Hans: 上下文大小
en_US: Context size
placeholder:
zh_Hans: 输入上下文大小
en_US: Enter context size
required: false
type: text-input

+ 25
- 3
api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py Wyświetl plik

import time import time
from json import JSONDecodeError, dumps from json import JSONDecodeError, dumps
from os.path import join
from typing import Optional from typing import Optional


from requests import post from requests import post
from yarl import URL


from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
} }


try: try:
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
except Exception as e: except Exception as e:
raise InvokeConnectionError(str(e)) raise InvokeConnectionError(str(e))
# use GPT2Tokenizer to get num tokens # use GPT2Tokenizer to get num tokens
num_tokens += self._get_num_tokens_by_gpt2(text) num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens return num_tokens
def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema

:param model: model name
:param credentials: model credentials
:return: model schema
"""
return AIModelEntity(
model=model,
label=I18nObject(zh_Hans=model, en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
features=[],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[]
)


def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
""" """

Ładowanie…
Anuluj
Zapisz