Ver código fonte

feat: add LocalAI local embedding model support (#1021)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
tags/0.3.19
takatost 2 anos atrás
pai
commit
417c19577a
Nenhuma conta vinculada ao e-mail do autor do commit
24 arquivos alterados com 1144 adições e 7 exclusões
  1. 3
    0
      api/core/model_providers/model_provider_factory.py
  2. 29
    0
      api/core/model_providers/models/embedding/localai_embedding.py
  3. 131
    0
      api/core/model_providers/models/llm/localai_model.py
  4. 164
    0
      api/core/model_providers/providers/localai_provider.py
  5. 2
    1
      api/core/model_providers/rules/_providers.json
  6. 7
    0
      api/core/model_providers/rules/localai.json
  7. 2
    1
      api/core/third_party/langchain/llms/chat_open_ai.py
  8. 32
    3
      api/core/third_party/langchain/llms/open_ai.py
  9. 4
    1
      api/tests/integration_tests/.env.example
  10. 61
    0
      api/tests/integration_tests/models/embedding/test_localai_embedding.py
  11. 68
    0
      api/tests/integration_tests/models/llm/test_localai_model.py
  12. 116
    0
      api/tests/unit_tests/model_providers/test_localai_provider.py
  13. 22
    0
      web/app/components/base/icons/assets/public/llm/localai-text.svg
  14. 15
    0
      web/app/components/base/icons/assets/public/llm/localai.svg
  15. 107
    0
      web/app/components/base/icons/src/public/llm/Localai.json
  16. 14
    0
      web/app/components/base/icons/src/public/llm/Localai.tsx
  17. 170
    0
      web/app/components/base/icons/src/public/llm/LocalaiText.json
  18. 14
    0
      web/app/components/base/icons/src/public/llm/LocalaiText.tsx
  19. 2
    0
      web/app/components/base/icons/src/public/llm/index.ts
  20. 2
    0
      web/app/components/header/account-setting/model-page/configs/index.ts
  21. 176
    0
      web/app/components/header/account-setting/model-page/configs/localai.tsx
  22. 1
    0
      web/app/components/header/account-setting/model-page/declarations.ts
  23. 1
    0
      web/app/components/header/account-setting/model-page/index.tsx
  24. 1
    1
      web/app/components/header/account-setting/model-page/utils.ts

+ 3
- 0
api/core/model_providers/model_provider_factory.py Ver arquivo

@@ -63,6 +63,9 @@ class ModelProviderFactory:
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
else:
raise NotImplementedError


+ 29
- 0
api/core/model_providers/models/embedding/localai_embedding.py Ver arquivo

@@ -0,0 +1,29 @@
from langchain.embeddings import LocalAIEmbeddings

from replicate.exceptions import ModelError, ReplicateError

from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding


class LocalAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)

client = LocalAIEmbeddings(
model=name,
openai_api_key="1",
openai_api_base=credentials['server_url'],
)

super().__init__(model_provider, client, name)

def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
else:
return ex

+ 131
- 0
api/core/model_providers/models/llm/localai_model.py Ver arquivo

@@ -0,0 +1,131 @@
import logging
from typing import List, Optional, Any

import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string

from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs


class LocalAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)

if credentials['completion_type'] == 'chat_completion':
self.model_mode = ModelMode.CHAT
else:
self.model_mode = ModelMode.COMPLETION

super().__init__(model_provider, name, model_kwargs, streaming, callbacks)

def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1',
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}

client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1'
)

return client

def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.

:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)

def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.

:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)

def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}

self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs

def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to LocalAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to LocalAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("LocalAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex

@classmethod
def support_streaming(cls):
return True

+ 164
- 0
api/core/model_providers/providers/localai_provider.py Ver arquivo

@@ -0,0 +1,164 @@
import json
from typing import Type

from langchain.embeddings import LocalAIEmbeddings
from langchain.schema import HumanMessage

from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError

from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from models.provider import ProviderType


class LocalAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'localai'

def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []

def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = LocalAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = LocalAIEmbedding
else:
raise NotImplementedError

return model_class

def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.

:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
)

@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.

:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')

try:
if model_type == ModelType.EMBEDDINGS:
model = LocalAIEmbeddings(
model=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url']
)

model.embed_query("ping")
else:
if ('completion_type' not in credentials
or credentials['completion_type'] not in ['completion', 'chat_completion']):
raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')

if credentials['completion_type'] == 'chat_completion':
model = EnhanceChatOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)

model([HumanMessage(content='ping')])
else:
model = EnhanceOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)

model('ping')
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.

:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials

def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.

:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError

provider_model = self._get_provider_model(model_name, model_type)

if not provider_model.encrypted_config:
return {
'server_url': None,
}

credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)

if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])

return credentials

@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return

@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}

def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

+ 2
- 1
api/core/model_providers/rules/_providers.json Ver arquivo

@@ -10,5 +10,6 @@
"replicate",
"huggingface_hub",
"xinference",
"openllm"
"openllm",
"localai"
]

