### What problem does this PR solve? #433 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.8.0
| @@ -157,6 +157,11 @@ factory_infos = [{ | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING", | |||
| "status": "1", | |||
| },{ | |||
| "name": "Mistral", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING", | |||
| "status": "1", | |||
| } | |||
| # { | |||
| # "name": "文心一言", | |||
| @@ -584,6 +589,63 @@ def init_llm_factory(): | |||
| "max_tokens": 8192, | |||
| "model_type": LLMType.CHAT.value | |||
| }, | |||
| # ------------------------ Mistral ----------------------- | |||
| { | |||
| "fid": factory_infos[14]["name"], | |||
| "llm_name": "open-mixtral-8x22b", | |||
| "tags": "LLM,CHAT,64k", | |||
| "max_tokens": 64000, | |||
| "model_type": LLMType.CHAT.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[14]["name"], | |||
| "llm_name": "open-mixtral-8x7b", | |||
| "tags": "LLM,CHAT,32k", | |||
| "max_tokens": 32000, | |||
| "model_type": LLMType.CHAT.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[14]["name"], | |||
| "llm_name": "open-mistral-7b", | |||
| "tags": "LLM,CHAT,32k", | |||
| "max_tokens": 32000, | |||
| "model_type": LLMType.CHAT.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[14]["name"], | |||
| "llm_name": "mistral-large-latest", | |||
| "tags": "LLM,CHAT,32k", | |||
| "max_tokens": 32000, | |||
| "model_type": LLMType.CHAT.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[14]["name"], | |||
| "llm_name": "mistral-small-latest", | |||
| "tags": "LLM,CHAT,32k", | |||
| "max_tokens": 32000, | |||
| "model_type": LLMType.CHAT.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[14]["name"], | |||
| "llm_name": "mistral-medium-latest", | |||
| "tags": "LLM,CHAT,32k", | |||
| "max_tokens": 32000, | |||
| "model_type": LLMType.CHAT.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[14]["name"], | |||
| "llm_name": "codestral-latest", | |||
| "tags": "LLM,CHAT,32k", | |||
| "max_tokens": 32000, | |||
| "model_type": LLMType.CHAT.value | |||
| }, | |||
| { | |||
| "fid": factory_infos[14]["name"], | |||
| "llm_name": "mistral-embed", | |||
| "tags": "LLM,CHAT,8k", | |||
| "max_tokens": 8192, | |||
| "model_type": LLMType.EMBEDDING | |||
| }, | |||
| ] | |||
| for info in factory_infos: | |||
| try: | |||
| @@ -29,7 +29,8 @@ EmbeddingModel = { | |||
| "Youdao": YoudaoEmbed, | |||
| "BaiChuan": BaiChuanEmbed, | |||
| "Jina": JinaEmbed, | |||
| "BAAI": DefaultEmbedding | |||
| "BAAI": DefaultEmbedding, | |||
| "Mistral": MistralEmbed | |||
| } | |||
| @@ -52,7 +53,8 @@ ChatModel = { | |||
| "Moonshot": MoonshotChat, | |||
| "DeepSeek": DeepSeekChat, | |||
| "BaiChuan": BaiChuanChat, | |||
| "MiniMax": MiniMaxChat | |||
| "MiniMax": MiniMaxChat, | |||
| "Mistral": MistralChat | |||
| } | |||
| @@ -472,3 +472,57 @@ class MiniMaxChat(Base): | |||
| if not base_url: | |||
| base_url="https://api.minimax.chat/v1/text/chatcompletion_v2" | |||
| super().__init__(key, model_name, base_url) | |||
| class MistralChat(Base): | |||
| def __init__(self, key, model_name, base_url=None): | |||
| from mistralai.client import MistralClient | |||
| self.client = MistralClient(api_key=key) | |||
| self.model_name = model_name | |||
| def chat(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| for k in list(gen_conf.keys()): | |||
| if k not in ["temperature", "top_p", "max_tokens"]: | |||
| del gen_conf[k] | |||
| try: | |||
| response = self.client.chat( | |||
| model=self.model_name, | |||
| messages=history, | |||
| **gen_conf) | |||
| ans = response.choices[0].message.content | |||
| if response.choices[0].finish_reason == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| return ans, response.usage.total_tokens | |||
| except openai.APIError as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| for k in list(gen_conf.keys()): | |||
| if k not in ["temperature", "top_p", "max_tokens"]: | |||
| del gen_conf[k] | |||
| ans = "" | |||
| total_tokens = 0 | |||
| try: | |||
| response = self.client.chat_stream( | |||
| model=self.model_name, | |||
| messages=history, | |||
| **gen_conf) | |||
| for resp in response: | |||
| if not resp.choices or not resp.choices[0].delta.content:continue | |||
| ans += resp.choices[0].delta.content | |||
| total_tokens += 1 | |||
| if resp.choices[0].finish_reason == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| yield ans | |||
| except openai.APIError as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| yield total_tokens | |||
| @@ -343,4 +343,24 @@ class InfinityEmbed(Base): | |||
| def encode_queries(self, text: str) -> tuple[np.ndarray, int]: | |||
| # Using the internal tokenizer to encode the texts and get the total | |||
| # number of tokens | |||
| return self.encode([text]) | |||
| return self.encode([text]) | |||
| class MistralEmbed(Base): | |||
| def __init__(self, key, model_name="mistral-embed", | |||
| base_url=None): | |||
| from mistralai.client import MistralClient | |||
| self.client = MistralClient(api_key=key) | |||
| self.model_name = model_name | |||
| def encode(self, texts: list, batch_size=32): | |||
| texts = [truncate(t, 8196) for t in texts] | |||
| res = self.client.embeddings(input=texts, | |||
| model=self.model_name) | |||
| return np.array([d.embedding for d in res.data] | |||
| ), res.usage.total_tokens | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings(input=[truncate(text, 8196)], | |||
| model=self.model_name) | |||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||