### What problem does this PR solve? #433 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.8.0
| "logo": "", | "logo": "", | ||||
| "tags": "LLM,TEXT EMBEDDING", | "tags": "LLM,TEXT EMBEDDING", | ||||
| "status": "1", | "status": "1", | ||||
| },{ | |||||
| "name": "Mistral", | |||||
| "logo": "", | |||||
| "tags": "LLM,TEXT EMBEDDING", | |||||
| "status": "1", | |||||
| } | } | ||||
| # { | # { | ||||
| # "name": "文心一言", | # "name": "文心一言", | ||||
| "max_tokens": 8192, | "max_tokens": 8192, | ||||
| "model_type": LLMType.CHAT.value | "model_type": LLMType.CHAT.value | ||||
| }, | }, | ||||
| # ------------------------ Mistral ----------------------- | |||||
| { | |||||
| "fid": factory_infos[14]["name"], | |||||
| "llm_name": "open-mixtral-8x22b", | |||||
| "tags": "LLM,CHAT,64k", | |||||
| "max_tokens": 64000, | |||||
| "model_type": LLMType.CHAT.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[14]["name"], | |||||
| "llm_name": "open-mixtral-8x7b", | |||||
| "tags": "LLM,CHAT,32k", | |||||
| "max_tokens": 32000, | |||||
| "model_type": LLMType.CHAT.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[14]["name"], | |||||
| "llm_name": "open-mistral-7b", | |||||
| "tags": "LLM,CHAT,32k", | |||||
| "max_tokens": 32000, | |||||
| "model_type": LLMType.CHAT.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[14]["name"], | |||||
| "llm_name": "mistral-large-latest", | |||||
| "tags": "LLM,CHAT,32k", | |||||
| "max_tokens": 32000, | |||||
| "model_type": LLMType.CHAT.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[14]["name"], | |||||
| "llm_name": "mistral-small-latest", | |||||
| "tags": "LLM,CHAT,32k", | |||||
| "max_tokens": 32000, | |||||
| "model_type": LLMType.CHAT.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[14]["name"], | |||||
| "llm_name": "mistral-medium-latest", | |||||
| "tags": "LLM,CHAT,32k", | |||||
| "max_tokens": 32000, | |||||
| "model_type": LLMType.CHAT.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[14]["name"], | |||||
| "llm_name": "codestral-latest", | |||||
| "tags": "LLM,CHAT,32k", | |||||
| "max_tokens": 32000, | |||||
| "model_type": LLMType.CHAT.value | |||||
| }, | |||||
| { | |||||
| "fid": factory_infos[14]["name"], | |||||
| "llm_name": "mistral-embed", | |||||
| "tags": "LLM,CHAT,8k", | |||||
| "max_tokens": 8192, | |||||
| "model_type": LLMType.EMBEDDING | |||||
| }, | |||||
| ] | ] | ||||
| for info in factory_infos: | for info in factory_infos: | ||||
| try: | try: |
| "Youdao": YoudaoEmbed, | "Youdao": YoudaoEmbed, | ||||
| "BaiChuan": BaiChuanEmbed, | "BaiChuan": BaiChuanEmbed, | ||||
| "Jina": JinaEmbed, | "Jina": JinaEmbed, | ||||
| "BAAI": DefaultEmbedding | |||||
| "BAAI": DefaultEmbedding, | |||||
| "Mistral": MistralEmbed | |||||
| } | } | ||||
| "Moonshot": MoonshotChat, | "Moonshot": MoonshotChat, | ||||
| "DeepSeek": DeepSeekChat, | "DeepSeek": DeepSeekChat, | ||||
| "BaiChuan": BaiChuanChat, | "BaiChuan": BaiChuanChat, | ||||
| "MiniMax": MiniMaxChat | |||||
| "MiniMax": MiniMaxChat, | |||||
| "Mistral": MistralChat | |||||
| } | } | ||||
| if not base_url: | if not base_url: | ||||
| base_url="https://api.minimax.chat/v1/text/chatcompletion_v2" | base_url="https://api.minimax.chat/v1/text/chatcompletion_v2" | ||||
| super().__init__(key, model_name, base_url) | super().__init__(key, model_name, base_url) | ||||
| class MistralChat(Base): | |||||
| def __init__(self, key, model_name, base_url=None): | |||||
| from mistralai.client import MistralClient | |||||
| self.client = MistralClient(api_key=key) | |||||
| self.model_name = model_name | |||||
| def chat(self, system, history, gen_conf): | |||||
| if system: | |||||
| history.insert(0, {"role": "system", "content": system}) | |||||
| for k in list(gen_conf.keys()): | |||||
| if k not in ["temperature", "top_p", "max_tokens"]: | |||||
| del gen_conf[k] | |||||
| try: | |||||
| response = self.client.chat( | |||||
| model=self.model_name, | |||||
| messages=history, | |||||
| **gen_conf) | |||||
| ans = response.choices[0].message.content | |||||
| if response.choices[0].finish_reason == "length": | |||||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||||
| return ans, response.usage.total_tokens | |||||
| except openai.APIError as e: | |||||
| return "**ERROR**: " + str(e), 0 | |||||
| def chat_streamly(self, system, history, gen_conf): | |||||
| if system: | |||||
| history.insert(0, {"role": "system", "content": system}) | |||||
| for k in list(gen_conf.keys()): | |||||
| if k not in ["temperature", "top_p", "max_tokens"]: | |||||
| del gen_conf[k] | |||||
| ans = "" | |||||
| total_tokens = 0 | |||||
| try: | |||||
| response = self.client.chat_stream( | |||||
| model=self.model_name, | |||||
| messages=history, | |||||
| **gen_conf) | |||||
| for resp in response: | |||||
| if not resp.choices or not resp.choices[0].delta.content:continue | |||||
| ans += resp.choices[0].delta.content | |||||
| total_tokens += 1 | |||||
| if resp.choices[0].finish_reason == "length": | |||||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||||
| yield ans | |||||
| except openai.APIError as e: | |||||
| yield ans + "\n**ERROR**: " + str(e) | |||||
| yield total_tokens |
| def encode_queries(self, text: str) -> tuple[np.ndarray, int]: | def encode_queries(self, text: str) -> tuple[np.ndarray, int]: | ||||
| # Using the internal tokenizer to encode the texts and get the total | # Using the internal tokenizer to encode the texts and get the total | ||||
| # number of tokens | # number of tokens | ||||
| return self.encode([text]) | |||||
| return self.encode([text]) | |||||
| class MistralEmbed(Base): | |||||
| def __init__(self, key, model_name="mistral-embed", | |||||
| base_url=None): | |||||
| from mistralai.client import MistralClient | |||||
| self.client = MistralClient(api_key=key) | |||||
| self.model_name = model_name | |||||
| def encode(self, texts: list, batch_size=32): | |||||
| texts = [truncate(t, 8196) for t in texts] | |||||
| res = self.client.embeddings(input=texts, | |||||
| model=self.model_name) | |||||
| return np.array([d.embedding for d in res.data] | |||||
| ), res.usage.total_tokens | |||||
| def encode_queries(self, text): | |||||
| res = self.client.embeddings(input=[truncate(text, 8196)], | |||||
| model=self.model_name) | |||||
| return np.array(res.data[0].embedding), res.usage.total_tokens |