Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

sequence2txt_model.py 5.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. import io
  19. from abc import ABC
  20. from ollama import Client
  21. from openai import OpenAI
  22. import os
  23. import json
  24. from rag.utils import num_tokens_from_string
  25. import base64
  26. import re
  27. class Base(ABC):
  28. def __init__(self, key, model_name):
  29. pass
  30. def transcription(self, audio, **kwargs):
  31. transcription = self.client.audio.transcriptions.create(
  32. model=self.model_name,
  33. file=audio,
  34. response_format="text"
  35. )
  36. return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
  37. def audio2base64(self,audio):
  38. if isinstance(audio, bytes):
  39. return base64.b64encode(audio).decode("utf-8")
  40. if isinstance(audio, io.BytesIO):
  41. return base64.b64encode(audio.getvalue()).decode("utf-8")
  42. raise TypeError("The input audio file should be in binary format.")
  43. class GPTSeq2txt(Base):
  44. def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
  45. if not base_url: base_url = "https://api.openai.com/v1"
  46. self.client = OpenAI(api_key=key, base_url=base_url)
  47. self.model_name = model_name
  48. class QWenSeq2txt(Base):
  49. def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
  50. import dashscope
  51. dashscope.api_key = key
  52. self.model_name = model_name
  53. def transcription(self, audio, format):
  54. from http import HTTPStatus
  55. from dashscope.audio.asr import Recognition
  56. recognition = Recognition(model=self.model_name,
  57. format=format,
  58. sample_rate=16000,
  59. callback=None)
  60. result = recognition.call(audio)
  61. ans = ""
  62. if result.status_code == HTTPStatus.OK:
  63. for sentence in result.get_sentence():
  64. ans += sentence.text.decode('utf-8') + '\n'
  65. return ans, num_tokens_from_string(ans)
  66. return "**ERROR**: " + result.message, 0
  67. class OllamaSeq2txt(Base):
  68. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  69. self.client = Client(host=kwargs["base_url"])
  70. self.model_name = model_name
  71. self.lang = lang
  72. class AzureSeq2txt(Base):
  73. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  74. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  75. self.model_name = model_name
  76. self.lang = lang
  77. class XinferenceSeq2txt(Base):
  78. def __init__(self, key, model_name="", base_url=""):
  79. if base_url.split("/")[-1] != "v1":
  80. base_url = os.path.join(base_url, "v1")
  81. self.client = OpenAI(api_key="xxx", base_url=base_url)
  82. self.model_name = model_name
  83. class TencentCloudSeq2txt(Base):
  84. def __init__(
  85. self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
  86. ):
  87. from tencentcloud.common import credential
  88. from tencentcloud.asr.v20190614 import asr_client
  89. key = json.loads(key)
  90. sid = key.get("tencent_cloud_sid", "")
  91. sk = key.get("tencent_cloud_sk", "")
  92. cred = credential.Credential(sid, sk)
  93. self.client = asr_client.AsrClient(cred, "")
  94. self.model_name = model_name
  95. def transcription(self, audio, max_retries=60, retry_interval=5):
  96. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  97. TencentCloudSDKException,
  98. )
  99. from tencentcloud.asr.v20190614 import models
  100. import time
  101. b64 = self.audio2base64(audio)
  102. try:
  103. # dispatch disk
  104. req = models.CreateRecTaskRequest()
  105. params = {
  106. "EngineModelType": self.model_name,
  107. "ChannelNum": 1,
  108. "ResTextFormat": 0,
  109. "SourceType": 1,
  110. "Data": b64,
  111. }
  112. req.from_json_string(json.dumps(params))
  113. resp = self.client.CreateRecTask(req)
  114. # loop query
  115. req = models.DescribeTaskStatusRequest()
  116. params = {"TaskId": resp.Data.TaskId}
  117. req.from_json_string(json.dumps(params))
  118. retries = 0
  119. while retries < max_retries:
  120. resp = self.client.DescribeTaskStatus(req)
  121. if resp.Data.StatusStr == "success":
  122. text = re.sub(
  123. r"\[\d+:\d+\.\d+,\d+:\d+\.\d+\]\s*", "", resp.Data.Result
  124. ).strip()
  125. return text, num_tokens_from_string(text)
  126. elif resp.Data.StatusStr == "failed":
  127. return (
  128. "**ERROR**: Failed to retrieve speech recognition results.",
  129. 0,
  130. )
  131. else:
  132. time.sleep(retry_interval)
  133. retries += 1
  134. return "**ERROR**: Max retries exceeded. Task may still be processing.", 0
  135. except TencentCloudSDKException as e:
  136. return "**ERROR**: " + str(e), 0
  137. except Exception as e:
  138. return "**ERROR**: " + str(e), 0