You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sequence2txt_model.py 6.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import requests
  17. from openai.lib.azure import AzureOpenAI
  18. import io
  19. from abc import ABC
  20. from openai import OpenAI
  21. import json
  22. from rag.utils import num_tokens_from_string
  23. import base64
  24. import re
  25. class Base(ABC):
  26. def __init__(self, key, model_name):
  27. pass
  28. def transcription(self, audio, **kwargs):
  29. transcription = self.client.audio.transcriptions.create(
  30. model=self.model_name,
  31. file=audio,
  32. response_format="text"
  33. )
  34. return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
  35. def audio2base64(self, audio):
  36. if isinstance(audio, bytes):
  37. return base64.b64encode(audio).decode("utf-8")
  38. if isinstance(audio, io.BytesIO):
  39. return base64.b64encode(audio.getvalue()).decode("utf-8")
  40. raise TypeError("The input audio file should be in binary format.")
  41. class GPTSeq2txt(Base):
  42. def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
  43. if not base_url:
  44. base_url = "https://api.openai.com/v1"
  45. self.client = OpenAI(api_key=key, base_url=base_url)
  46. self.model_name = model_name
  47. class QWenSeq2txt(Base):
  48. def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
  49. import dashscope
  50. dashscope.api_key = key
  51. self.model_name = model_name
  52. def transcription(self, audio, format):
  53. from http import HTTPStatus
  54. from dashscope.audio.asr import Recognition
  55. recognition = Recognition(model=self.model_name,
  56. format=format,
  57. sample_rate=16000,
  58. callback=None)
  59. result = recognition.call(audio)
  60. ans = ""
  61. if result.status_code == HTTPStatus.OK:
  62. for sentence in result.get_sentence():
  63. ans += sentence.text.decode('utf-8') + '\n'
  64. return ans, num_tokens_from_string(ans)
  65. return "**ERROR**: " + result.message, 0
  66. class AzureSeq2txt(Base):
  67. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  68. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  69. self.model_name = model_name
  70. self.lang = lang
  71. class XinferenceSeq2txt(Base):
  72. def __init__(self, key, model_name="whisper-small", **kwargs):
  73. self.base_url = kwargs.get('base_url', None)
  74. self.model_name = model_name
  75. self.key = key
  76. def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
  77. if isinstance(audio, str):
  78. audio_file = open(audio, 'rb')
  79. audio_data = audio_file.read()
  80. audio_file_name = audio.split("/")[-1]
  81. else:
  82. audio_data = audio
  83. audio_file_name = "audio.wav"
  84. payload = {
  85. "model": self.model_name,
  86. "language": language,
  87. "prompt": prompt,
  88. "response_format": response_format,
  89. "temperature": temperature
  90. }
  91. files = {
  92. "file": (audio_file_name, audio_data, 'audio/wav')
  93. }
  94. try:
  95. response = requests.post(
  96. f"{self.base_url}/v1/audio/transcriptions",
  97. files=files,
  98. data=payload
  99. )
  100. response.raise_for_status()
  101. result = response.json()
  102. if 'text' in result:
  103. transcription_text = result['text'].strip()
  104. return transcription_text, num_tokens_from_string(transcription_text)
  105. else:
  106. return "**ERROR**: Failed to retrieve transcription.", 0
  107. except requests.exceptions.RequestException as e:
  108. return f"**ERROR**: {str(e)}", 0
  109. class TencentCloudSeq2txt(Base):
  110. def __init__(
  111. self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
  112. ):
  113. from tencentcloud.common import credential
  114. from tencentcloud.asr.v20190614 import asr_client
  115. key = json.loads(key)
  116. sid = key.get("tencent_cloud_sid", "")
  117. sk = key.get("tencent_cloud_sk", "")
  118. cred = credential.Credential(sid, sk)
  119. self.client = asr_client.AsrClient(cred, "")
  120. self.model_name = model_name
  121. def transcription(self, audio, max_retries=60, retry_interval=5):
  122. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  123. TencentCloudSDKException,
  124. )
  125. from tencentcloud.asr.v20190614 import models
  126. import time
  127. b64 = self.audio2base64(audio)
  128. try:
  129. # dispatch disk
  130. req = models.CreateRecTaskRequest()
  131. params = {
  132. "EngineModelType": self.model_name,
  133. "ChannelNum": 1,
  134. "ResTextFormat": 0,
  135. "SourceType": 1,
  136. "Data": b64,
  137. }
  138. req.from_json_string(json.dumps(params))
  139. resp = self.client.CreateRecTask(req)
  140. # loop query
  141. req = models.DescribeTaskStatusRequest()
  142. params = {"TaskId": resp.Data.TaskId}
  143. req.from_json_string(json.dumps(params))
  144. retries = 0
  145. while retries < max_retries:
  146. resp = self.client.DescribeTaskStatus(req)
  147. if resp.Data.StatusStr == "success":
  148. text = re.sub(
  149. r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
  150. ).strip()
  151. return text, num_tokens_from_string(text)
  152. elif resp.Data.StatusStr == "failed":
  153. return (
  154. "**ERROR**: Failed to retrieve speech recognition results.",
  155. 0,
  156. )
  157. else:
  158. time.sleep(retry_interval)
  159. retries += 1
  160. return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
  161. except TencentCloudSDKException as e:
  162. return "**ERROR**: " + str(e), 0
  163. except Exception as e:
  164. return "**ERROR**: " + str(e), 0