Browse Source

support sequence2txt and tts model in Xinference (#2696)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
tags/v0.13.0
JobSmithManipulation 1 year ago
parent
commit
a3ab5ba9ac
No account linked to committer's email address

+ 1
- 1
api/db/services/llm_service.py View File

self.llm_name = llm_name self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance( self.mdl = TenantLLMService.model_instance(
tenant_id, llm_type, llm_name, lang=lang) tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find mole for {}/{}/{}".format(
assert self.mdl, "Can't find model for {}/{}/{}".format(
tenant_id, llm_type, llm_name) tenant_id, llm_type, llm_name)
self.max_length = 8192 self.max_length = 8192
for lm in LLMService.query(llm_name=llm_name): for lm in LLMService.query(llm_name=llm_name):

+ 7
- 11
rag/llm/__init__.py View File

"Replicate": ReplicateEmbed, "Replicate": ReplicateEmbed,
"BaiduYiyan": BaiduYiyanEmbed, "BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed, "Voyage AI": VoyageEmbed,
"HuggingFace":HuggingFaceEmbed,
"HuggingFace": HuggingFaceEmbed,
} }



CvModel = { CvModel = {
"OpenAI": GptV4, "OpenAI": GptV4,
"Azure-OpenAI": AzureGptV4, "Azure-OpenAI": AzureGptV4,
"LocalAI": LocalAICV, "LocalAI": LocalAICV,
"NVIDIA": NvidiaCV, "NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV, "LM-Studio": LmStudioCV,
"StepFun":StepFunCV,
"StepFun": StepFunCV,
"OpenAI-API-Compatible": OpenAI_APICV, "OpenAI-API-Compatible": OpenAI_APICV,
"TogetherAI": TogetherAICV, "TogetherAI": TogetherAICV,
"01.AI": YiCV, "01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV "Tencent Hunyuan": HunyuanCV
} }



ChatModel = { ChatModel = {
"OpenAI": GptTurbo, "OpenAI": GptTurbo,
"Azure-OpenAI": AzureChat, "Azure-OpenAI": AzureChat,
"LeptonAI": LeptonAIChat, "LeptonAI": LeptonAIChat,
"TogetherAI": TogetherAIChat, "TogetherAI": TogetherAIChat,
"PerfXCloud": PerfXCloudChat, "PerfXCloud": PerfXCloudChat,
"Upstage":UpstageChat,
"Upstage": UpstageChat,
"novita.ai": NovitaAIChat, "novita.ai": NovitaAIChat,
"SILICONFLOW": SILICONFLOWChat, "SILICONFLOW": SILICONFLOWChat,
"01.AI": YiChat, "01.AI": YiChat,
"Google Cloud": GoogleChat, "Google Cloud": GoogleChat,
} }



RerankModel = { RerankModel = {
"BAAI": DefaultRerank, "BAAI": DefaultRerank,
"Jina": JinaRerank, "Jina": JinaRerank,
"Voyage AI": VoyageRerank "Voyage AI": VoyageRerank
} }



Seq2txtModel = { Seq2txtModel = {
"OpenAI": GPTSeq2txt, "OpenAI": GPTSeq2txt,
"Tongyi-Qianwen": QWenSeq2txt, "Tongyi-Qianwen": QWenSeq2txt,
"Ollama": OllamaSeq2txt,
"Azure-OpenAI": AzureSeq2txt, "Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt, "Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt "Tencent Cloud": TencentCloudSeq2txt
TTSModel = { TTSModel = {
"Fish Audio": FishAudioTTS, "Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS, "Tongyi-Qianwen": QwenTTS,
"OpenAI":OpenAITTS,
"XunFei Spark":SparkTTS
}
"OpenAI": OpenAITTS,
"XunFei Spark": SparkTTS,
"Xinference": XinferenceTTS,
}

+ 46
- 14
rag/llm/sequence2txt_model.py View File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import requests
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
import io import io
import base64 import base64
import re import re



class Base(ABC): class Base(ABC):
def __init__(self, key, model_name): def __init__(self, key, model_name):
pass pass
response_format="text" response_format="text"
) )
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self,audio):
def audio2base64(self, audio):
if isinstance(audio, bytes): if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8") return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO): if isinstance(audio, io.BytesIO):
return "**ERROR**: " + result.message, 0 return "**ERROR**: " + result.message, 0




class OllamaSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang


class AzureSeq2txt(Base): class AzureSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs): def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")




class XinferenceSeq2txt(Base): class XinferenceSeq2txt(Base):
def __init__(self, key, model_name="", base_url=""):
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="xxx", base_url=base_url)
def __init__(self,key,model_name="whisper-small",**kwargs):
self.base_url = kwargs.get('base_url', None)
self.model_name = model_name self.model_name = model_name


def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, 'rb')
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"

payload = {
"model": self.model_name,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature
}

files = {
"file": (audio_file_name, audio_data, 'audio/wav')
}

try:
response = requests.post(
f"{self.base_url}/v1/audio/transcriptions",
files=files,
data=payload
)
response.raise_for_status()
result = response.json()

if 'text' in result:
transcription_text = result['text'].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0

except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0



class TencentCloudSeq2txt(Base): class TencentCloudSeq2txt(Base):
def __init__( def __init__(
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
): ):
from tencentcloud.common import credential from tencentcloud.common import credential
from tencentcloud.asr.v20190614 import asr_client from tencentcloud.asr.v20190614 import asr_client

+ 33
- 0
rag/llm/tts_model.py View File

break break
status_code = 1 status_code = 1
yield audio_chunk yield audio_chunk




class XinferenceTTS:
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json"
}

def tts(self, text, voice="中文女", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}

response = requests.post(
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)

if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")

for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk

+ 25
- 12
web/src/pages/user-setting/setting-model/ollama-modal/index.tsx View File

const url = const url =
llmFactoryToUrlMap[llmFactory as LlmFactory] || llmFactoryToUrlMap[llmFactory as LlmFactory] ||
'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx'; 'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx';
const optionsMap = {
HuggingFace: [{ value: 'embedding', label: 'embedding' }],
Xinference: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
{ value: 'rerank', label: 'rerank' },
{ value: 'image2text', label: 'image2text' },
{ value: 'speech2text', label: 'sequence2text' },
{ value: 'tts', label: 'tts' },
],
Default: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
{ value: 'rerank', label: 'rerank' },
{ value: 'image2text', label: 'image2text' },
],
};
const getOptions = (factory: string) => {
return optionsMap[factory as keyof typeof optionsMap] || optionsMap.Default;
};
return ( return (
<Modal <Modal
title={t('addLlmTitle', { name: llmFactory })} title={t('addLlmTitle', { name: llmFactory })}
rules={[{ required: true, message: t('modelTypeMessage') }]} rules={[{ required: true, message: t('modelTypeMessage') }]}
> >
<Select placeholder={t('modelTypeMessage')}> <Select placeholder={t('modelTypeMessage')}>
{llmFactory === 'HuggingFace' ? (
<Option value="embedding">embedding</Option>
) : (
<>
<Option value="chat">chat</Option>
<Option value="embedding">embedding</Option>
<Option value="rerank">rerank</Option>
<Option value="image2text">image2text</Option>
<Option value="audio2text">audio2text</Option>
<Option value="text2andio">text2andio</Option>
</>
)}
{getOptions(llmFactory).map((option) => (
<Option key={option.value} value={option.value}>
{option.label}
</Option>
))}
</Select> </Select>
</Form.Item> </Form.Item>
<Form.Item<FieldType> <Form.Item<FieldType>

Loading…
Cancel
Save