您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

cv_model.py 39KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import io
  18. import json
  19. import os
  20. from abc import ABC
  21. from io import BytesIO
  22. from urllib.parse import urljoin
  23. import requests
  24. from ollama import Client
  25. from openai import OpenAI
  26. from openai.lib.azure import AzureOpenAI
  27. from PIL import Image
  28. from zhipuai import ZhipuAI
  29. from api.utils import get_uuid
  30. from api.utils.file_utils import get_project_base_directory
  31. from rag.nlp import is_english
  32. from rag.prompts import vision_llm_describe_prompt
  33. from rag.utils import num_tokens_from_string
  34. class Base(ABC):
  35. def __init__(self, key, model_name):
  36. pass
  37. def describe(self, image):
  38. raise NotImplementedError("Please implement encode method!")
  39. def describe_with_prompt(self, image, prompt=None):
  40. raise NotImplementedError("Please implement encode method!")
  41. def chat(self, system, history, gen_conf, image=""):
  42. if system:
  43. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  44. try:
  45. for his in history:
  46. if his["role"] == "user":
  47. his["content"] = self.chat_prompt(his["content"], image)
  48. response = self.client.chat.completions.create(
  49. model=self.model_name,
  50. messages=history,
  51. temperature=gen_conf.get("temperature", 0.3),
  52. top_p=gen_conf.get("top_p", 0.7)
  53. )
  54. return response.choices[0].message.content.strip(), response.usage.total_tokens
  55. except Exception as e:
  56. return "**ERROR**: " + str(e), 0
  57. def chat_streamly(self, system, history, gen_conf, image=""):
  58. if system:
  59. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  60. ans = ""
  61. tk_count = 0
  62. try:
  63. for his in history:
  64. if his["role"] == "user":
  65. his["content"] = self.chat_prompt(his["content"], image)
  66. response = self.client.chat.completions.create(
  67. model=self.model_name,
  68. messages=history,
  69. temperature=gen_conf.get("temperature", 0.3),
  70. top_p=gen_conf.get("top_p", 0.7),
  71. stream=True
  72. )
  73. for resp in response:
  74. if not resp.choices[0].delta.content:
  75. continue
  76. delta = resp.choices[0].delta.content
  77. ans += delta
  78. if resp.choices[0].finish_reason == "length":
  79. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  80. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  81. tk_count = resp.usage.total_tokens
  82. if resp.choices[0].finish_reason == "stop":
  83. tk_count = resp.usage.total_tokens
  84. yield ans
  85. except Exception as e:
  86. yield ans + "\n**ERROR**: " + str(e)
  87. yield tk_count
  88. def image2base64(self, image):
  89. if isinstance(image, bytes):
  90. return base64.b64encode(image).decode("utf-8")
  91. if isinstance(image, BytesIO):
  92. return base64.b64encode(image.getvalue()).decode("utf-8")
  93. buffered = BytesIO()
  94. try:
  95. image.save(buffered, format="JPEG")
  96. except Exception:
  97. image.save(buffered, format="PNG")
  98. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  99. def prompt(self, b64):
  100. return [
  101. {
  102. "role": "user",
  103. "content": [
  104. {
  105. "type": "image_url",
  106. "image_url": {
  107. "url": f"data:image/jpeg;base64,{b64}"
  108. },
  109. },
  110. {
  111. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  112. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  113. },
  114. ],
  115. }
  116. ]
  117. def vision_llm_prompt(self, b64, prompt=None):
  118. return [
  119. {
  120. "role": "user",
  121. "content": [
  122. {
  123. "type": "image_url",
  124. "image_url": {
  125. "url": f"data:image/jpeg;base64,{b64}"
  126. },
  127. },
  128. {
  129. "type": "text",
  130. "text": prompt if prompt else vision_llm_describe_prompt(),
  131. },
  132. ],
  133. }
  134. ]
  135. def chat_prompt(self, text, b64):
  136. return [
  137. {
  138. "type": "image_url",
  139. "image_url": {
  140. "url": f"data:image/jpeg;base64,{b64}",
  141. },
  142. },
  143. {
  144. "type": "text",
  145. "text": text
  146. },
  147. ]
  148. class GptV4(Base):
  149. def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
  150. if not base_url:
  151. base_url = "https://api.openai.com/v1"
  152. self.client = OpenAI(api_key=key, base_url=base_url)
  153. self.model_name = model_name
  154. self.lang = lang
  155. def describe(self, image):
  156. b64 = self.image2base64(image)
  157. prompt = self.prompt(b64)
  158. for i in range(len(prompt)):
  159. for c in prompt[i]["content"]:
  160. if "text" in c:
  161. c["type"] = "text"
  162. res = self.client.chat.completions.create(
  163. model=self.model_name,
  164. messages=prompt
  165. )
  166. return res.choices[0].message.content.strip(), res.usage.total_tokens
  167. def describe_with_prompt(self, image, prompt=None):
  168. b64 = self.image2base64(image)
  169. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  170. res = self.client.chat.completions.create(
  171. model=self.model_name,
  172. messages=vision_prompt,
  173. )
  174. return res.choices[0].message.content.strip(), res.usage.total_tokens
  175. class AzureGptV4(Base):
  176. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  177. api_key = json.loads(key).get('api_key', '')
  178. api_version = json.loads(key).get('api_version', '2024-02-01')
  179. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  180. self.model_name = model_name
  181. self.lang = lang
  182. def describe(self, image):
  183. b64 = self.image2base64(image)
  184. prompt = self.prompt(b64)
  185. for i in range(len(prompt)):
  186. for c in prompt[i]["content"]:
  187. if "text" in c:
  188. c["type"] = "text"
  189. res = self.client.chat.completions.create(
  190. model=self.model_name,
  191. messages=prompt
  192. )
  193. return res.choices[0].message.content.strip(), res.usage.total_tokens
  194. def describe_with_prompt(self, image, prompt=None):
  195. b64 = self.image2base64(image)
  196. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  197. res = self.client.chat.completions.create(
  198. model=self.model_name,
  199. messages=vision_prompt,
  200. )
  201. return res.choices[0].message.content.strip(), res.usage.total_tokens
  202. class QWenCV(Base):
  203. def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
  204. import dashscope
  205. dashscope.api_key = key
  206. self.model_name = model_name
  207. self.lang = lang
  208. def prompt(self, binary):
  209. # stupid as hell
  210. tmp_dir = get_project_base_directory("tmp")
  211. if not os.path.exists(tmp_dir):
  212. os.mkdir(tmp_dir)
  213. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  214. Image.open(io.BytesIO(binary)).save(path)
  215. return [
  216. {
  217. "role": "user",
  218. "content": [
  219. {
  220. "image": f"file://{path}"
  221. },
  222. {
  223. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  224. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  225. },
  226. ],
  227. }
  228. ]
  229. def vision_llm_prompt(self, binary, prompt=None):
  230. # stupid as hell
  231. tmp_dir = get_project_base_directory("tmp")
  232. if not os.path.exists(tmp_dir):
  233. os.mkdir(tmp_dir)
  234. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  235. Image.open(io.BytesIO(binary)).save(path)
  236. return [
  237. {
  238. "role": "user",
  239. "content": [
  240. {
  241. "image": f"file://{path}"
  242. },
  243. {
  244. "text": prompt if prompt else vision_llm_describe_prompt(),
  245. },
  246. ],
  247. }
  248. ]
  249. def chat_prompt(self, text, b64):
  250. return [
  251. {"image": f"{b64}"},
  252. {"text": text},
  253. ]
  254. def describe(self, image):
  255. from http import HTTPStatus
  256. from dashscope import MultiModalConversation
  257. response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
  258. if response.status_code == HTTPStatus.OK:
  259. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  260. return response.message, 0
  261. def describe_with_prompt(self, image, prompt=None):
  262. from http import HTTPStatus
  263. from dashscope import MultiModalConversation
  264. vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
  265. response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
  266. if response.status_code == HTTPStatus.OK:
  267. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  268. return response.message, 0
  269. def chat(self, system, history, gen_conf, image=""):
  270. from http import HTTPStatus
  271. from dashscope import MultiModalConversation
  272. if system:
  273. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  274. for his in history:
  275. if his["role"] == "user":
  276. his["content"] = self.chat_prompt(his["content"], image)
  277. response = MultiModalConversation.call(model=self.model_name, messages=history,
  278. temperature=gen_conf.get("temperature", 0.3),
  279. top_p=gen_conf.get("top_p", 0.7))
  280. ans = ""
  281. tk_count = 0
  282. if response.status_code == HTTPStatus.OK:
  283. ans += response.output.choices[0]['message']['content']
  284. tk_count += response.usage.total_tokens
  285. if response.output.choices[0].get("finish_reason", "") == "length":
  286. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  287. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  288. return ans, tk_count
  289. return "**ERROR**: " + response.message, tk_count
  290. def chat_streamly(self, system, history, gen_conf, image=""):
  291. from http import HTTPStatus
  292. from dashscope import MultiModalConversation
  293. if system:
  294. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  295. for his in history:
  296. if his["role"] == "user":
  297. his["content"] = self.chat_prompt(his["content"], image)
  298. ans = ""
  299. tk_count = 0
  300. try:
  301. response = MultiModalConversation.call(model=self.model_name, messages=history,
  302. temperature=gen_conf.get("temperature", 0.3),
  303. top_p=gen_conf.get("top_p", 0.7),
  304. stream=True)
  305. for resp in response:
  306. if resp.status_code == HTTPStatus.OK:
  307. ans = resp.output.choices[0]['message']['content']
  308. tk_count = resp.usage.total_tokens
  309. if resp.output.choices[0].get("finish_reason", "") == "length":
  310. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  311. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  312. yield ans
  313. else:
  314. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
  315. "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  316. except Exception as e:
  317. yield ans + "\n**ERROR**: " + str(e)
  318. yield tk_count
  319. class Zhipu4V(Base):
  320. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  321. self.client = ZhipuAI(api_key=key)
  322. self.model_name = model_name
  323. self.lang = lang
  324. def describe(self, image):
  325. b64 = self.image2base64(image)
  326. prompt = self.prompt(b64)
  327. prompt[0]["content"][1]["type"] = "text"
  328. res = self.client.chat.completions.create(
  329. model=self.model_name,
  330. messages=prompt,
  331. )
  332. return res.choices[0].message.content.strip(), res.usage.total_tokens
  333. def describe_with_prompt(self, image, prompt=None):
  334. b64 = self.image2base64(image)
  335. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  336. res = self.client.chat.completions.create(
  337. model=self.model_name,
  338. messages=vision_prompt
  339. )
  340. return res.choices[0].message.content.strip(), res.usage.total_tokens
  341. def chat(self, system, history, gen_conf, image=""):
  342. if system:
  343. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  344. try:
  345. for his in history:
  346. if his["role"] == "user":
  347. his["content"] = self.chat_prompt(his["content"], image)
  348. response = self.client.chat.completions.create(
  349. model=self.model_name,
  350. messages=history,
  351. temperature=gen_conf.get("temperature", 0.3),
  352. top_p=gen_conf.get("top_p", 0.7)
  353. )
  354. return response.choices[0].message.content.strip(), response.usage.total_tokens
  355. except Exception as e:
  356. return "**ERROR**: " + str(e), 0
  357. def chat_streamly(self, system, history, gen_conf, image=""):
  358. if system:
  359. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  360. ans = ""
  361. tk_count = 0
  362. try:
  363. for his in history:
  364. if his["role"] == "user":
  365. his["content"] = self.chat_prompt(his["content"], image)
  366. response = self.client.chat.completions.create(
  367. model=self.model_name,
  368. messages=history,
  369. temperature=gen_conf.get("temperature", 0.3),
  370. top_p=gen_conf.get("top_p", 0.7),
  371. stream=True
  372. )
  373. for resp in response:
  374. if not resp.choices[0].delta.content:
  375. continue
  376. delta = resp.choices[0].delta.content
  377. ans += delta
  378. if resp.choices[0].finish_reason == "length":
  379. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  380. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  381. tk_count = resp.usage.total_tokens
  382. if resp.choices[0].finish_reason == "stop":
  383. tk_count = resp.usage.total_tokens
  384. yield ans
  385. except Exception as e:
  386. yield ans + "\n**ERROR**: " + str(e)
  387. yield tk_count
  388. class OllamaCV(Base):
  389. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  390. self.client = Client(host=kwargs["base_url"])
  391. self.model_name = model_name
  392. self.lang = lang
  393. def describe(self, image):
  394. prompt = self.prompt("")
  395. try:
  396. response = self.client.generate(
  397. model=self.model_name,
  398. prompt=prompt[0]["content"][1]["text"],
  399. images=[image]
  400. )
  401. ans = response["response"].strip()
  402. return ans, 128
  403. except Exception as e:
  404. return "**ERROR**: " + str(e), 0
  405. def describe_with_prompt(self, image, prompt=None):
  406. vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
  407. try:
  408. response = self.client.generate(
  409. model=self.model_name,
  410. prompt=vision_prompt[0]["content"][1]["text"],
  411. images=[image],
  412. )
  413. ans = response["response"].strip()
  414. return ans, 128
  415. except Exception as e:
  416. return "**ERROR**: " + str(e), 0
  417. def chat(self, system, history, gen_conf, image=""):
  418. if system:
  419. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  420. try:
  421. for his in history:
  422. if his["role"] == "user":
  423. his["images"] = [image]
  424. options = {}
  425. if "temperature" in gen_conf:
  426. options["temperature"] = gen_conf["temperature"]
  427. if "top_p" in gen_conf:
  428. options["top_k"] = gen_conf["top_p"]
  429. if "presence_penalty" in gen_conf:
  430. options["presence_penalty"] = gen_conf["presence_penalty"]
  431. if "frequency_penalty" in gen_conf:
  432. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  433. response = self.client.chat(
  434. model=self.model_name,
  435. messages=history,
  436. options=options
  437. )
  438. ans = response["message"]["content"].strip()
  439. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  440. except Exception as e:
  441. return "**ERROR**: " + str(e), 0
  442. def chat_streamly(self, system, history, gen_conf, image=""):
  443. if system:
  444. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  445. for his in history:
  446. if his["role"] == "user":
  447. his["images"] = [image]
  448. options = {}
  449. if "temperature" in gen_conf:
  450. options["temperature"] = gen_conf["temperature"]
  451. if "top_p" in gen_conf:
  452. options["top_k"] = gen_conf["top_p"]
  453. if "presence_penalty" in gen_conf:
  454. options["presence_penalty"] = gen_conf["presence_penalty"]
  455. if "frequency_penalty" in gen_conf:
  456. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  457. ans = ""
  458. try:
  459. response = self.client.chat(
  460. model=self.model_name,
  461. messages=history,
  462. stream=True,
  463. options=options
  464. )
  465. for resp in response:
  466. if resp["done"]:
  467. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  468. ans += resp["message"]["content"]
  469. yield ans
  470. except Exception as e:
  471. yield ans + "\n**ERROR**: " + str(e)
  472. yield 0
  473. class LocalAICV(GptV4):
  474. def __init__(self, key, model_name, base_url, lang="Chinese"):
  475. if not base_url:
  476. raise ValueError("Local cv model url cannot be None")
  477. base_url = urljoin(base_url, "v1")
  478. self.client = OpenAI(api_key="empty", base_url=base_url)
  479. self.model_name = model_name.split("___")[0]
  480. self.lang = lang
  481. class XinferenceCV(Base):
  482. def __init__(self, key, model_name="", lang="Chinese", base_url=""):
  483. base_url = urljoin(base_url, "v1")
  484. self.client = OpenAI(api_key=key, base_url=base_url)
  485. self.model_name = model_name
  486. self.lang = lang
  487. def describe(self, image):
  488. b64 = self.image2base64(image)
  489. res = self.client.chat.completions.create(
  490. model=self.model_name,
  491. messages=self.prompt(b64)
  492. )
  493. return res.choices[0].message.content.strip(), res.usage.total_tokens
  494. def describe_with_prompt(self, image, prompt=None):
  495. b64 = self.image2base64(image)
  496. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  497. res = self.client.chat.completions.create(
  498. model=self.model_name,
  499. messages=vision_prompt,
  500. )
  501. return res.choices[0].message.content.strip(), res.usage.total_tokens
  502. class GeminiCV(Base):
  503. def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
  504. from google.generativeai import GenerativeModel, client
  505. client.configure(api_key=key)
  506. _client = client.get_default_generative_client()
  507. self.model_name = model_name
  508. self.model = GenerativeModel(model_name=self.model_name)
  509. self.model._client = _client
  510. self.lang = lang
  511. def describe(self, image):
  512. from PIL.Image import open
  513. prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
  514. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  515. b64 = self.image2base64(image)
  516. img = open(BytesIO(base64.b64decode(b64)))
  517. input = [prompt, img]
  518. res = self.model.generate_content(
  519. input
  520. )
  521. return res.text, res.usage_metadata.total_token_count
  522. def describe_with_prompt(self, image, prompt=None):
  523. from PIL.Image import open
  524. b64 = self.image2base64(image)
  525. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  526. img = open(BytesIO(base64.b64decode(b64)))
  527. input = [vision_prompt, img]
  528. res = self.model.generate_content(
  529. input,
  530. )
  531. return res.text, res.usage_metadata.total_token_count
  532. def chat(self, system, history, gen_conf, image=""):
  533. from transformers import GenerationConfig
  534. if system:
  535. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  536. try:
  537. for his in history:
  538. if his["role"] == "assistant":
  539. his["role"] = "model"
  540. his["parts"] = [his["content"]]
  541. his.pop("content")
  542. if his["role"] == "user":
  543. his["parts"] = [his["content"]]
  544. his.pop("content")
  545. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  546. response = self.model.generate_content(history, generation_config=GenerationConfig(
  547. temperature=gen_conf.get("temperature", 0.3),
  548. top_p=gen_conf.get("top_p", 0.7)))
  549. ans = response.text
  550. return ans, response.usage_metadata.total_token_count
  551. except Exception as e:
  552. return "**ERROR**: " + str(e), 0
  553. def chat_streamly(self, system, history, gen_conf, image=""):
  554. from transformers import GenerationConfig
  555. if system:
  556. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  557. ans = ""
  558. try:
  559. for his in history:
  560. if his["role"] == "assistant":
  561. his["role"] = "model"
  562. his["parts"] = [his["content"]]
  563. his.pop("content")
  564. if his["role"] == "user":
  565. his["parts"] = [his["content"]]
  566. his.pop("content")
  567. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  568. response = self.model.generate_content(history, generation_config=GenerationConfig(
  569. temperature=gen_conf.get("temperature", 0.3),
  570. top_p=gen_conf.get("top_p", 0.7)), stream=True)
  571. for resp in response:
  572. if not resp.text:
  573. continue
  574. ans += resp.text
  575. yield ans
  576. except Exception as e:
  577. yield ans + "\n**ERROR**: " + str(e)
  578. yield response._chunks[-1].usage_metadata.total_token_count
  579. class OpenRouterCV(GptV4):
  580. def __init__(
  581. self,
  582. key,
  583. model_name,
  584. lang="Chinese",
  585. base_url="https://openrouter.ai/api/v1",
  586. ):
  587. if not base_url:
  588. base_url = "https://openrouter.ai/api/v1"
  589. self.client = OpenAI(api_key=key, base_url=base_url)
  590. self.model_name = model_name
  591. self.lang = lang
  592. class LocalCV(Base):
  593. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  594. pass
  595. def describe(self, image):
  596. return "", 0
  597. class NvidiaCV(Base):
  598. def __init__(
  599. self,
  600. key,
  601. model_name,
  602. lang="Chinese",
  603. base_url="https://ai.api.nvidia.com/v1/vlm",
  604. ):
  605. if not base_url:
  606. base_url = ("https://ai.api.nvidia.com/v1/vlm",)
  607. self.lang = lang
  608. factory, llm_name = model_name.split("/")
  609. if factory != "liuhaotian":
  610. self.base_url = urljoin(base_url, f"{factory}/{llm_name}")
  611. else:
  612. self.base_url = urljoin(f"{base_url}/community", llm_name.replace("-v1.6", "16"))
  613. self.key = key
  614. def describe(self, image):
  615. b64 = self.image2base64(image)
  616. response = requests.post(
  617. url=self.base_url,
  618. headers={
  619. "accept": "application/json",
  620. "content-type": "application/json",
  621. "Authorization": f"Bearer {self.key}",
  622. },
  623. json={
  624. "messages": self.prompt(b64)
  625. },
  626. )
  627. response = response.json()
  628. return (
  629. response["choices"][0]["message"]["content"].strip(),
  630. response["usage"]["total_tokens"],
  631. )
  632. def describe_with_prompt(self, image, prompt=None):
  633. b64 = self.image2base64(image)
  634. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  635. response = requests.post(
  636. url=self.base_url,
  637. headers={
  638. "accept": "application/json",
  639. "content-type": "application/json",
  640. "Authorization": f"Bearer {self.key}",
  641. },
  642. json={
  643. "messages": vision_prompt,
  644. },
  645. )
  646. response = response.json()
  647. return (
  648. response["choices"][0]["message"]["content"].strip(),
  649. response["usage"]["total_tokens"],
  650. )
  651. def prompt(self, b64):
  652. return [
  653. {
  654. "role": "user",
  655. "content": (
  656. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  657. if self.lang.lower() == "chinese"
  658. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  659. )
  660. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  661. }
  662. ]
  663. def vision_llm_prompt(self, b64, prompt=None):
  664. return [
  665. {
  666. "role": "user",
  667. "content": (
  668. prompt if prompt else vision_llm_describe_prompt()
  669. )
  670. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  671. }
  672. ]
  673. def chat_prompt(self, text, b64):
  674. return [
  675. {
  676. "role": "user",
  677. "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
  678. }
  679. ]
  680. class StepFunCV(GptV4):
  681. def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
  682. if not base_url:
  683. base_url = "https://api.stepfun.com/v1"
  684. self.client = OpenAI(api_key=key, base_url=base_url)
  685. self.model_name = model_name
  686. self.lang = lang
  687. class LmStudioCV(GptV4):
  688. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  689. if not base_url:
  690. raise ValueError("Local llm url cannot be None")
  691. base_url = urljoin(base_url, "v1")
  692. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  693. self.model_name = model_name
  694. self.lang = lang
  695. class OpenAI_APICV(GptV4):
  696. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  697. if not base_url:
  698. raise ValueError("url cannot be None")
  699. base_url = urljoin(base_url, "v1")
  700. self.client = OpenAI(api_key=key, base_url=base_url)
  701. self.model_name = model_name.split("___")[0]
  702. self.lang = lang
  703. class TogetherAICV(GptV4):
  704. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
  705. if not base_url:
  706. base_url = "https://api.together.xyz/v1"
  707. super().__init__(key, model_name, lang, base_url)
  708. class YiCV(GptV4):
  709. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
  710. if not base_url:
  711. base_url = "https://api.lingyiwanwu.com/v1"
  712. super().__init__(key, model_name, lang, base_url)
  713. class SILICONFLOWCV(GptV4):
  714. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",):
  715. if not base_url:
  716. base_url = "https://api.siliconflow.cn/v1"
  717. super().__init__(key, model_name, lang, base_url)
  718. class HunyuanCV(Base):
  719. def __init__(self, key, model_name, lang="Chinese", base_url=None):
  720. from tencentcloud.common import credential
  721. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  722. key = json.loads(key)
  723. sid = key.get("hunyuan_sid", "")
  724. sk = key.get("hunyuan_sk", "")
  725. cred = credential.Credential(sid, sk)
  726. self.model_name = model_name
  727. self.client = hunyuan_client.HunyuanClient(cred, "")
  728. self.lang = lang
  729. def describe(self, image):
  730. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  731. TencentCloudSDKException,
  732. )
  733. from tencentcloud.hunyuan.v20230901 import models
  734. b64 = self.image2base64(image)
  735. req = models.ChatCompletionsRequest()
  736. params = {"Model": self.model_name, "Messages": self.prompt(b64)}
  737. req.from_json_string(json.dumps(params))
  738. ans = ""
  739. try:
  740. response = self.client.ChatCompletions(req)
  741. ans = response.Choices[0].Message.Content
  742. return ans, response.Usage.TotalTokens
  743. except TencentCloudSDKException as e:
  744. return ans + "\n**ERROR**: " + str(e), 0
  745. def describe_with_prompt(self, image, prompt=None):
  746. from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
  747. from tencentcloud.hunyuan.v20230901 import models
  748. b64 = self.image2base64(image)
  749. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  750. req = models.ChatCompletionsRequest()
  751. params = {"Model": self.model_name, "Messages": vision_prompt}
  752. req.from_json_string(json.dumps(params))
  753. ans = ""
  754. try:
  755. response = self.client.ChatCompletions(req)
  756. ans = response.Choices[0].Message.Content
  757. return ans, response.Usage.TotalTokens
  758. except TencentCloudSDKException as e:
  759. return ans + "\n**ERROR**: " + str(e), 0
  760. def prompt(self, b64):
  761. return [
  762. {
  763. "Role": "user",
  764. "Contents": [
  765. {
  766. "Type": "image_url",
  767. "ImageUrl": {
  768. "Url": f"data:image/jpeg;base64,{b64}"
  769. },
  770. },
  771. {
  772. "Type": "text",
  773. "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  774. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  775. },
  776. ],
  777. }
  778. ]
  779. class AnthropicCV(Base):
  780. def __init__(self, key, model_name, base_url=None):
  781. import anthropic
  782. self.client = anthropic.Anthropic(api_key=key)
  783. self.model_name = model_name
  784. self.system = ""
  785. self.max_tokens = 8192
  786. if "haiku" in self.model_name or "opus" in self.model_name:
  787. self.max_tokens = 4096
  788. def prompt(self, b64, prompt):
  789. return [
  790. {
  791. "role": "user",
  792. "content": [
  793. {
  794. "type": "image",
  795. "source": {
  796. "type": "base64",
  797. "media_type": "image/jpeg",
  798. "data": b64,
  799. },
  800. },
  801. {
  802. "type": "text",
  803. "text": prompt
  804. }
  805. ],
  806. }
  807. ]
  808. def describe(self, image):
  809. b64 = self.image2base64(image)
  810. prompt = self.prompt(b64,
  811. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  812. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  813. )
  814. response = self.client.messages.create(
  815. model=self.model_name,
  816. max_tokens=self.max_tokens,
  817. messages=prompt
  818. )
  819. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
  820. def describe_with_prompt(self, image, prompt=None):
  821. b64 = self.image2base64(image)
  822. prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
  823. response = self.client.messages.create(
  824. model=self.model_name,
  825. max_tokens=self.max_tokens,
  826. messages=prompt
  827. )
  828. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
  829. def chat(self, system, history, gen_conf):
  830. if "presence_penalty" in gen_conf:
  831. del gen_conf["presence_penalty"]
  832. if "frequency_penalty" in gen_conf:
  833. del gen_conf["frequency_penalty"]
  834. gen_conf["max_tokens"] = self.max_tokens
  835. ans = ""
  836. try:
  837. response = self.client.messages.create(
  838. model=self.model_name,
  839. messages=history,
  840. system=system,
  841. stream=False,
  842. **gen_conf,
  843. ).to_dict()
  844. ans = response["content"][0]["text"]
  845. if response["stop_reason"] == "max_tokens":
  846. ans += (
  847. "...\nFor the content length reason, it stopped, continue?"
  848. if is_english([ans])
  849. else "······\n由于长度的原因,回答被截断了,要继续吗?"
  850. )
  851. return (
  852. ans,
  853. response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
  854. )
  855. except Exception as e:
  856. return ans + "\n**ERROR**: " + str(e), 0
  857. def chat_streamly(self, system, history, gen_conf):
  858. if "presence_penalty" in gen_conf:
  859. del gen_conf["presence_penalty"]
  860. if "frequency_penalty" in gen_conf:
  861. del gen_conf["frequency_penalty"]
  862. gen_conf["max_tokens"] = self.max_tokens
  863. ans = ""
  864. total_tokens = 0
  865. try:
  866. response = self.client.messages.create(
  867. model=self.model_name,
  868. messages=history,
  869. system=system,
  870. stream=True,
  871. **gen_conf,
  872. )
  873. for res in response:
  874. if res.type == 'content_block_delta':
  875. if res.delta.type == "thinking_delta" and res.delta.thinking:
  876. if ans.find("<think>") < 0:
  877. ans += "<think>"
  878. ans = ans.replace("</think>", "")
  879. ans += res.delta.thinking + "</think>"
  880. else:
  881. text = res.delta.text
  882. ans += text
  883. total_tokens += num_tokens_from_string(text)
  884. yield ans
  885. except Exception as e:
  886. yield ans + "\n**ERROR**: " + str(e)
  887. yield total_tokens
  888. class GPUStackCV(GptV4):
  889. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  890. if not base_url:
  891. raise ValueError("Local llm url cannot be None")
  892. base_url = urljoin(base_url, "v1")
  893. self.client = OpenAI(api_key=key, base_url=base_url)
  894. self.model_name = model_name
  895. self.lang = lang