### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.13.0
| @@ -195,7 +195,7 @@ class LLMBundle(object): | |||
| self.llm_name = llm_name | |||
| self.mdl = TenantLLMService.model_instance( | |||
| tenant_id, llm_type, llm_name, lang=lang) | |||
| assert self.mdl, "Can't find mole for {}/{}/{}".format( | |||
| assert self.mdl, "Can't find model for {}/{}/{}".format( | |||
| tenant_id, llm_type, llm_name) | |||
| self.max_length = 8192 | |||
| for lm in LLMService.query(llm_name=llm_name): | |||
| @@ -47,10 +47,9 @@ EmbeddingModel = { | |||
| "Replicate": ReplicateEmbed, | |||
| "BaiduYiyan": BaiduYiyanEmbed, | |||
| "Voyage AI": VoyageEmbed, | |||
| "HuggingFace":HuggingFaceEmbed, | |||
| "HuggingFace": HuggingFaceEmbed, | |||
| } | |||
| CvModel = { | |||
| "OpenAI": GptV4, | |||
| "Azure-OpenAI": AzureGptV4, | |||
| @@ -64,14 +63,13 @@ CvModel = { | |||
| "LocalAI": LocalAICV, | |||
| "NVIDIA": NvidiaCV, | |||
| "LM-Studio": LmStudioCV, | |||
| "StepFun":StepFunCV, | |||
| "StepFun": StepFunCV, | |||
| "OpenAI-API-Compatible": OpenAI_APICV, | |||
| "TogetherAI": TogetherAICV, | |||
| "01.AI": YiCV, | |||
| "Tencent Hunyuan": HunyuanCV | |||
| } | |||
| ChatModel = { | |||
| "OpenAI": GptTurbo, | |||
| "Azure-OpenAI": AzureChat, | |||
| @@ -99,7 +97,7 @@ ChatModel = { | |||
| "LeptonAI": LeptonAIChat, | |||
| "TogetherAI": TogetherAIChat, | |||
| "PerfXCloud": PerfXCloudChat, | |||
| "Upstage":UpstageChat, | |||
| "Upstage": UpstageChat, | |||
| "novita.ai": NovitaAIChat, | |||
| "SILICONFLOW": SILICONFLOWChat, | |||
| "01.AI": YiChat, | |||
| @@ -111,7 +109,6 @@ ChatModel = { | |||
| "Google Cloud": GoogleChat, | |||
| } | |||
| RerankModel = { | |||
| "BAAI": DefaultRerank, | |||
| "Jina": JinaRerank, | |||
| @@ -127,11 +124,9 @@ RerankModel = { | |||
| "Voyage AI": VoyageRerank | |||
| } | |||
| Seq2txtModel = { | |||
| "OpenAI": GPTSeq2txt, | |||
| "Tongyi-Qianwen": QWenSeq2txt, | |||
| "Ollama": OllamaSeq2txt, | |||
| "Azure-OpenAI": AzureSeq2txt, | |||
| "Xinference": XinferenceSeq2txt, | |||
| "Tencent Cloud": TencentCloudSeq2txt | |||
| @@ -140,6 +135,7 @@ Seq2txtModel = { | |||
| TTSModel = { | |||
| "Fish Audio": FishAudioTTS, | |||
| "Tongyi-Qianwen": QwenTTS, | |||
| "OpenAI":OpenAITTS, | |||
| "XunFei Spark":SparkTTS | |||
| } | |||
| "OpenAI": OpenAITTS, | |||
| "XunFei Spark": SparkTTS, | |||
| "Xinference": XinferenceTTS, | |||
| } | |||
| @@ -13,6 +13,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import requests | |||
| from openai.lib.azure import AzureOpenAI | |||
| from zhipuai import ZhipuAI | |||
| import io | |||
| @@ -25,6 +26,7 @@ from rag.utils import num_tokens_from_string | |||
| import base64 | |||
| import re | |||
| class Base(ABC): | |||
| def __init__(self, key, model_name): | |||
| pass | |||
| @@ -36,8 +38,8 @@ class Base(ABC): | |||
| response_format="text" | |||
| ) | |||
| return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) | |||
| def audio2base64(self,audio): | |||
| def audio2base64(self, audio): | |||
| if isinstance(audio, bytes): | |||
| return base64.b64encode(audio).decode("utf-8") | |||
| if isinstance(audio, io.BytesIO): | |||
| @@ -77,13 +79,6 @@ class QWenSeq2txt(Base): | |||
| return "**ERROR**: " + result.message, 0 | |||
| class OllamaSeq2txt(Base): | |||
| def __init__(self, key, model_name, lang="Chinese", **kwargs): | |||
| self.client = Client(host=kwargs["base_url"]) | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| class AzureSeq2txt(Base): | |||
| def __init__(self, key, model_name, lang="Chinese", **kwargs): | |||
| self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") | |||
| @@ -92,16 +87,53 @@ class AzureSeq2txt(Base): | |||
| class XinferenceSeq2txt(Base): | |||
| def __init__(self, key, model_name="", base_url=""): | |||
| if base_url.split("/")[-1] != "v1": | |||
| base_url = os.path.join(base_url, "v1") | |||
| self.client = OpenAI(api_key="xxx", base_url=base_url) | |||
| def __init__(self,key,model_name="whisper-small",**kwargs): | |||
| self.base_url = kwargs.get('base_url', None) | |||
| self.model_name = model_name | |||
| def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7): | |||
| if isinstance(audio, str): | |||
| audio_file = open(audio, 'rb') | |||
| audio_data = audio_file.read() | |||
| audio_file_name = audio.split("/")[-1] | |||
| else: | |||
| audio_data = audio | |||
| audio_file_name = "audio.wav" | |||
| payload = { | |||
| "model": self.model_name, | |||
| "language": language, | |||
| "prompt": prompt, | |||
| "response_format": response_format, | |||
| "temperature": temperature | |||
| } | |||
| files = { | |||
| "file": (audio_file_name, audio_data, 'audio/wav') | |||
| } | |||
| try: | |||
| response = requests.post( | |||
| f"{self.base_url}/v1/audio/transcriptions", | |||
| files=files, | |||
| data=payload | |||
| ) | |||
| response.raise_for_status() | |||
| result = response.json() | |||
| if 'text' in result: | |||
| transcription_text = result['text'].strip() | |||
| return transcription_text, num_tokens_from_string(transcription_text) | |||
| else: | |||
| return "**ERROR**: Failed to retrieve transcription.", 0 | |||
| except requests.exceptions.RequestException as e: | |||
| return f"**ERROR**: {str(e)}", 0 | |||
| class TencentCloudSeq2txt(Base): | |||
| def __init__( | |||
| self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" | |||
| self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" | |||
| ): | |||
| from tencentcloud.common import credential | |||
| from tencentcloud.asr.v20190614 import asr_client | |||
| @@ -297,3 +297,36 @@ class SparkTTS: | |||
| break | |||
| status_code = 1 | |||
| yield audio_chunk | |||
| class XinferenceTTS: | |||
| def __init__(self, key, model_name, **kwargs): | |||
| self.base_url = kwargs.get("base_url", None) | |||
| self.model_name = model_name | |||
| self.headers = { | |||
| "accept": "application/json", | |||
| "Content-Type": "application/json" | |||
| } | |||
| def tts(self, text, voice="中文女", stream=True): | |||
| payload = { | |||
| "model": self.model_name, | |||
| "input": text, | |||
| "voice": voice | |||
| } | |||
| response = requests.post( | |||
| f"{self.base_url}/v1/audio/speech", | |||
| headers=self.headers, | |||
| json=payload, | |||
| stream=stream | |||
| ) | |||
| if response.status_code != 200: | |||
| raise Exception(f"**Error**: {response.status_code}, {response.text}") | |||
| for chunk in response.iter_content(chunk_size=1024): | |||
| if chunk: | |||
| yield chunk | |||
| @@ -53,6 +53,26 @@ const OllamaModal = ({ | |||
| const url = | |||
| llmFactoryToUrlMap[llmFactory as LlmFactory] || | |||
| 'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx'; | |||
| const optionsMap = { | |||
| HuggingFace: [{ value: 'embedding', label: 'embedding' }], | |||
| Xinference: [ | |||
| { value: 'chat', label: 'chat' }, | |||
| { value: 'embedding', label: 'embedding' }, | |||
| { value: 'rerank', label: 'rerank' }, | |||
| { value: 'image2text', label: 'image2text' }, | |||
| { value: 'speech2text', label: 'sequence2text' }, | |||
| { value: 'tts', label: 'tts' }, | |||
| ], | |||
| Default: [ | |||
| { value: 'chat', label: 'chat' }, | |||
| { value: 'embedding', label: 'embedding' }, | |||
| { value: 'rerank', label: 'rerank' }, | |||
| { value: 'image2text', label: 'image2text' }, | |||
| ], | |||
| }; | |||
| const getOptions = (factory: string) => { | |||
| return optionsMap[factory as keyof typeof optionsMap] || optionsMap.Default; | |||
| }; | |||
| return ( | |||
| <Modal | |||
| title={t('addLlmTitle', { name: llmFactory })} | |||
| @@ -85,18 +105,11 @@ const OllamaModal = ({ | |||
| rules={[{ required: true, message: t('modelTypeMessage') }]} | |||
| > | |||
| <Select placeholder={t('modelTypeMessage')}> | |||
| {llmFactory === 'HuggingFace' ? ( | |||
| <Option value="embedding">embedding</Option> | |||
| ) : ( | |||
| <> | |||
| <Option value="chat">chat</Option> | |||
| <Option value="embedding">embedding</Option> | |||
| <Option value="rerank">rerank</Option> | |||
| <Option value="image2text">image2text</Option> | |||
| <Option value="audio2text">audio2text</Option> | |||
| <Option value="text2andio">text2andio</Option> | |||
| </> | |||
| )} | |||
| {getOptions(llmFactory).map((option) => ( | |||
| <Option key={option.value} value={option.value}> | |||
| {option.label} | |||
| </Option> | |||
| ))} | |||
| </Select> | |||
| </Form.Item> | |||
| <Form.Item<FieldType> | |||