+ 7
- 0
api/core/model_providers/rules/localai.json Ver arquivo

@@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

+ 2
- 1
api/core/third_party/langchain/llms/chat_open_ai.py Ver arquivo

@@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,

+ 32
- 3
api/core/third_party/langchain/llms/open_ai.py Ver arquivo

@@ -1,7 +1,10 @@
import os

from typing import Dict, Any, Mapping, Optional, Union, Tuple
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
from langchain import OpenAI
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
from langchain.schema.output import GenerationChunk
from pydantic import root_validator


@@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
@@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
if 'text' in stream_resp["choices"][0]:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)

+ 4
- 1
api/tests/integration_tests/.env.example Ver arquivo

@@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID=

# OpenLLM Credentials
OPENLLM_SERVER_URL=
OPENLLM_SERVER_URL=

# LocalAI Credentials
LOCALAI_SERVER_URL=

+ 61
- 0
api/tests/integration_tests/models/embedding/test_localai_embedding.py Ver arquivo

@@ -0,0 +1,61 @@
import json
import os
from unittest.mock import patch, MagicMock

from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import Provider, ProviderType, ProviderModel


def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)


def get_mock_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
server_url = os.environ['LOCALAI_SERVER_URL']
model_provider = LocalAIProvider(provider=get_mock_provider())

mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url,
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)

return LocalAIEmbedding(
model_provider=model_provider,
name=model_name
)


def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)

+ 68
- 0
api/tests/integration_tests/models/llm/test_localai_model.py Ver arquivo

@@ -0,0 +1,68 @@
import json
import os
from unittest.mock import patch, MagicMock

from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.localai_provider import LocalAIProvider
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from models.provider import Provider, ProviderType, ProviderModel


def get_mock_provider(server_url):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({}),
is_valid=True,
)


def get_mock_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
server_url = os.environ['LOCALAI_SERVER_URL']

mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)

openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
return LocalAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)


def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst > 0


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)

openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0

+ 116
- 0
api/tests/unit_tests/model_providers/test_localai_provider.py Ver arquivo

@@ -0,0 +1,116 @@
import pytest
from unittest.mock import patch, MagicMock
import json

from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import ProviderType, Provider, ProviderModel

PROVIDER_NAME = 'localai'
MODEL_PROVIDER_CLASS = LocalAIProvider
VALIDATE_CREDENTIAL = {
'server_url': 'http://127.0.0.1:8080/'
}


def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'


def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')


def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query',
return_value="abc")

MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='username/test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)


def test_is_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if server_url is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials={}
)


@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt, mocker):
server_url = 'http://127.0.0.1:8080/'

result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', server_url)
assert result['server_url'] == f'encrypted_{server_url}'


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)

encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']

mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)

model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS
)
assert result['server_url'] == 'http://127.0.0.1:8080/'


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)

encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']

mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)

model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
obfuscated=True
)
middle_token = result['server_url'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
assert all(char == '*' for char in middle_token)

+ 22
- 0
web/app/components/base/icons/assets/public/llm/localai-text.svg
Diferenças do arquivo suprimidas por serem muito extensas
Ver arquivo


+ 15
- 0
web/app/components/base/icons/assets/public/llm/localai.svg
Diferenças do arquivo suprimidas por serem muito extensas
Ver arquivo


+ 107
- 0
web/app/components/base/icons/src/public/llm/Localai.json
Diferenças do arquivo suprimidas por serem muito extensas
Ver arquivo


+ 14
- 0
web/app/components/base/icons/src/public/llm/Localai.tsx Ver arquivo

@@ -0,0 +1,14 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY

import * as React from 'react'
import data from './Localai.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'

const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)

export default Icon

+ 170
- 0
web/app/components/base/icons/src/public/llm/LocalaiText.json
Diferenças do arquivo suprimidas por serem muito extensas
Ver arquivo


+ 14
- 0
web/app/components/base/icons/src/public/llm/LocalaiText.tsx Ver arquivo

@@ -0,0 +1,14 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY

import * as React from 'react'
import data from './LocalaiText.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'

const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)

export default Icon

+ 2
- 0
web/app/components/base/icons/src/public/llm/index.ts Ver arquivo

@@ -14,6 +14,8 @@ export { default as Huggingface } from './Huggingface'
export { default as IflytekSparkTextCn } from './IflytekSparkTextCn'
export { default as IflytekSparkText } from './IflytekSparkText'
export { default as IflytekSpark } from './IflytekSpark'
export { default as LocalaiText } from './LocalaiText'
export { default as Localai } from './Localai'
export { default as Microsoft } from './Microsoft'
export { default as OpenaiBlack } from './OpenaiBlack'
export { default as OpenaiBlue } from './OpenaiBlue'

+ 2
- 0
web/app/components/header/account-setting/model-page/configs/index.ts Ver arquivo

@@ -10,6 +10,7 @@ import minimax from './minimax'
import chatglm from './chatglm'
import xinference from './xinference'
import openllm from './openllm'
import localai from './localai'

