Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

chat_model.py 6.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from abc import ABC
  17. from openai import OpenAI
  18. import openai
  19. from rag.nlp import is_english
  20. from rag.utils import num_tokens_from_string
  21. class Base(ABC):
  22. def __init__(self, key, model_name):
  23. pass
  24. def chat(self, system, history, gen_conf):
  25. raise NotImplementedError("Please implement encode method!")
  26. class GptTurbo(Base):
  27. def __init__(self, key, model_name="gpt-3.5-turbo"):
  28. self.client = OpenAI(api_key=key)
  29. self.model_name = model_name
  30. def chat(self, system, history, gen_conf):
  31. if system: history.insert(0, {"role": "system", "content": system})
  32. try:
  33. response = self.client.chat.completions.create(
  34. model=self.model_name,
  35. messages=history,
  36. **gen_conf)
  37. ans = response.output.choices[0]['message']['content'].strip()
  38. if response.output.choices[0].get("finish_reason", "") == "length":
  39. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  40. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  41. return ans, response.usage.completion_tokens
  42. except openai.APIError as e:
  43. return "**ERROR**: "+str(e), 0
  44. class MoonshotChat(GptTurbo):
  45. def __init__(self, key, model_name="moonshot-v1-8k"):
  46. self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",)
  47. self.model_name = model_name
  48. def chat(self, system, history, gen_conf):
  49. if system: history.insert(0, {"role": "system", "content": system})
  50. try:
  51. response = self.client.chat.completions.create(
  52. model=self.model_name,
  53. messages=history,
  54. **gen_conf)
  55. ans = response.choices[0].message.content.strip()
  56. if response.choices[0].finish_reason == "length":
  57. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  58. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  59. return ans, response.usage.completion_tokens
  60. except openai.APIError as e:
  61. return "**ERROR**: "+str(e), 0
  62. from dashscope import Generation
  63. class QWenChat(Base):
  64. def __init__(self, key, model_name=Generation.Models.qwen_turbo):
  65. import dashscope
  66. dashscope.api_key = key
  67. self.model_name = model_name
  68. def chat(self, system, history, gen_conf):
  69. from http import HTTPStatus
  70. if system: history.insert(0, {"role": "system", "content": system})
  71. response = Generation.call(
  72. self.model_name,
  73. messages=history,
  74. result_format='message',
  75. **gen_conf
  76. )
  77. ans = ""
  78. tk_count = 0
  79. if response.status_code == HTTPStatus.OK:
  80. ans += response.output.choices[0]['message']['content']
  81. tk_count += response.usage.output_tokens
  82. if response.output.choices[0].get("finish_reason", "") == "length":
  83. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  84. return ans, tk_count
  85. return "**ERROR**: " + response.message, tk_count
  86. from zhipuai import ZhipuAI
  87. class ZhipuChat(Base):
  88. def __init__(self, key, model_name="glm-3-turbo"):
  89. self.client = ZhipuAI(api_key=key)
  90. self.model_name = model_name
  91. def chat(self, system, history, gen_conf):
  92. if system: history.insert(0, {"role": "system", "content": system})
  93. try:
  94. response = self.client.chat.completions.create(
  95. self.model_name,
  96. messages=history,
  97. **gen_conf
  98. )
  99. ans = response.output.choices[0]['message']['content'].strip()
  100. if response.output.choices[0].get("finish_reason", "") == "length":
  101. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  102. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  103. return ans, response.usage.completion_tokens
  104. except Exception as e:
  105. return "**ERROR**: " + str(e), 0
  106. class LocalLLM(Base):
  107. class RPCProxy:
  108. def __init__(self, host, port):
  109. self.host = host
  110. self.port = int(port)
  111. self.__conn()
  112. def __conn(self):
  113. from multiprocessing.connection import Client
  114. self._connection = Client((self.host, self.port), authkey=b'infiniflow-token4kevinhu')
  115. def __getattr__(self, name):
  116. import pickle
  117. def do_rpc(*args, **kwargs):
  118. for _ in range(3):
  119. try:
  120. self._connection.send(pickle.dumps((name, args, kwargs)))
  121. return pickle.loads(self._connection.recv())
  122. except Exception as e:
  123. self.__conn()
  124. raise Exception("RPC connection lost!")
  125. return do_rpc
  126. def __init__(self, key, model_name="glm-3-turbo"):
  127. self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
  128. def chat(self, system, history, gen_conf):
  129. if system: history.insert(0, {"role": "system", "content": system})
  130. try:
  131. ans = self.client.chat(
  132. history,
  133. gen_conf
  134. )
  135. return ans, num_tokens_from_string(ans)
  136. except Exception as e:
  137. return "**ERROR**: " + str(e), 0