Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

sequence2txt_model.py 6.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import requests
  17. from openai.lib.azure import AzureOpenAI
  18. from zhipuai import ZhipuAI
  19. import io
  20. from abc import ABC
  21. from ollama import Client
  22. from openai import OpenAI
  23. import os
  24. import json
  25. from rag.utils import num_tokens_from_string
  26. import base64
  27. import re
  28. class Base(ABC):
  29. def __init__(self, key, model_name):
  30. pass
  31. def transcription(self, audio, **kwargs):
  32. transcription = self.client.audio.transcriptions.create(
  33. model=self.model_name,
  34. file=audio,
  35. response_format="text"
  36. )
  37. return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
  38. def audio2base64(self, audio):
  39. if isinstance(audio, bytes):
  40. return base64.b64encode(audio).decode("utf-8")
  41. if isinstance(audio, io.BytesIO):
  42. return base64.b64encode(audio.getvalue()).decode("utf-8")
  43. raise TypeError("The input audio file should be in binary format.")
  44. class GPTSeq2txt(Base):
  45. def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
  46. if not base_url: base_url = "https://api.openai.com/v1"
  47. self.client = OpenAI(api_key=key, base_url=base_url)
  48. self.model_name = model_name
  49. class QWenSeq2txt(Base):
  50. def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
  51. import dashscope
  52. dashscope.api_key = key
  53. self.model_name = model_name
  54. def transcription(self, audio, format):
  55. from http import HTTPStatus
  56. from dashscope.audio.asr import Recognition
  57. recognition = Recognition(model=self.model_name,
  58. format=format,
  59. sample_rate=16000,
  60. callback=None)
  61. result = recognition.call(audio)
  62. ans = ""
  63. if result.status_code == HTTPStatus.OK:
  64. for sentence in result.get_sentence():
  65. ans += sentence.text.decode('utf-8') + '\n'
  66. return ans, num_tokens_from_string(ans)
  67. return "**ERROR**: " + result.message, 0
  68. class AzureSeq2txt(Base):
  69. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  70. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  71. self.model_name = model_name
  72. self.lang = lang
  73. class XinferenceSeq2txt(Base):
  74. def __init__(self, key, model_name="whisper-small", **kwargs):
  75. self.base_url = kwargs.get('base_url', None)
  76. self.model_name = model_name
  77. self.key = key
  78. def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
  79. if isinstance(audio, str):
  80. audio_file = open(audio, 'rb')
  81. audio_data = audio_file.read()
  82. audio_file_name = audio.split("/")[-1]
  83. else:
  84. audio_data = audio
  85. audio_file_name = "audio.wav"
  86. payload = {
  87. "model": self.model_name,
  88. "language": language,
  89. "prompt": prompt,
  90. "response_format": response_format,
  91. "temperature": temperature
  92. }
  93. files = {
  94. "file": (audio_file_name, audio_data, 'audio/wav')
  95. }
  96. try:
  97. response = requests.post(
  98. f"{self.base_url}/v1/audio/transcriptions",
  99. files=files,
  100. data=payload
  101. )
  102. response.raise_for_status()
  103. result = response.json()
  104. if 'text' in result:
  105. transcription_text = result['text'].strip()
  106. return transcription_text, num_tokens_from_string(transcription_text)
  107. else:
  108. return "**ERROR**: Failed to retrieve transcription.", 0
  109. except requests.exceptions.RequestException as e:
  110. return f"**ERROR**: {str(e)}", 0
  111. class TencentCloudSeq2txt(Base):
  112. def __init__(
  113. self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
  114. ):
  115. from tencentcloud.common import credential
  116. from tencentcloud.asr.v20190614 import asr_client
  117. key = json.loads(key)
  118. sid = key.get("tencent_cloud_sid", "")
  119. sk = key.get("tencent_cloud_sk", "")
  120. cred = credential.Credential(sid, sk)
  121. self.client = asr_client.AsrClient(cred, "")
  122. self.model_name = model_name
  123. def transcription(self, audio, max_retries=60, retry_interval=5):
  124. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  125. TencentCloudSDKException,
  126. )
  127. from tencentcloud.asr.v20190614 import models
  128. import time
  129. b64 = self.audio2base64(audio)
  130. try:
  131. # dispatch disk
  132. req = models.CreateRecTaskRequest()
  133. params = {
  134. "EngineModelType": self.model_name,
  135. "ChannelNum": 1,
  136. "ResTextFormat": 0,
  137. "SourceType": 1,
  138. "Data": b64,
  139. }
  140. req.from_json_string(json.dumps(params))
  141. resp = self.client.CreateRecTask(req)
  142. # loop query
  143. req = models.DescribeTaskStatusRequest()
  144. params = {"TaskId": resp.Data.TaskId}
  145. req.from_json_string(json.dumps(params))
  146. retries = 0
  147. while retries < max_retries:
  148. resp = self.client.DescribeTaskStatus(req)
  149. if resp.Data.StatusStr == "success":
  150. text = re.sub(
  151. r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
  152. ).strip()
  153. return text, num_tokens_from_string(text)
  154. elif resp.Data.StatusStr == "failed":
  155. return (
  156. "**ERROR**: Failed to retrieve speech recognition results.",
  157. 0,
  158. )
  159. else:
  160. time.sleep(retry_interval)
  161. retries += 1
  162. return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
  163. except TencentCloudSDKException as e:
  164. return "**ERROR**: " + str(e), 0
  165. except Exception as e:
  166. return "**ERROR**: " + str(e), 0