Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

sequence2txt_model.py 7.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import io
  18. import json
  19. import os
  20. import re
  21. from abc import ABC
  22. import requests
  23. from openai import OpenAI
  24. from openai.lib.azure import AzureOpenAI
  25. from rag.utils import num_tokens_from_string
  26. class Base(ABC):
  27. def __init__(self, key, model_name):
  28. pass
  29. def transcription(self, audio, **kwargs):
  30. transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text")
  31. return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
  32. def audio2base64(self, audio):
  33. if isinstance(audio, bytes):
  34. return base64.b64encode(audio).decode("utf-8")
  35. if isinstance(audio, io.BytesIO):
  36. return base64.b64encode(audio.getvalue()).decode("utf-8")
  37. raise TypeError("The input audio file should be in binary format.")
  38. class GPTSeq2txt(Base):
  39. _FACTORY_NAME = "OpenAI"
  40. def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
  41. if not base_url:
  42. base_url = "https://api.openai.com/v1"
  43. self.client = OpenAI(api_key=key, base_url=base_url)
  44. self.model_name = model_name
  45. class QWenSeq2txt(Base):
  46. _FACTORY_NAME = "Tongyi-Qianwen"
  47. def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
  48. import dashscope
  49. dashscope.api_key = key
  50. self.model_name = model_name
  51. def transcription(self, audio, format):
  52. from http import HTTPStatus
  53. from dashscope.audio.asr import Recognition
  54. recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None)
  55. result = recognition.call(audio)
  56. ans = ""
  57. if result.status_code == HTTPStatus.OK:
  58. for sentence in result.get_sentence():
  59. ans += sentence.text.decode("utf-8") + "\n"
  60. return ans, num_tokens_from_string(ans)
  61. return "**ERROR**: " + result.message, 0
  62. class AzureSeq2txt(Base):
  63. _FACTORY_NAME = "Azure-OpenAI"
  64. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  65. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  66. self.model_name = model_name
  67. self.lang = lang
  68. class XinferenceSeq2txt(Base):
  69. _FACTORY_NAME = "Xinference"
  70. def __init__(self, key, model_name="whisper-small", **kwargs):
  71. self.base_url = kwargs.get("base_url", None)
  72. self.model_name = model_name
  73. self.key = key
  74. def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
  75. if isinstance(audio, str):
  76. audio_file = open(audio, "rb")
  77. audio_data = audio_file.read()
  78. audio_file_name = audio.split("/")[-1]
  79. else:
  80. audio_data = audio
  81. audio_file_name = "audio.wav"
  82. payload = {"model": self.model_name, "language": language, "prompt": prompt, "response_format": response_format, "temperature": temperature}
  83. files = {"file": (audio_file_name, audio_data, "audio/wav")}
  84. try:
  85. response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload)
  86. response.raise_for_status()
  87. result = response.json()
  88. if "text" in result:
  89. transcription_text = result["text"].strip()
  90. return transcription_text, num_tokens_from_string(transcription_text)
  91. else:
  92. return "**ERROR**: Failed to retrieve transcription.", 0
  93. except requests.exceptions.RequestException as e:
  94. return f"**ERROR**: {str(e)}", 0
  95. class TencentCloudSeq2txt(Base):
  96. _FACTORY_NAME = "Tencent Cloud"
  97. def __init__(self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"):
  98. from tencentcloud.asr.v20190614 import asr_client
  99. from tencentcloud.common import credential
  100. key = json.loads(key)
  101. sid = key.get("tencent_cloud_sid", "")
  102. sk = key.get("tencent_cloud_sk", "")
  103. cred = credential.Credential(sid, sk)
  104. self.client = asr_client.AsrClient(cred, "")
  105. self.model_name = model_name
  106. def transcription(self, audio, max_retries=60, retry_interval=5):
  107. import time
  108. from tencentcloud.asr.v20190614 import models
  109. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  110. TencentCloudSDKException,
  111. )
  112. b64 = self.audio2base64(audio)
  113. try:
  114. # dispatch disk
  115. req = models.CreateRecTaskRequest()
  116. params = {
  117. "EngineModelType": self.model_name,
  118. "ChannelNum": 1,
  119. "ResTextFormat": 0,
  120. "SourceType": 1,
  121. "Data": b64,
  122. }
  123. req.from_json_string(json.dumps(params))
  124. resp = self.client.CreateRecTask(req)
  125. # loop query
  126. req = models.DescribeTaskStatusRequest()
  127. params = {"TaskId": resp.Data.TaskId}
  128. req.from_json_string(json.dumps(params))
  129. retries = 0
  130. while retries < max_retries:
  131. resp = self.client.DescribeTaskStatus(req)
  132. if resp.Data.StatusStr == "success":
  133. text = re.sub(r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result).strip()
  134. return text, num_tokens_from_string(text)
  135. elif resp.Data.StatusStr == "failed":
  136. return (
  137. "**ERROR**: Failed to retrieve speech recognition results.",
  138. 0,
  139. )
  140. else:
  141. time.sleep(retry_interval)
  142. retries += 1
  143. return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
  144. except TencentCloudSDKException as e:
  145. return "**ERROR**: " + str(e), 0
  146. except Exception as e:
  147. return "**ERROR**: " + str(e), 0
  148. class GPUStackSeq2txt(Base):
  149. _FACTORY_NAME = "GPUStack"
  150. def __init__(self, key, model_name, base_url):
  151. if not base_url:
  152. raise ValueError("url cannot be None")
  153. if base_url.split("/")[-1] != "v1":
  154. base_url = os.path.join(base_url, "v1")
  155. self.base_url = base_url
  156. self.model_name = model_name
  157. self.key = key
  158. class GiteeSeq2txt(Base):
  159. _FACTORY_NAME = "GiteeAI"
  160. def __init__(self, key, model_name="whisper-1", base_url="https://ai.gitee.com/v1/"):
  161. if not base_url:
  162. base_url = "https://ai.gitee.com/v1/"
  163. self.client = OpenAI(api_key=key, base_url=base_url)
  164. self.model_name = model_name
  165. class DeepInfraSeq2txt(Base):
  166. _FACTORY_NAME = "DeepInfra"
  167. def __init__(self, key, model_name, base_url="https://api.deepinfra.com/v1/openai", **kwargs):
  168. if not base_url:
  169. base_url = "https://api.deepinfra.com/v1/openai"
  170. self.client = OpenAI(api_key=key, base_url=base_url)
  171. self.model_name = model_name