You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sequence2txt_model.py 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import requests
  18. from openai.lib.azure import AzureOpenAI
  19. import io
  20. from abc import ABC
  21. from openai import OpenAI
  22. import json
  23. from rag.utils import num_tokens_from_string
  24. import base64
  25. import re
  26. class Base(ABC):
  27. def __init__(self, key, model_name):
  28. pass
  29. def transcription(self, audio, **kwargs):
  30. transcription = self.client.audio.transcriptions.create(
  31. model=self.model_name,
  32. file=audio,
  33. response_format="text"
  34. )
  35. return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
  36. def audio2base64(self, audio):
  37. if isinstance(audio, bytes):
  38. return base64.b64encode(audio).decode("utf-8")
  39. if isinstance(audio, io.BytesIO):
  40. return base64.b64encode(audio.getvalue()).decode("utf-8")
  41. raise TypeError("The input audio file should be in binary format.")
  42. class GPTSeq2txt(Base):
  43. def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
  44. if not base_url:
  45. base_url = "https://api.openai.com/v1"
  46. self.client = OpenAI(api_key=key, base_url=base_url)
  47. self.model_name = model_name
  48. class QWenSeq2txt(Base):
  49. def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
  50. import dashscope
  51. dashscope.api_key = key
  52. self.model_name = model_name
  53. def transcription(self, audio, format):
  54. from http import HTTPStatus
  55. from dashscope.audio.asr import Recognition
  56. recognition = Recognition(model=self.model_name,
  57. format=format,
  58. sample_rate=16000,
  59. callback=None)
  60. result = recognition.call(audio)
  61. ans = ""
  62. if result.status_code == HTTPStatus.OK:
  63. for sentence in result.get_sentence():
  64. ans += sentence.text.decode('utf-8') + '\n'
  65. return ans, num_tokens_from_string(ans)
  66. return "**ERROR**: " + result.message, 0
  67. class AzureSeq2txt(Base):
  68. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  69. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  70. self.model_name = model_name
  71. self.lang = lang
  72. class XinferenceSeq2txt(Base):
  73. def __init__(self, key, model_name="whisper-small", **kwargs):
  74. self.base_url = kwargs.get('base_url', None)
  75. self.model_name = model_name
  76. self.key = key
  77. def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
  78. if isinstance(audio, str):
  79. audio_file = open(audio, 'rb')
  80. audio_data = audio_file.read()
  81. audio_file_name = audio.split("/")[-1]
  82. else:
  83. audio_data = audio
  84. audio_file_name = "audio.wav"
  85. payload = {
  86. "model": self.model_name,
  87. "language": language,
  88. "prompt": prompt,
  89. "response_format": response_format,
  90. "temperature": temperature
  91. }
  92. files = {
  93. "file": (audio_file_name, audio_data, 'audio/wav')
  94. }
  95. try:
  96. response = requests.post(
  97. f"{self.base_url}/v1/audio/transcriptions",
  98. files=files,
  99. data=payload
  100. )
  101. response.raise_for_status()
  102. result = response.json()
  103. if 'text' in result:
  104. transcription_text = result['text'].strip()
  105. return transcription_text, num_tokens_from_string(transcription_text)
  106. else:
  107. return "**ERROR**: Failed to retrieve transcription.", 0
  108. except requests.exceptions.RequestException as e:
  109. return f"**ERROR**: {str(e)}", 0
  110. class TencentCloudSeq2txt(Base):
  111. def __init__(
  112. self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
  113. ):
  114. from tencentcloud.common import credential
  115. from tencentcloud.asr.v20190614 import asr_client
  116. key = json.loads(key)
  117. sid = key.get("tencent_cloud_sid", "")
  118. sk = key.get("tencent_cloud_sk", "")
  119. cred = credential.Credential(sid, sk)
  120. self.client = asr_client.AsrClient(cred, "")
  121. self.model_name = model_name
  122. def transcription(self, audio, max_retries=60, retry_interval=5):
  123. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  124. TencentCloudSDKException,
  125. )
  126. from tencentcloud.asr.v20190614 import models
  127. import time
  128. b64 = self.audio2base64(audio)
  129. try:
  130. # dispatch disk
  131. req = models.CreateRecTaskRequest()
  132. params = {
  133. "EngineModelType": self.model_name,
  134. "ChannelNum": 1,
  135. "ResTextFormat": 0,
  136. "SourceType": 1,
  137. "Data": b64,
  138. }
  139. req.from_json_string(json.dumps(params))
  140. resp = self.client.CreateRecTask(req)
  141. # loop query
  142. req = models.DescribeTaskStatusRequest()
  143. params = {"TaskId": resp.Data.TaskId}
  144. req.from_json_string(json.dumps(params))
  145. retries = 0
  146. while retries < max_retries:
  147. resp = self.client.DescribeTaskStatus(req)
  148. if resp.Data.StatusStr == "success":
  149. text = re.sub(
  150. r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
  151. ).strip()
  152. return text, num_tokens_from_string(text)
  153. elif resp.Data.StatusStr == "failed":
  154. return (
  155. "**ERROR**: Failed to retrieve speech recognition results.",
  156. 0,
  157. )
  158. else:
  159. time.sleep(retry_interval)
  160. retries += 1
  161. return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
  162. except TencentCloudSDKException as e:
  163. return "**ERROR**: " + str(e), 0
  164. except Exception as e:
  165. return "**ERROR**: " + str(e), 0
  166. class GPUStackSeq2txt(Base):
  167. def __init__(self, key, model_name, base_url):
  168. if not base_url:
  169. raise ValueError("url cannot be None")
  170. if base_url.split("/")[-1] != "v1":
  171. base_url = os.path.join(base_url, "v1")
  172. self.base_url = base_url
  173. self.model_name = model_name
  174. self.key = key