Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

chat_model.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from zhipuai import ZhipuAI
  17. from dashscope import Generation
  18. from abc import ABC
  19. from openai import OpenAI
  20. import openai
  21. from rag.nlp import is_english
  22. from rag.utils import num_tokens_from_string
  23. class Base(ABC):
  24. def __init__(self, key, model_name):
  25. pass
  26. def chat(self, system, history, gen_conf):
  27. raise NotImplementedError("Please implement encode method!")
  28. class GptTurbo(Base):
  29. def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
  30. if not base_url: base_url="https://api.openai.com/v1"
  31. self.client = OpenAI(api_key=key, base_url=base_url)
  32. self.model_name = model_name
  33. def chat(self, system, history, gen_conf):
  34. if system:
  35. history.insert(0, {"role": "system", "content": system})
  36. try:
  37. response = self.client.chat.completions.create(
  38. model=self.model_name,
  39. messages=history,
  40. **gen_conf)
  41. ans = response.choices[0].message.content.strip()
  42. if response.choices[0].finish_reason == "length":
  43. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  44. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  45. return ans, response.usage.completion_tokens
  46. except openai.APIError as e:
  47. return "**ERROR**: " + str(e), 0
  48. class MoonshotChat(GptTurbo):
  49. def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
  50. if not base_url: base_url="https://api.moonshot.cn/v1"
  51. self.client = OpenAI(
  52. api_key=key, base_url=base_url)
  53. self.model_name = model_name
  54. def chat(self, system, history, gen_conf):
  55. if system:
  56. history.insert(0, {"role": "system", "content": system})
  57. try:
  58. response = self.client.chat.completions.create(
  59. model=self.model_name,
  60. messages=history,
  61. **gen_conf)
  62. ans = response.choices[0].message.content.strip()
  63. if response.choices[0].finish_reason == "length":
  64. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  65. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  66. return ans, response.usage.completion_tokens
  67. except openai.APIError as e:
  68. return "**ERROR**: " + str(e), 0
  69. class QWenChat(Base):
  70. def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
  71. import dashscope
  72. dashscope.api_key = key
  73. self.model_name = model_name
  74. def chat(self, system, history, gen_conf):
  75. from http import HTTPStatus
  76. if system:
  77. history.insert(0, {"role": "system", "content": system})
  78. response = Generation.call(
  79. self.model_name,
  80. messages=history,
  81. result_format='message',
  82. **gen_conf
  83. )
  84. ans = ""
  85. tk_count = 0
  86. if response.status_code == HTTPStatus.OK:
  87. ans += response.output.choices[0]['message']['content']
  88. tk_count += response.usage.output_tokens
  89. if response.output.choices[0].get("finish_reason", "") == "length":
  90. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  91. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  92. return ans, tk_count
  93. return "**ERROR**: " + response.message, tk_count
  94. class ZhipuChat(Base):
  95. def __init__(self, key, model_name="glm-3-turbo", **kwargs):
  96. self.client = ZhipuAI(api_key=key)
  97. self.model_name = model_name
  98. def chat(self, system, history, gen_conf):
  99. if system:
  100. history.insert(0, {"role": "system", "content": system})
  101. try:
  102. response = self.client.chat.completions.create(
  103. model=self.model_name,
  104. messages=history,
  105. **gen_conf
  106. )
  107. ans = response.choices[0].message.content.strip()
  108. if response.choices[0].finish_reason == "length":
  109. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  110. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  111. return ans, response.usage.completion_tokens
  112. except Exception as e:
  113. return "**ERROR**: " + str(e), 0
  114. class LocalLLM(Base):
  115. class RPCProxy:
  116. def __init__(self, host, port):
  117. self.host = host
  118. self.port = int(port)
  119. self.__conn()
  120. def __conn(self):
  121. from multiprocessing.connection import Client
  122. self._connection = Client(
  123. (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
  124. def __getattr__(self, name):
  125. import pickle
  126. def do_rpc(*args, **kwargs):
  127. for _ in range(3):
  128. try:
  129. self._connection.send(
  130. pickle.dumps((name, args, kwargs)))
  131. return pickle.loads(self._connection.recv())
  132. except Exception as e:
  133. self.__conn()
  134. raise Exception("RPC connection lost!")
  135. return do_rpc
  136. def __init__(self, **kwargs):
  137. self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
  138. def chat(self, system, history, gen_conf):
  139. if system:
  140. history.insert(0, {"role": "system", "content": system})
  141. try:
  142. ans = self.client.chat(
  143. history,
  144. gen_conf
  145. )
  146. return ans, num_tokens_from_string(ans)
  147. except Exception as e:
  148. return "**ERROR**: " + str(e), 0