Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

sequence2txt_model.py 8.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import io
  18. import json
  19. import os
  20. import re
  21. from abc import ABC
  22. import requests
  23. from openai import OpenAI
  24. from openai.lib.azure import AzureOpenAI
  25. from rag.utils import num_tokens_from_string
  26. class Base(ABC):
  27. def __init__(self, key, model_name, **kwargs):
  28. """
  29. Abstract base class constructor.
  30. Parameters are not stored; initialization is left to subclasses.
  31. """
  32. pass
  33. def transcription(self, audio_path, **kwargs):
  34. audio_file = open(audio_path, "rb")
  35. transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
  36. return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
  37. def audio2base64(self, audio):
  38. if isinstance(audio, bytes):
  39. return base64.b64encode(audio).decode("utf-8")
  40. if isinstance(audio, io.BytesIO):
  41. return base64.b64encode(audio.getvalue()).decode("utf-8")
  42. raise TypeError("The input audio file should be in binary format.")
  43. class GPTSeq2txt(Base):
  44. _FACTORY_NAME = "OpenAI"
  45. def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
  46. if not base_url:
  47. base_url = "https://api.openai.com/v1"
  48. self.client = OpenAI(api_key=key, base_url=base_url)
  49. self.model_name = model_name
  50. class QWenSeq2txt(Base):
  51. _FACTORY_NAME = "Tongyi-Qianwen"
  52. def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
  53. import dashscope
  54. dashscope.api_key = key
  55. self.model_name = model_name
  56. def transcription(self, audio_path):
  57. if "paraformer" in self.model_name or "sensevoice" in self.model_name:
  58. return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
  59. from dashscope import MultiModalConversation
  60. audio_path = f"file://{audio_path}"
  61. messages = [
  62. {
  63. "role": "user",
  64. "content": [{"audio": audio_path}],
  65. }
  66. ]
  67. response = None
  68. full_content = ""
  69. try:
  70. response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
  71. for response in response:
  72. try:
  73. full_content += response["output"]["choices"][0]["message"].content[0]["text"]
  74. except Exception:
  75. pass
  76. return full_content, num_tokens_from_string(full_content)
  77. except Exception as e:
  78. return "**ERROR**: " + str(e), 0
  79. class AzureSeq2txt(Base):
  80. _FACTORY_NAME = "Azure-OpenAI"
  81. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  82. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  83. self.model_name = model_name
  84. self.lang = lang
  85. class XinferenceSeq2txt(Base):
  86. _FACTORY_NAME = "Xinference"
  87. def __init__(self, key, model_name="whisper-small", **kwargs):
  88. self.base_url = kwargs.get("base_url", None)
  89. self.model_name = model_name
  90. self.key = key
  91. def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
  92. if isinstance(audio, str):
  93. audio_file = open(audio, "rb")
  94. audio_data = audio_file.read()
  95. audio_file_name = audio.split("/")[-1]
  96. else:
  97. audio_data = audio
  98. audio_file_name = "audio.wav"
  99. payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
  100. files = {"file": (audio_file_name, audio_data, "audio/wav")}
  101. try:
  102. response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
  103. response.raise_for_status()
  104. result = response.json()
  105. if "text" in result:
  106. transcription_text = result["text"].strip()
  107. return transcription_text, num_tokens_from_string(transcription_text)
  108. else:
  109. return "**ERROR**: Failed to retrieve transcription.", 0
  110. except requests.exceptions.RequestException as e:
  111. return f"**ERROR**: {str(e)}", 0
  112. class TencentCloudSeq2txt(Base):
  113. _FACTORY_NAME = "Tencent Cloud"
  114. def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
  115. from tencentcloud.asr.v20190614 import asr_client
  116. from tencentcloud.common import credential
  117. key = json.loads(key)
  118. sid = key.get("tencent_cloud_sid", "")
  119. sk = key.get("tencent_cloud_sk", "")
  120. cred = credential.Credential(sid, sk)
  121. self.client = asr_client.AsrClient(cred, "")
  122. self.model_name = model_name
  123. def transcription(self, audio, max_retries=60, retry_interval=5):
  124. import time
  125. from tencentcloud.asr.v20190614 import models
  126. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  127. TencentCloudSDKException,
  128. )
  129. b64 = self.audio2base64(audio)
  130. try:
  131. # dispatch disk
  132. req = models.CreateRecTaskRequest()
  133. params = {
  134. "EngineModelType": self.model_name,
  135. "ChannelNum": 1,
  136. "ResTextFormat": 0,
  137. "SourceType": 1,
  138. "Data": b64,
  139. }
  140. req.from_json_string(json.dumps(params))
  141. resp = self.client.CreateRecTask(req)
  142. # loop query
  143. req = models.DescribeTaskStatusRequest()
  144. params = {"TaskId": resp.Data.TaskId}
  145. req.from_json_string(json.dumps(params))
  146. retries = 0
  147. while retries < max_retries:
  148. resp = self.client.DescribeTaskStatus(req)
  149. if resp.Data.StatusStr == "success":
  150. text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
  151. return text, num_tokens_from_string(text)
  152. elif resp.Data.StatusStr == "failed":
  153. return (
  154. "**ERROR**: Failed to retrieve speech recognition results.",
  155. 0,
  156. )
  157. else:
  158. time.sleep(retry_interval)
  159. retries += 1
  160. return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
  161. except TencentCloudSDKException as e:
  162. return "**ERROR**: " + str(e), 0
  163. except Exception as e:
  164. return "**ERROR**: " + str(e), 0
  165. class GPUStackSeq2txt(Base):
  166. _FACTORY_NAME = "GPUStack"
  167. def __init__(self, key, model_name, base_url):
  168. if not base_url:
  169. raise ValueError("url cannot be None")
  170. if base_url.split("/")[-1] != "v1":
  171. base_url = os.path.join(base_url, "v1")
  172. self.base_url = base_url
  173. self.model_name = model_name
  174. self.key = key
  175. class GiteeSeq2txt(Base):
  176. _FACTORY_NAME = "GiteeAI"
  177. def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
  178. if not base_url:
  179. base_url = "https://ai.gitee.com/v1/"
  180. self.client = OpenAI(api_key=key, base_url=base_url)
  181. self.model_name = model_name
  182. class DeepInfraSeq2txt(Base):
  183. _FACTORY_NAME = "DeepInfra"
  184. def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
  185. if not base_url:
  186. base_url = "https://api.deepinfra.com/v1/openai"
  187. self.client = OpenAI(api_key=key, base_url=base_url)
  188. self.model_name = model_name