您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

sequence2txt_model.py 7.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import io
  18. import json
  19. import os
  20. import re
  21. from abc import ABC
  22. import requests
  23. from openai import OpenAI
  24. from openai.lib.azure import AzureOpenAI
  25. from rag.utils import num_tokens_from_string
  26. class Base(ABC):
  27. def __init__(self, key, model_name, **kwargs):
  28. """
  29. Abstract base class constructor.
  30. Parameters are not stored; initialization is left to subclasses.
  31. """
  32. pass
  33. def transcription(self, audio, **kwargs):
  34. transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text")
  35. return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
  36. def audio2base64(self, audio):
  37. if isinstance(audio, bytes):
  38. return base64.b64encode(audio).decode("utf-8")
  39. if isinstance(audio, io.BytesIO):
  40. return base64.b64encode(audio.getvalue()).decode("utf-8")
  41. raise TypeError("The input audio file should be in binary format.")
  42. class GPTSeq2txt(Base):
  43. _FACTORY_NAME = "OpenAI"
  44. def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
  45. if not base_url:
  46. base_url = "https://api.openai.com/v1"
  47. self.client = OpenAI(api_key=key, base_url=base_url)
  48. self.model_name = model_name
  49. class QWenSeq2txt(Base):
  50. _FACTORY_NAME = "Tongyi-Qianwen"
  51. def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
  52. import dashscope
  53. dashscope.api_key = key
  54. self.model_name = model_name
  55. def transcription(self, audio, format):
  56. from http import HTTPStatus
  57. from dashscope.audio.asr import Recognition
  58. recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None)
  59. result = recognition.call(audio)
  60. ans = ""
  61. if result.status_code == HTTPStatus.OK:
  62. for sentence in result.get_sentence():
  63. ans += sentence.text.decode("utf-8") + "\n"
  64. return ans, num_tokens_from_string(ans)
  65. return "**ERROR**: " + result.message, 0
  66. class AzureSeq2txt(Base):
  67. _FACTORY_NAME = "Azure-OpenAI"
  68. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  69. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  70. self.model_name = model_name
  71. self.lang = lang
  72. class XinferenceSeq2txt(Base):
  73. _FACTORY_NAME = "Xinference"
  74. def __init__(self, key, model_name="whisper-small", **kwargs):
  75. self.base_url = kwargs.get("base_url", None)
  76. self.model_name = model_name
  77. self.key = key
  78. def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
  79. if isinstance(audio, str):
  80. audio_file = open(audio, "rb")
  81. audio_data = audio_file.read()
  82. audio_file_name = audio.split("/")[-1]
  83. else:
  84. audio_data = audio
  85. audio_file_name = "audio.wav"
  86. payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
  87. files = {"file": (audio_file_name, audio_data, "audio/wav")}
  88. try:
  89. response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
  90. response.raise_for_status()
  91. result = response.json()
  92. if "text" in result:
  93. transcription_text = result["text"].strip()
  94. return transcription_text, num_tokens_from_string(transcription_text)
  95. else:
  96. return "**ERROR**: Failed to retrieve transcription.", 0
  97. except requests.exceptions.RequestException as e:
  98. return f"**ERROR**: {str(e)}", 0
  99. class TencentCloudSeq2txt(Base):
  100. _FACTORY_NAME = "Tencent Cloud"
  101. def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
  102. from tencentcloud.asr.v20190614 import asr_client
  103. from tencentcloud.common import credential
  104. key = json.loads(key)
  105. sid = key.get("tencent_cloud_sid", "")
  106. sk = key.get("tencent_cloud_sk", "")
  107. cred = credential.Credential(sid, sk)
  108. self.client = asr_client.AsrClient(cred, "")
  109. self.model_name = model_name
  110. def transcription(self, audio, max_retries=60, retry_interval=5):
  111. import time
  112. from tencentcloud.asr.v20190614 import models
  113. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  114. TencentCloudSDKException,
  115. )
  116. b64 = self.audio2base64(audio)
  117. try:
  118. # dispatch disk
  119. req = models.CreateRecTaskRequest()
  120. params = {
  121. "EngineModelType": self.model_name,
  122. "ChannelNum": 1,
  123. "ResTextFormat": 0,
  124. "SourceType": 1,
  125. "Data": b64,
  126. }
  127. req.from_json_string(json.dumps(params))
  128. resp = self.client.CreateRecTask(req)
  129. # loop query
  130. req = models.DescribeTaskStatusRequest()
  131. params = {"TaskId": resp.Data.TaskId}
  132. req.from_json_string(json.dumps(params))
  133. retries = 0
  134. while retries < max_retries:
  135. resp = self.client.DescribeTaskStatus(req)
  136. if resp.Data.StatusStr == "success":
  137. text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
  138. return text, num_tokens_from_string(text)
  139. elif resp.Data.StatusStr == "failed":
  140. return (
  141. "**ERROR**: Failed to retrieve speech recognition results.",
  142. 0,
  143. )
  144. else:
  145. time.sleep(retry_interval)
  146. retries += 1
  147. return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
  148. except TencentCloudSDKException as e:
  149. return "**ERROR**: " + str(e), 0
  150. except Exception as e:
  151. return "**ERROR**: " + str(e), 0
  152. class GPUStackSeq2txt(Base):
  153. _FACTORY_NAME = "GPUStack"
  154. def __init__(self, key, model_name, base_url):
  155. if not base_url:
  156. raise ValueError("url cannot be None")
  157. if base_url.split("/")[-1] != "v1":
  158. base_url = os.path.join(base_url, "v1")
  159. self.base_url = base_url
  160. self.model_name = model_name
  161. self.key = key
  162. class GiteeSeq2txt(Base):
  163. _FACTORY_NAME = "GiteeAI"
  164. def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
  165. if not base_url:
  166. base_url = "https://ai.gitee.com/v1/"
  167. self.client = OpenAI(api_key=key, base_url=base_url)
  168. self.model_name = model_name
  169. class DeepInfraSeq2txt(Base):
  170. _FACTORY_NAME = "DeepInfra"
  171. def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
  172. if not base_url:
  173. base_url = "https://api.deepinfra.com/v1/openai"
  174. self.client = OpenAI(api_key=key, base_url=base_url)
  175. self.model_name = model_name