export default {
openai,
@@ -24,4 +25,5 @@ export default {
chatglm,
xinference,
openllm,
localai,
}

+ 176
- 0
web/app/components/header/account-setting/model-page/configs/localai.tsx Ver arquivo

@@ -0,0 +1,176 @@
import { ProviderEnum } from '../declarations'
import type { FormValue, ProviderConfig } from '../declarations'
import { Localai, LocalaiText } from '@/app/components/base/icons/src/public/llm'

const config: ProviderConfig = {
selector: {
name: {
'en': 'LocalAI',
'zh-Hans': 'LocalAI',
},
icon: <Localai className='w-full h-full' />,
},
item: {
key: ProviderEnum.localai,
titleIcon: {
'en': <LocalaiText className='h-6' />,
'zh-Hans': <LocalaiText className='h-6' />,
},
disable: {
tip: {
'en': 'Only supports the ',
'zh-Hans': '仅支持',
},
link: {
href: {
'en': 'https://docs.dify.ai/getting-started/install-self-hosted',
'zh-Hans': 'https://docs.dify.ai/v/zh-hans/getting-started/install-self-hosted',
},
label: {
'en': 'community open-source version',
'zh-Hans': '社区开源版本',
},
},
},
},
modal: {
key: ProviderEnum.localai,
title: {
'en': 'LocalAI',
'zh-Hans': 'LocalAI',
},
icon: <Localai className='h-6' />,
link: {
href: 'https://github.com/go-skynet/LocalAI',
label: {
'en': 'How to deploy LocalAI',
'zh-Hans': '如何部署 LocalAI',
},
},
defaultValue: {
model_type: 'text-generation',
completion_type: 'completion',
},
validateKeys: (v?: FormValue) => {
if (v?.model_type === 'text-generation') {
return [
'model_type',
'model_name',
'server_url',
'completion_type',
]
}
if (v?.model_type === 'embeddings') {
return [
'model_type',
'model_name',
'server_url',
]
}
return []
},
filterValue: (v?: FormValue) => {
let filteredKeys: string[] = []
if (v?.model_type === 'text-generation') {
filteredKeys = [
'model_type',
'model_name',
'server_url',
'completion_type',
]
}
if (v?.model_type === 'embeddings') {
filteredKeys = [
'model_type',
'model_name',
'server_url',
]
}
return filteredKeys.reduce((prev: FormValue, next: string) => {
prev[next] = v?.[next] || ''
return prev
}, {})
},
fields: [
{
type: 'radio',
key: 'model_type',
required: true,
label: {
'en': 'Model Type',
'zh-Hans': '模型类型',
},
options: [
{
key: 'text-generation',
label: {
'en': 'Text Generation',
'zh-Hans': '文本生成',
},
},
{
key: 'embeddings',
label: {
'en': 'Embeddings',
'zh-Hans': 'Embeddings',
},
},
],
},
{
type: 'text',
key: 'model_name',
required: true,
label: {
'en': 'Model Name',
'zh-Hans': '模型名称',
},
placeholder: {
'en': 'Enter your Model Name here',
'zh-Hans': '在此输入您的模型名称',
},
},
{
hidden: (value?: FormValue) => value?.model_type === 'embeddings',
type: 'radio',
key: 'completion_type',
required: true,
label: {
'en': 'Completion Type',
'zh-Hans': 'Completion Type',
},
options: [
{
key: 'completion',
label: {
'en': 'Completion',
'zh-Hans': 'Completion',
},
},
{
key: 'chat_completion',
label: {
'en': 'Chat Completion',
'zh-Hans': 'Chat Completion',
},
},
],
},
{
type: 'text',
key: 'server_url',
required: true,
label: {
'en': 'Server url',
'zh-Hans': 'Server url',
},
placeholder: {
'en': 'Enter your Server Url, eg: https://example.com/xxx',
'zh-Hans': '在此输入您的 Server Url,如:https://example.com/xxx',
},
},
],
},
}

export default config

+ 1
- 0
web/app/components/header/account-setting/model-page/declarations.ts Ver arquivo

@@ -41,6 +41,7 @@ export enum ProviderEnum {
'chatglm' = 'chatglm',
'xinference' = 'xinference',
'openllm' = 'openllm',
'localai' = 'localai',
}

export type ProviderConfigItem = {

+ 1
- 0
web/app/components/header/account-setting/model-page/index.tsx Ver arquivo

@@ -99,6 +99,7 @@ const ModelPage = () => {
config.chatglm,
config.xinference,
config.openllm,
config.localai,
]
}


+ 1
- 1
web/app/components/header/account-setting/model-page/utils.ts Ver arquivo

@@ -2,7 +2,7 @@ import { ValidatedStatus } from '../key-validator/declarations'
import { ProviderEnum } from './declarations'
import { validateModelProvider } from '@/service/common'

export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm]
export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm, ProviderEnum.localai]

export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => {
let body, url

Carregando…
Cancelar
Salvar