You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cv_model.py 38KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import io
  18. import json
  19. import os
  20. from abc import ABC
  21. from io import BytesIO
  22. import requests
  23. from ollama import Client
  24. from openai import OpenAI
  25. from openai.lib.azure import AzureOpenAI
  26. from PIL import Image
  27. from zhipuai import ZhipuAI
  28. from api.utils import get_uuid
  29. from api.utils.file_utils import get_project_base_directory
  30. from rag.nlp import is_english
  31. from rag.prompts import vision_llm_describe_prompt
  32. from rag.utils import num_tokens_from_string
  33. class Base(ABC):
  34. def __init__(self, key, model_name):
  35. pass
  36. def describe(self, image):
  37. raise NotImplementedError("Please implement encode method!")
  38. def describe_with_prompt(self, image, prompt=None):
  39. raise NotImplementedError("Please implement encode method!")
  40. def chat(self, system, history, gen_conf, image=""):
  41. if system:
  42. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  43. try:
  44. for his in history:
  45. if his["role"] == "user":
  46. his["content"] = self.chat_prompt(his["content"], image)
  47. response = self.client.chat.completions.create(
  48. model=self.model_name,
  49. messages=history,
  50. temperature=gen_conf.get("temperature", 0.3),
  51. top_p=gen_conf.get("top_p", 0.7)
  52. )
  53. return response.choices[0].message.content.strip(), response.usage.total_tokens
  54. except Exception as e:
  55. return "**ERROR**: " + str(e), 0
  56. def chat_streamly(self, system, history, gen_conf, image=""):
  57. if system:
  58. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  59. ans = ""
  60. tk_count = 0
  61. try:
  62. for his in history:
  63. if his["role"] == "user":
  64. his["content"] = self.chat_prompt(his["content"], image)
  65. response = self.client.chat.completions.create(
  66. model=self.model_name,
  67. messages=history,
  68. temperature=gen_conf.get("temperature", 0.3),
  69. top_p=gen_conf.get("top_p", 0.7),
  70. stream=True
  71. )
  72. for resp in response:
  73. if not resp.choices[0].delta.content:
  74. continue
  75. delta = resp.choices[0].delta.content
  76. ans += delta
  77. if resp.choices[0].finish_reason == "length":
  78. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  79. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  80. tk_count = resp.usage.total_tokens
  81. if resp.choices[0].finish_reason == "stop":
  82. tk_count = resp.usage.total_tokens
  83. yield ans
  84. except Exception as e:
  85. yield ans + "\n**ERROR**: " + str(e)
  86. yield tk_count
  87. def image2base64(self, image):
  88. if isinstance(image, bytes):
  89. return base64.b64encode(image).decode("utf-8")
  90. if isinstance(image, BytesIO):
  91. return base64.b64encode(image.getvalue()).decode("utf-8")
  92. buffered = BytesIO()
  93. try:
  94. image.save(buffered, format="JPEG")
  95. except Exception:
  96. image.save(buffered, format="PNG")
  97. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  98. def prompt(self, b64):
  99. return [
  100. {
  101. "role": "user",
  102. "content": [
  103. {
  104. "type": "image_url",
  105. "image_url": {
  106. "url": f"data:image/jpeg;base64,{b64}"
  107. },
  108. },
  109. {
  110. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  111. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  112. },
  113. ],
  114. }
  115. ]
  116. def vision_llm_prompt(self, b64, prompt=None):
  117. return [
  118. {
  119. "role": "user",
  120. "content": [
  121. {
  122. "type": "image_url",
  123. "image_url": {
  124. "url": f"data:image/jpeg;base64,{b64}"
  125. },
  126. },
  127. {
  128. "type": "text",
  129. "text": prompt if prompt else vision_llm_describe_prompt(),
  130. },
  131. ],
  132. }
  133. ]
  134. def chat_prompt(self, text, b64):
  135. return [
  136. {
  137. "type": "image_url",
  138. "image_url": {
  139. "url": f"data:image/jpeg;base64,{b64}",
  140. },
  141. },
  142. {
  143. "type": "text",
  144. "text": text
  145. },
  146. ]
  147. class GptV4(Base):
  148. def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
  149. if not base_url:
  150. base_url = "https://api.openai.com/v1"
  151. self.client = OpenAI(api_key=key, base_url=base_url)
  152. self.model_name = model_name
  153. self.lang = lang
  154. def describe(self, image):
  155. b64 = self.image2base64(image)
  156. prompt = self.prompt(b64)
  157. for i in range(len(prompt)):
  158. for c in prompt[i]["content"]:
  159. if "text" in c:
  160. c["type"] = "text"
  161. res = self.client.chat.completions.create(
  162. model=self.model_name,
  163. messages=prompt
  164. )
  165. return res.choices[0].message.content.strip(), res.usage.total_tokens
  166. def describe_with_prompt(self, image, prompt=None):
  167. b64 = self.image2base64(image)
  168. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  169. res = self.client.chat.completions.create(
  170. model=self.model_name,
  171. messages=vision_prompt,
  172. )
  173. return res.choices[0].message.content.strip(), res.usage.total_tokens
  174. class AzureGptV4(Base):
  175. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  176. api_key = json.loads(key).get('api_key', '')
  177. api_version = json.loads(key).get('api_version', '2024-02-01')
  178. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  179. self.model_name = model_name
  180. self.lang = lang
  181. def describe(self, image):
  182. b64 = self.image2base64(image)
  183. prompt = self.prompt(b64)
  184. for i in range(len(prompt)):
  185. for c in prompt[i]["content"]:
  186. if "text" in c:
  187. c["type"] = "text"
  188. res = self.client.chat.completions.create(
  189. model=self.model_name,
  190. messages=prompt
  191. )
  192. return res.choices[0].message.content.strip(), res.usage.total_tokens
  193. def describe_with_prompt(self, image, prompt=None):
  194. b64 = self.image2base64(image)
  195. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  196. res = self.client.chat.completions.create(
  197. model=self.model_name,
  198. messages=vision_prompt,
  199. )
  200. return res.choices[0].message.content.strip(), res.usage.total_tokens
  201. class QWenCV(Base):
  202. def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
  203. import dashscope
  204. dashscope.api_key = key
  205. self.model_name = model_name
  206. self.lang = lang
  207. def prompt(self, binary):
  208. # stupid as hell
  209. tmp_dir = get_project_base_directory("tmp")
  210. if not os.path.exists(tmp_dir):
  211. os.mkdir(tmp_dir)
  212. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  213. Image.open(io.BytesIO(binary)).save(path)
  214. return [
  215. {
  216. "role": "user",
  217. "content": [
  218. {
  219. "image": f"file://{path}"
  220. },
  221. {
  222. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  223. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  224. },
  225. ],
  226. }
  227. ]
  228. def vision_llm_prompt(self, binary, prompt=None):
  229. # stupid as hell
  230. tmp_dir = get_project_base_directory("tmp")
  231. if not os.path.exists(tmp_dir):
  232. os.mkdir(tmp_dir)
  233. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  234. Image.open(io.BytesIO(binary)).save(path)
  235. return [
  236. {
  237. "role": "user",
  238. "content": [
  239. {
  240. "image": f"file://{path}"
  241. },
  242. {
  243. "text": prompt if prompt else vision_llm_describe_prompt(),
  244. },
  245. ],
  246. }
  247. ]
  248. def chat_prompt(self, text, b64):
  249. return [
  250. {"image": f"{b64}"},
  251. {"text": text},
  252. ]
  253. def describe(self, image):
  254. from http import HTTPStatus
  255. from dashscope import MultiModalConversation
  256. response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
  257. if response.status_code == HTTPStatus.OK:
  258. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  259. return response.message, 0
  260. def describe_with_prompt(self, image, prompt=None):
  261. from http import HTTPStatus
  262. from dashscope import MultiModalConversation
  263. vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
  264. response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
  265. if response.status_code == HTTPStatus.OK:
  266. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  267. return response.message, 0
  268. def chat(self, system, history, gen_conf, image=""):
  269. from http import HTTPStatus
  270. from dashscope import MultiModalConversation
  271. if system:
  272. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  273. for his in history:
  274. if his["role"] == "user":
  275. his["content"] = self.chat_prompt(his["content"], image)
  276. response = MultiModalConversation.call(model=self.model_name, messages=history,
  277. temperature=gen_conf.get("temperature", 0.3),
  278. top_p=gen_conf.get("top_p", 0.7))
  279. ans = ""
  280. tk_count = 0
  281. if response.status_code == HTTPStatus.OK:
  282. ans += response.output.choices[0]['message']['content']
  283. tk_count += response.usage.total_tokens
  284. if response.output.choices[0].get("finish_reason", "") == "length":
  285. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  286. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  287. return ans, tk_count
  288. return "**ERROR**: " + response.message, tk_count
  289. def chat_streamly(self, system, history, gen_conf, image=""):
  290. from http import HTTPStatus
  291. from dashscope import MultiModalConversation
  292. if system:
  293. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  294. for his in history:
  295. if his["role"] == "user":
  296. his["content"] = self.chat_prompt(his["content"], image)
  297. ans = ""
  298. tk_count = 0
  299. try:
  300. response = MultiModalConversation.call(model=self.model_name, messages=history,
  301. temperature=gen_conf.get("temperature", 0.3),
  302. top_p=gen_conf.get("top_p", 0.7),
  303. stream=True)
  304. for resp in response:
  305. if resp.status_code == HTTPStatus.OK:
  306. ans = resp.output.choices[0]['message']['content']
  307. tk_count = resp.usage.total_tokens
  308. if resp.output.choices[0].get("finish_reason", "") == "length":
  309. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  310. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  311. yield ans
  312. else:
  313. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
  314. "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  315. except Exception as e:
  316. yield ans + "\n**ERROR**: " + str(e)
  317. yield tk_count
  318. class Zhipu4V(Base):
  319. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  320. self.client = ZhipuAI(api_key=key)
  321. self.model_name = model_name
  322. self.lang = lang
  323. def describe(self, image):
  324. b64 = self.image2base64(image)
  325. prompt = self.prompt(b64)
  326. prompt[0]["content"][1]["type"] = "text"
  327. res = self.client.chat.completions.create(
  328. model=self.model_name,
  329. messages=prompt,
  330. )
  331. return res.choices[0].message.content.strip(), res.usage.total_tokens
  332. def describe_with_prompt(self, image, prompt=None):
  333. b64 = self.image2base64(image)
  334. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  335. res = self.client.chat.completions.create(
  336. model=self.model_name,
  337. messages=vision_prompt
  338. )
  339. return res.choices[0].message.content.strip(), res.usage.total_tokens
  340. def chat(self, system, history, gen_conf, image=""):
  341. if system:
  342. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  343. try:
  344. for his in history:
  345. if his["role"] == "user":
  346. his["content"] = self.chat_prompt(his["content"], image)
  347. response = self.client.chat.completions.create(
  348. model=self.model_name,
  349. messages=history,
  350. temperature=gen_conf.get("temperature", 0.3),
  351. top_p=gen_conf.get("top_p", 0.7)
  352. )
  353. return response.choices[0].message.content.strip(), response.usage.total_tokens
  354. except Exception as e:
  355. return "**ERROR**: " + str(e), 0
  356. def chat_streamly(self, system, history, gen_conf, image=""):
  357. if system:
  358. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  359. ans = ""
  360. tk_count = 0
  361. try:
  362. for his in history:
  363. if his["role"] == "user":
  364. his["content"] = self.chat_prompt(his["content"], image)
  365. response = self.client.chat.completions.create(
  366. model=self.model_name,
  367. messages=history,
  368. temperature=gen_conf.get("temperature", 0.3),
  369. top_p=gen_conf.get("top_p", 0.7),
  370. stream=True
  371. )
  372. for resp in response:
  373. if not resp.choices[0].delta.content:
  374. continue
  375. delta = resp.choices[0].delta.content
  376. ans += delta
  377. if resp.choices[0].finish_reason == "length":
  378. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  379. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  380. tk_count = resp.usage.total_tokens
  381. if resp.choices[0].finish_reason == "stop":
  382. tk_count = resp.usage.total_tokens
  383. yield ans
  384. except Exception as e:
  385. yield ans + "\n**ERROR**: " + str(e)
  386. yield tk_count
  387. class OllamaCV(Base):
  388. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  389. self.client = Client(host=kwargs["base_url"])
  390. self.model_name = model_name
  391. self.lang = lang
  392. def describe(self, image):
  393. prompt = self.prompt("")
  394. try:
  395. response = self.client.generate(
  396. model=self.model_name,
  397. prompt=prompt[0]["content"][1]["text"],
  398. images=[image]
  399. )
  400. ans = response["response"].strip()
  401. return ans, 128
  402. except Exception as e:
  403. return "**ERROR**: " + str(e), 0
  404. def describe_with_prompt(self, image, prompt=None):
  405. vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
  406. try:
  407. response = self.client.generate(
  408. model=self.model_name,
  409. prompt=vision_prompt[0]["content"][1]["text"],
  410. images=[image],
  411. )
  412. ans = response["response"].strip()
  413. return ans, 128
  414. except Exception as e:
  415. return "**ERROR**: " + str(e), 0
  416. def chat(self, system, history, gen_conf, image=""):
  417. if system:
  418. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  419. try:
  420. for his in history:
  421. if his["role"] == "user":
  422. his["images"] = [image]
  423. options = {}
  424. if "temperature" in gen_conf:
  425. options["temperature"] = gen_conf["temperature"]
  426. if "top_p" in gen_conf:
  427. options["top_k"] = gen_conf["top_p"]
  428. if "presence_penalty" in gen_conf:
  429. options["presence_penalty"] = gen_conf["presence_penalty"]
  430. if "frequency_penalty" in gen_conf:
  431. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  432. response = self.client.chat(
  433. model=self.model_name,
  434. messages=history,
  435. options=options,
  436. keep_alive=-1
  437. )
  438. ans = response["message"]["content"].strip()
  439. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  440. except Exception as e:
  441. return "**ERROR**: " + str(e), 0
  442. def chat_streamly(self, system, history, gen_conf, image=""):
  443. if system:
  444. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  445. for his in history:
  446. if his["role"] == "user":
  447. his["images"] = [image]
  448. options = {}
  449. if "temperature" in gen_conf:
  450. options["temperature"] = gen_conf["temperature"]
  451. if "top_p" in gen_conf:
  452. options["top_k"] = gen_conf["top_p"]
  453. if "presence_penalty" in gen_conf:
  454. options["presence_penalty"] = gen_conf["presence_penalty"]
  455. if "frequency_penalty" in gen_conf:
  456. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  457. ans = ""
  458. try:
  459. response = self.client.chat(
  460. model=self.model_name,
  461. messages=history,
  462. stream=True,
  463. options=options,
  464. keep_alive=-1
  465. )
  466. for resp in response:
  467. if resp["done"]:
  468. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  469. ans += resp["message"]["content"]
  470. yield ans
  471. except Exception as e:
  472. yield ans + "\n**ERROR**: " + str(e)
  473. yield 0
  474. class LocalAICV(GptV4):
  475. def __init__(self, key, model_name, base_url, lang="Chinese"):
  476. if not base_url:
  477. raise ValueError("Local cv model url cannot be None")
  478. if base_url.split("/")[-1] != "v1":
  479. base_url = os.path.join(base_url, "v1")
  480. self.client = OpenAI(api_key="empty", base_url=base_url)
  481. self.model_name = model_name.split("___")[0]
  482. self.lang = lang
  483. class XinferenceCV(Base):
  484. def __init__(self, key, model_name="", lang="Chinese", base_url=""):
  485. if base_url.split("/")[-1] != "v1":
  486. base_url = os.path.join(base_url, "v1")
  487. self.client = OpenAI(api_key=key, base_url=base_url)
  488. self.model_name = model_name
  489. self.lang = lang
  490. def describe(self, image):
  491. b64 = self.image2base64(image)
  492. res = self.client.chat.completions.create(
  493. model=self.model_name,
  494. messages=self.prompt(b64)
  495. )
  496. return res.choices[0].message.content.strip(), res.usage.total_tokens
  497. def describe_with_prompt(self, image, prompt=None):
  498. b64 = self.image2base64(image)
  499. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  500. res = self.client.chat.completions.create(
  501. model=self.model_name,
  502. messages=vision_prompt,
  503. )
  504. return res.choices[0].message.content.strip(), res.usage.total_tokens
  505. class GeminiCV(Base):
  506. def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
  507. from google.generativeai import GenerativeModel, client
  508. client.configure(api_key=key)
  509. _client = client.get_default_generative_client()
  510. self.model_name = model_name
  511. self.model = GenerativeModel(model_name=self.model_name)
  512. self.model._client = _client
  513. self.lang = lang
  514. def describe(self, image):
  515. from PIL.Image import open
  516. prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
  517. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  518. b64 = self.image2base64(image)
  519. img = open(BytesIO(base64.b64decode(b64)))
  520. input = [prompt, img]
  521. res = self.model.generate_content(
  522. input
  523. )
  524. return res.text, res.usage_metadata.total_token_count
  525. def describe_with_prompt(self, image, prompt=None):
  526. from PIL.Image import open
  527. b64 = self.image2base64(image)
  528. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  529. img = open(BytesIO(base64.b64decode(b64)))
  530. input = [vision_prompt, img]
  531. res = self.model.generate_content(
  532. input,
  533. )
  534. return res.text, res.usage_metadata.total_token_count
  535. def chat(self, system, history, gen_conf, image=""):
  536. from transformers import GenerationConfig
  537. if system:
  538. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  539. try:
  540. for his in history:
  541. if his["role"] == "assistant":
  542. his["role"] = "model"
  543. his["parts"] = [his["content"]]
  544. his.pop("content")
  545. if his["role"] == "user":
  546. his["parts"] = [his["content"]]
  547. his.pop("content")
  548. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  549. response = self.model.generate_content(history, generation_config=GenerationConfig(
  550. temperature=gen_conf.get("temperature", 0.3),
  551. top_p=gen_conf.get("top_p", 0.7)))
  552. ans = response.text
  553. return ans, response.usage_metadata.total_token_count
  554. except Exception as e:
  555. return "**ERROR**: " + str(e), 0
  556. def chat_streamly(self, system, history, gen_conf, image=""):
  557. from transformers import GenerationConfig
  558. if system:
  559. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  560. ans = ""
  561. try:
  562. for his in history:
  563. if his["role"] == "assistant":
  564. his["role"] = "model"
  565. his["parts"] = [his["content"]]
  566. his.pop("content")
  567. if his["role"] == "user":
  568. his["parts"] = [his["content"]]
  569. his.pop("content")
  570. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  571. response = self.model.generate_content(history, generation_config=GenerationConfig(
  572. temperature=gen_conf.get("temperature", 0.3),
  573. top_p=gen_conf.get("top_p", 0.7)), stream=True)
  574. for resp in response:
  575. if not resp.text:
  576. continue
  577. ans += resp.text
  578. yield ans
  579. except Exception as e:
  580. yield ans + "\n**ERROR**: " + str(e)
  581. yield response._chunks[-1].usage_metadata.total_token_count
  582. class OpenRouterCV(GptV4):
  583. def __init__(
  584. self,
  585. key,
  586. model_name,
  587. lang="Chinese",
  588. base_url="https://openrouter.ai/api/v1",
  589. ):
  590. if not base_url:
  591. base_url = "https://openrouter.ai/api/v1"
  592. self.client = OpenAI(api_key=key, base_url=base_url)
  593. self.model_name = model_name
  594. self.lang = lang
  595. class LocalCV(Base):
  596. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  597. pass
  598. def describe(self, image):
  599. return "", 0
  600. class NvidiaCV(Base):
  601. def __init__(
  602. self,
  603. key,
  604. model_name,
  605. lang="Chinese",
  606. base_url="https://ai.api.nvidia.com/v1/vlm",
  607. ):
  608. if not base_url:
  609. base_url = ("https://ai.api.nvidia.com/v1/vlm",)
  610. self.lang = lang
  611. factory, llm_name = model_name.split("/")
  612. if factory != "liuhaotian":
  613. self.base_url = os.path.join(base_url, factory, llm_name)
  614. else:
  615. self.base_url = os.path.join(
  616. base_url, "community", llm_name.replace("-v1.6", "16")
  617. )
  618. self.key = key
  619. def describe(self, image):
  620. b64 = self.image2base64(image)
  621. response = requests.post(
  622. url=self.base_url,
  623. headers={
  624. "accept": "application/json",
  625. "content-type": "application/json",
  626. "Authorization": f"Bearer {self.key}",
  627. },
  628. json={
  629. "messages": self.prompt(b64)
  630. },
  631. )
  632. response = response.json()
  633. return (
  634. response["choices"][0]["message"]["content"].strip(),
  635. response["usage"]["total_tokens"],
  636. )
  637. def describe_with_prompt(self, image, prompt=None):
  638. b64 = self.image2base64(image)
  639. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  640. response = requests.post(
  641. url=self.base_url,
  642. headers={
  643. "accept": "application/json",
  644. "content-type": "application/json",
  645. "Authorization": f"Bearer {self.key}",
  646. },
  647. json={
  648. "messages": vision_prompt,
  649. },
  650. )
  651. response = response.json()
  652. return (
  653. response["choices"][0]["message"]["content"].strip(),
  654. response["usage"]["total_tokens"],
  655. )
  656. def prompt(self, b64):
  657. return [
  658. {
  659. "role": "user",
  660. "content": (
  661. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  662. if self.lang.lower() == "chinese"
  663. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  664. )
  665. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  666. }
  667. ]
  668. def vision_llm_prompt(self, b64, prompt=None):
  669. return [
  670. {
  671. "role": "user",
  672. "content": (
  673. prompt if prompt else vision_llm_describe_prompt()
  674. )
  675. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  676. }
  677. ]
  678. def chat_prompt(self, text, b64):
  679. return [
  680. {
  681. "role": "user",
  682. "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
  683. }
  684. ]
  685. class StepFunCV(GptV4):
  686. def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
  687. if not base_url:
  688. base_url = "https://api.stepfun.com/v1"
  689. self.client = OpenAI(api_key=key, base_url=base_url)
  690. self.model_name = model_name
  691. self.lang = lang
  692. class LmStudioCV(GptV4):
  693. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  694. if not base_url:
  695. raise ValueError("Local llm url cannot be None")
  696. if base_url.split("/")[-1] != "v1":
  697. base_url = os.path.join(base_url, "v1")
  698. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  699. self.model_name = model_name
  700. self.lang = lang
  701. class OpenAI_APICV(GptV4):
  702. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  703. if not base_url:
  704. raise ValueError("url cannot be None")
  705. if base_url.split("/")[-1] != "v1":
  706. base_url = os.path.join(base_url, "v1")
  707. self.client = OpenAI(api_key=key, base_url=base_url)
  708. self.model_name = model_name.split("___")[0]
  709. self.lang = lang
  710. class TogetherAICV(GptV4):
  711. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
  712. if not base_url:
  713. base_url = "https://api.together.xyz/v1"
  714. super().__init__(key, model_name, lang, base_url)
  715. class YiCV(GptV4):
  716. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
  717. if not base_url:
  718. base_url = "https://api.lingyiwanwu.com/v1"
  719. super().__init__(key, model_name, lang, base_url)
  720. class HunyuanCV(Base):
  721. def __init__(self, key, model_name, lang="Chinese", base_url=None):
  722. from tencentcloud.common import credential
  723. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  724. key = json.loads(key)
  725. sid = key.get("hunyuan_sid", "")
  726. sk = key.get("hunyuan_sk", "")
  727. cred = credential.Credential(sid, sk)
  728. self.model_name = model_name
  729. self.client = hunyuan_client.HunyuanClient(cred, "")
  730. self.lang = lang
  731. def describe(self, image):
  732. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  733. TencentCloudSDKException,
  734. )
  735. from tencentcloud.hunyuan.v20230901 import models
  736. b64 = self.image2base64(image)
  737. req = models.ChatCompletionsRequest()
  738. params = {"Model": self.model_name, "Messages": self.prompt(b64)}
  739. req.from_json_string(json.dumps(params))
  740. ans = ""
  741. try:
  742. response = self.client.ChatCompletions(req)
  743. ans = response.Choices[0].Message.Content
  744. return ans, response.Usage.TotalTokens
  745. except TencentCloudSDKException as e:
  746. return ans + "\n**ERROR**: " + str(e), 0
  747. def describe_with_prompt(self, image, prompt=None):
  748. from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
  749. from tencentcloud.hunyuan.v20230901 import models
  750. b64 = self.image2base64(image)
  751. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  752. req = models.ChatCompletionsRequest()
  753. params = {"Model": self.model_name, "Messages": vision_prompt}
  754. req.from_json_string(json.dumps(params))
  755. ans = ""
  756. try:
  757. response = self.client.ChatCompletions(req)
  758. ans = response.Choices[0].Message.Content
  759. return ans, response.Usage.TotalTokens
  760. except TencentCloudSDKException as e:
  761. return ans + "\n**ERROR**: " + str(e), 0
  762. def prompt(self, b64):
  763. return [
  764. {
  765. "Role": "user",
  766. "Contents": [
  767. {
  768. "Type": "image_url",
  769. "ImageUrl": {
  770. "Url": f"data:image/jpeg;base64,{b64}"
  771. },
  772. },
  773. {
  774. "Type": "text",
  775. "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  776. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  777. },
  778. ],
  779. }
  780. ]
  781. class AnthropicCV(Base):
  782. def __init__(self, key, model_name, base_url=None):
  783. import anthropic
  784. self.client = anthropic.Anthropic(api_key=key)
  785. self.model_name = model_name
  786. self.system = ""
  787. self.max_tokens = 8192
  788. if "haiku" in self.model_name or "opus" in self.model_name:
  789. self.max_tokens = 4096
  790. def prompt(self, b64, prompt):
  791. return [
  792. {
  793. "role": "user",
  794. "content": [
  795. {
  796. "type": "image",
  797. "source": {
  798. "type": "base64",
  799. "media_type": "image/jpeg",
  800. "data": b64,
  801. },
  802. },
  803. {
  804. "type": "text",
  805. "text": prompt
  806. }
  807. ],
  808. }
  809. ]
  810. def describe(self, image):
  811. b64 = self.image2base64(image)
  812. prompt = self.prompt(b64,
  813. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  814. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  815. )
  816. response = self.client.messages.create(
  817. model=self.model_name,
  818. max_tokens=self.max_tokens,
  819. messages=prompt
  820. )
  821. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
  822. def describe_with_prompt(self, image, prompt=None):
  823. b64 = self.image2base64(image)
  824. prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
  825. response = self.client.messages.create(
  826. model=self.model_name,
  827. max_tokens=self.max_tokens,
  828. messages=prompt
  829. )
  830. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
  831. def chat(self, system, history, gen_conf):
  832. if "presence_penalty" in gen_conf:
  833. del gen_conf["presence_penalty"]
  834. if "frequency_penalty" in gen_conf:
  835. del gen_conf["frequency_penalty"]
  836. gen_conf["max_tokens"] = self.max_tokens
  837. ans = ""
  838. try:
  839. response = self.client.messages.create(
  840. model=self.model_name,
  841. messages=history,
  842. system=system,
  843. stream=False,
  844. **gen_conf,
  845. ).to_dict()
  846. ans = response["content"][0]["text"]
  847. if response["stop_reason"] == "max_tokens":
  848. ans += (
  849. "...\nFor the content length reason, it stopped, continue?"
  850. if is_english([ans])
  851. else "······\n由于长度的原因,回答被截断了,要继续吗?"
  852. )
  853. return (
  854. ans,
  855. response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
  856. )
  857. except Exception as e:
  858. return ans + "\n**ERROR**: " + str(e), 0
  859. def chat_streamly(self, system, history, gen_conf):
  860. if "presence_penalty" in gen_conf:
  861. del gen_conf["presence_penalty"]
  862. if "frequency_penalty" in gen_conf:
  863. del gen_conf["frequency_penalty"]
  864. gen_conf["max_tokens"] = self.max_tokens
  865. ans = ""
  866. total_tokens = 0
  867. try:
  868. response = self.client.messages.create(
  869. model=self.model_name,
  870. messages=history,
  871. system=system,
  872. stream=True,
  873. **gen_conf,
  874. )
  875. for res in response:
  876. if res.type == 'content_block_delta':
  877. if res.delta.type == "thinking_delta" and res.delta.thinking:
  878. if ans.find("<think>") < 0:
  879. ans += "<think>"
  880. ans = ans.replace("</think>", "")
  881. ans += res.delta.thinking + "</think>"
  882. else:
  883. text = res.delta.text
  884. ans += text
  885. total_tokens += num_tokens_from_string(text)
  886. yield ans
  887. except Exception as e:
  888. yield ans + "\n**ERROR**: " + str(e)
  889. yield total_tokens