### What problem does this PR solve? #762 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com>tags/v0.9.0
| @@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va | |||
| from api.db import StatusEnum, LLMType | |||
| from api.db.db_models import TenantLLM | |||
| from api.utils.api_utils import get_json_result | |||
| from rag.llm import EmbeddingModel, ChatModel, RerankModel | |||
| from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel | |||
| @manager.route('/factories', methods=['GET']) | |||
| @@ -126,6 +126,9 @@ def add_llm(): | |||
| api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \ | |||
| f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \ | |||
| f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}' | |||
| elif factory == "LocalAI": | |||
| llm_name = req["llm_name"]+"___LocalAI" | |||
| api_key = "xxxxxxxxxxxxxxx" | |||
| else: | |||
| llm_name = req["llm_name"] | |||
| api_key = "xxxxxxxxxxxxxxx" | |||
| @@ -176,6 +179,21 @@ def add_llm(): | |||
| except Exception as e: | |||
| msg += f"\nFail to access model({llm['llm_name']})." + str( | |||
| e) | |||
| elif llm["model_type"] == LLMType.IMAGE2TEXT.value: | |||
| mdl = CvModel[factory]( | |||
| key=None, model_name=llm["llm_name"], base_url=llm["api_base"] | |||
| ) | |||
| try: | |||
| img_url = ( | |||
| "https://upload.wikimedia.org/wikipedia/comm" | |||
| "ons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/256" | |||
| "0px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" | |||
| ) | |||
| m, tc = mdl.describe(img_url) | |||
| if not tc: | |||
| raise Exception(m) | |||
| except Exception as e: | |||
| msg += f"\nFail to access model({llm['llm_name']})." + str(e) | |||
| else: | |||
| # TODO: check other type of models | |||
| pass | |||
| @@ -157,6 +157,13 @@ | |||
| "status": "1", | |||
| "llm": [] | |||
| }, | |||
| { | |||
| "name": "LocalAI", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "status": "1", | |||
| "llm": [] | |||
| }, | |||
| { | |||
| "name": "Moonshot", | |||
| "logo": "", | |||
| @@ -21,6 +21,7 @@ from .rerank_model import * | |||
| EmbeddingModel = { | |||
| "Ollama": OllamaEmbed, | |||
| "LocalAI": LocalAIEmbed, | |||
| "OpenAI": OpenAIEmbed, | |||
| "Azure-OpenAI": AzureEmbed, | |||
| "Xinference": XinferenceEmbed, | |||
| @@ -46,7 +47,8 @@ CvModel = { | |||
| "ZHIPU-AI": Zhipu4V, | |||
| "Moonshot": LocalCV, | |||
| 'Gemini':GeminiCV, | |||
| 'OpenRouter':OpenRouterCV | |||
| 'OpenRouter':OpenRouterCV, | |||
| "LocalAI":LocalAICV | |||
| } | |||
| @@ -56,6 +58,7 @@ ChatModel = { | |||
| "ZHIPU-AI": ZhipuChat, | |||
| "Tongyi-Qianwen": QWenChat, | |||
| "Ollama": OllamaChat, | |||
| "LocalAI": LocalAIChat, | |||
| "Xinference": XinferenceChat, | |||
| "Moonshot": MoonshotChat, | |||
| "DeepSeek": DeepSeekChat, | |||
| @@ -67,7 +70,7 @@ ChatModel = { | |||
| 'Gemini' : GeminiChat, | |||
| "Bedrock": BedrockChat, | |||
| "Groq": GroqChat, | |||
| 'OpenRouter':OpenRouterChat | |||
| 'OpenRouter':OpenRouterChat, | |||
| } | |||
| @@ -348,6 +348,82 @@ class OllamaChat(Base): | |||
| yield 0 | |||
| class LocalAIChat(Base): | |||
| def __init__(self, key, model_name, base_url): | |||
| if base_url[-1] == "/": | |||
| base_url = base_url[:-1] | |||
| self.base_url = base_url + "/v1/chat/completions" | |||
| self.model_name = model_name.split("___")[0] | |||
| def chat(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| for k in list(gen_conf.keys()): | |||
| if k not in ["temperature", "top_p", "max_tokens"]: | |||
| del gen_conf[k] | |||
| headers = { | |||
| "Content-Type": "application/json", | |||
| } | |||
| payload = json.dumps( | |||
| {"model": self.model_name, "messages": history, **gen_conf} | |||
| ) | |||
| try: | |||
| response = requests.request( | |||
| "POST", url=self.base_url, headers=headers, data=payload | |||
| ) | |||
| response = response.json() | |||
| ans = response["choices"][0]["message"]["content"].strip() | |||
| if response["choices"][0]["finish_reason"] == "length": | |||
| ans += ( | |||
| "...\nFor the content length reason, it stopped, continue?" | |||
| if is_english([ans]) | |||
| else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| ) | |||
| return ans, response["usage"]["total_tokens"] | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| ans = "" | |||
| total_tokens = 0 | |||
| try: | |||
| headers = { | |||
| "Content-Type": "application/json", | |||
| } | |||
| payload = json.dumps( | |||
| { | |||
| "model": self.model_name, | |||
| "messages": history, | |||
| "stream": True, | |||
| **gen_conf, | |||
| } | |||
| ) | |||
| response = requests.request( | |||
| "POST", | |||
| url=self.base_url, | |||
| headers=headers, | |||
| data=payload, | |||
| ) | |||
| for resp in response.content.decode("utf-8").split("\n\n"): | |||
| if "choices" not in resp: | |||
| continue | |||
| resp = json.loads(resp[6:]) | |||
| if "delta" in resp["choices"][0]: | |||
| text = resp["choices"][0]["delta"]["content"] | |||
| else: | |||
| continue | |||
| ans += text | |||
| total_tokens += 1 | |||
| yield ans | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| yield total_tokens | |||
| class LocalLLM(Base): | |||
| class RPCProxy: | |||
| def __init__(self, host, port): | |||
| @@ -189,6 +189,35 @@ class OllamaCV(Base): | |||
| return "**ERROR**: " + str(e), 0 | |||
| class LocalAICV(Base): | |||
| def __init__(self, key, model_name, base_url, lang="Chinese"): | |||
| self.client = OpenAI(api_key="empty", base_url=base_url) | |||
| self.model_name = model_name.split("___")[0] | |||
| self.lang = lang | |||
| def describe(self, image, max_tokens=300): | |||
| if not isinstance(image, bytes) and not isinstance( | |||
| image, BytesIO | |||
| ): # if url string | |||
| prompt = self.prompt(image) | |||
| for i in range(len(prompt)): | |||
| prompt[i]["content"]["image_url"]["url"] = image | |||
| else: | |||
| b64 = self.image2base64(image) | |||
| prompt = self.prompt(b64) | |||
| for i in range(len(prompt)): | |||
| for c in prompt[i]["content"]: | |||
| if "text" in c: | |||
| c["type"] = "text" | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=prompt, | |||
| max_tokens=max_tokens, | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| class XinferenceCV(Base): | |||
| def __init__(self, key, model_name="", lang="Chinese", base_url=""): | |||
| self.client = OpenAI(api_key="xxx", base_url=base_url) | |||
| @@ -111,6 +111,24 @@ class OpenAIEmbed(Base): | |||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||
| class LocalAIEmbed(Base): | |||
| def __init__(self, key, model_name, base_url): | |||
| self.base_url = base_url + "/embeddings" | |||
| self.headers = { | |||
| "Content-Type": "application/json", | |||
| } | |||
| self.model_name = model_name.split("___")[0] | |||
| def encode(self, texts: list, batch_size=None): | |||
| data = {"model": self.model_name, "input": texts, "encoding_type": "float"} | |||
| res = requests.post(self.base_url, headers=self.headers, json=data).json() | |||
| return np.array([d["embedding"] for d in res["data"]]), 1024 | |||
| def encode_queries(self, text): | |||
| embds, cnt = self.encode([text]) | |||
| return np.array(embds[0]), cnt | |||
| class AzureEmbed(OpenAIEmbed): | |||
| def __init__(self, key, model_name, **kwargs): | |||
| self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") | |||
| @@ -443,4 +461,4 @@ class GeminiEmbed(Base): | |||
| task_type="retrieval_document", | |||
| title="Embedding of single string") | |||
| token_count = num_tokens_from_string(text) | |||
| return np.array(result['embedding']),token_count | |||
| return np.array(result['embedding']),token_count | |||
| @@ -135,7 +135,7 @@ class YoudaoRerank(DefaultRerank): | |||
| if isinstance(scores, float): res.append(scores) | |||
| else: res.extend(scores) | |||
| return np.array(res), token_count | |||
| class XInferenceRerank(Base): | |||
| def __init__(self, key="xxxxxxx", model_name="", base_url=""): | |||
| @@ -156,3 +156,11 @@ class XInferenceRerank(Base): | |||
| } | |||
| res = requests.post(self.base_url, headers=self.headers, json=data).json() | |||
| return np.array([d["relevance_score"] for d in res["results"]]), res["meta"]["tokens"]["input_tokens"]+res["meta"]["tokens"]["output_tokens"] | |||
| class LocalAIRerank(Base): | |||
| def __init__(self, key, model_name, base_url): | |||
| pass | |||
| def similarity(self, query: str, texts: list): | |||
| raise NotImplementedError("The LocalAIRerank has not been implement") | |||
| @@ -17,4 +17,4 @@ export const UserSettingIconMap = { | |||
| export * from '@/constants/setting'; | |||
| export const LocalLlmFactories = ['Ollama', 'Xinference']; | |||
| export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI']; | |||
| @@ -75,6 +75,7 @@ const OllamaModal = ({ | |||
| <Option value="chat">chat</Option> | |||
| <Option value="embedding">embedding</Option> | |||
| <Option value="rerank">rerank</Option> | |||
| <Option value="image2text">image2text</Option> | |||
| </Select> | |||
| </Form.Item> | |||
| <Form.Item<FieldType> | |||