Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

cv_model.py 39KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import io
  18. import json
  19. import os
  20. from abc import ABC
  21. from io import BytesIO
  22. import requests
  23. from ollama import Client
  24. from openai import OpenAI
  25. from openai.lib.azure import AzureOpenAI
  26. from PIL import Image
  27. from zhipuai import ZhipuAI
  28. from api.utils import get_uuid
  29. from api.utils.file_utils import get_project_base_directory
  30. from rag.nlp import is_english
  31. from rag.prompts import vision_llm_describe_prompt
  32. from rag.utils import num_tokens_from_string
  33. class Base(ABC):
  34. def __init__(self, key, model_name):
  35. pass
  36. def describe(self, image):
  37. raise NotImplementedError("Please implement encode method!")
  38. def describe_with_prompt(self, image, prompt=None):
  39. raise NotImplementedError("Please implement encode method!")
  40. def chat(self, system, history, gen_conf, image=""):
  41. if system:
  42. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  43. try:
  44. for his in history:
  45. if his["role"] == "user":
  46. his["content"] = self.chat_prompt(his["content"], image)
  47. response = self.client.chat.completions.create(
  48. model=self.model_name,
  49. messages=history,
  50. temperature=gen_conf.get("temperature", 0.3),
  51. top_p=gen_conf.get("top_p", 0.7)
  52. )
  53. return response.choices[0].message.content.strip(), response.usage.total_tokens
  54. except Exception as e:
  55. return "**ERROR**: " + str(e), 0
  56. def chat_streamly(self, system, history, gen_conf, image=""):
  57. if system:
  58. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  59. ans = ""
  60. tk_count = 0
  61. try:
  62. for his in history:
  63. if his["role"] == "user":
  64. his["content"] = self.chat_prompt(his["content"], image)
  65. response = self.client.chat.completions.create(
  66. model=self.model_name,
  67. messages=history,
  68. temperature=gen_conf.get("temperature", 0.3),
  69. top_p=gen_conf.get("top_p", 0.7),
  70. stream=True
  71. )
  72. for resp in response:
  73. if not resp.choices[0].delta.content:
  74. continue
  75. delta = resp.choices[0].delta.content
  76. ans += delta
  77. if resp.choices[0].finish_reason == "length":
  78. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  79. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  80. tk_count = resp.usage.total_tokens
  81. if resp.choices[0].finish_reason == "stop":
  82. tk_count = resp.usage.total_tokens
  83. yield ans
  84. except Exception as e:
  85. yield ans + "\n**ERROR**: " + str(e)
  86. yield tk_count
  87. def image2base64(self, image):
  88. if isinstance(image, bytes):
  89. return base64.b64encode(image).decode("utf-8")
  90. if isinstance(image, BytesIO):
  91. return base64.b64encode(image.getvalue()).decode("utf-8")
  92. buffered = BytesIO()
  93. try:
  94. image.save(buffered, format="JPEG")
  95. except Exception:
  96. image.save(buffered, format="PNG")
  97. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  98. def prompt(self, b64):
  99. return [
  100. {
  101. "role": "user",
  102. "content": [
  103. {
  104. "type": "image_url",
  105. "image_url": {
  106. "url": f"data:image/jpeg;base64,{b64}"
  107. },
  108. },
  109. {
  110. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  111. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  112. },
  113. ],
  114. }
  115. ]
  116. def vision_llm_prompt(self, b64, prompt=None):
  117. return [
  118. {
  119. "role": "user",
  120. "content": [
  121. {
  122. "type": "image_url",
  123. "image_url": {
  124. "url": f"data:image/jpeg;base64,{b64}"
  125. },
  126. },
  127. {
  128. "type": "text",
  129. "text": prompt if prompt else vision_llm_describe_prompt(),
  130. },
  131. ],
  132. }
  133. ]
  134. def chat_prompt(self, text, b64):
  135. return [
  136. {
  137. "type": "image_url",
  138. "image_url": {
  139. "url": f"data:image/jpeg;base64,{b64}",
  140. },
  141. },
  142. {
  143. "type": "text",
  144. "text": text
  145. },
  146. ]
  147. class GptV4(Base):
  148. def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
  149. if not base_url:
  150. base_url = "https://api.openai.com/v1"
  151. self.client = OpenAI(api_key=key, base_url=base_url)
  152. self.model_name = model_name
  153. self.lang = lang
  154. def describe(self, image):
  155. b64 = self.image2base64(image)
  156. prompt = self.prompt(b64)
  157. for i in range(len(prompt)):
  158. for c in prompt[i]["content"]:
  159. if "text" in c:
  160. c["type"] = "text"
  161. res = self.client.chat.completions.create(
  162. model=self.model_name,
  163. messages=prompt
  164. )
  165. return res.choices[0].message.content.strip(), res.usage.total_tokens
  166. def describe_with_prompt(self, image, prompt=None):
  167. b64 = self.image2base64(image)
  168. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  169. res = self.client.chat.completions.create(
  170. model=self.model_name,
  171. messages=vision_prompt,
  172. )
  173. return res.choices[0].message.content.strip(), res.usage.total_tokens
  174. class AzureGptV4(Base):
  175. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  176. api_key = json.loads(key).get('api_key', '')
  177. api_version = json.loads(key).get('api_version', '2024-02-01')
  178. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  179. self.model_name = model_name
  180. self.lang = lang
  181. def describe(self, image):
  182. b64 = self.image2base64(image)
  183. prompt = self.prompt(b64)
  184. for i in range(len(prompt)):
  185. for c in prompt[i]["content"]:
  186. if "text" in c:
  187. c["type"] = "text"
  188. res = self.client.chat.completions.create(
  189. model=self.model_name,
  190. messages=prompt
  191. )
  192. return res.choices[0].message.content.strip(), res.usage.total_tokens
  193. def describe_with_prompt(self, image, prompt=None):
  194. b64 = self.image2base64(image)
  195. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  196. res = self.client.chat.completions.create(
  197. model=self.model_name,
  198. messages=vision_prompt,
  199. )
  200. return res.choices[0].message.content.strip(), res.usage.total_tokens
  201. class QWenCV(Base):
  202. def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
  203. import dashscope
  204. dashscope.api_key = key
  205. self.model_name = model_name
  206. self.lang = lang
  207. def prompt(self, binary):
  208. # stupid as hell
  209. tmp_dir = get_project_base_directory("tmp")
  210. if not os.path.exists(tmp_dir):
  211. os.mkdir(tmp_dir)
  212. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  213. Image.open(io.BytesIO(binary)).save(path)
  214. return [
  215. {
  216. "role": "user",
  217. "content": [
  218. {
  219. "image": f"file://{path}"
  220. },
  221. {
  222. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  223. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  224. },
  225. ],
  226. }
  227. ]
  228. def vision_llm_prompt(self, binary, prompt=None):
  229. # stupid as hell
  230. tmp_dir = get_project_base_directory("tmp")
  231. if not os.path.exists(tmp_dir):
  232. os.mkdir(tmp_dir)
  233. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  234. Image.open(io.BytesIO(binary)).save(path)
  235. return [
  236. {
  237. "role": "user",
  238. "content": [
  239. {
  240. "image": f"file://{path}"
  241. },
  242. {
  243. "text": prompt if prompt else vision_llm_describe_prompt(),
  244. },
  245. ],
  246. }
  247. ]
  248. def chat_prompt(self, text, b64):
  249. return [
  250. {"image": f"{b64}"},
  251. {"text": text},
  252. ]
  253. def describe(self, image):
  254. from http import HTTPStatus
  255. from dashscope import MultiModalConversation
  256. response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
  257. if response.status_code == HTTPStatus.OK:
  258. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  259. return response.message, 0
  260. def describe_with_prompt(self, image, prompt=None):
  261. from http import HTTPStatus
  262. from dashscope import MultiModalConversation
  263. vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
  264. response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
  265. if response.status_code == HTTPStatus.OK:
  266. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  267. return response.message, 0
  268. def chat(self, system, history, gen_conf, image=""):
  269. from http import HTTPStatus
  270. from dashscope import MultiModalConversation
  271. if system:
  272. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  273. for his in history:
  274. if his["role"] == "user":
  275. his["content"] = self.chat_prompt(his["content"], image)
  276. response = MultiModalConversation.call(model=self.model_name, messages=history,
  277. temperature=gen_conf.get("temperature", 0.3),
  278. top_p=gen_conf.get("top_p", 0.7))
  279. ans = ""
  280. tk_count = 0
  281. if response.status_code == HTTPStatus.OK:
  282. ans += response.output.choices[0]['message']['content']
  283. tk_count += response.usage.total_tokens
  284. if response.output.choices[0].get("finish_reason", "") == "length":
  285. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  286. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  287. return ans, tk_count
  288. return "**ERROR**: " + response.message, tk_count
  289. def chat_streamly(self, system, history, gen_conf, image=""):
  290. from http import HTTPStatus
  291. from dashscope import MultiModalConversation
  292. if system:
  293. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  294. for his in history:
  295. if his["role"] == "user":
  296. his["content"] = self.chat_prompt(his["content"], image)
  297. ans = ""
  298. tk_count = 0
  299. try:
  300. response = MultiModalConversation.call(model=self.model_name, messages=history,
  301. temperature=gen_conf.get("temperature", 0.3),
  302. top_p=gen_conf.get("top_p", 0.7),
  303. stream=True)
  304. for resp in response:
  305. if resp.status_code == HTTPStatus.OK:
  306. ans = resp.output.choices[0]['message']['content']
  307. tk_count = resp.usage.total_tokens
  308. if resp.output.choices[0].get("finish_reason", "") == "length":
  309. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  310. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  311. yield ans
  312. else:
  313. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
  314. "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  315. except Exception as e:
  316. yield ans + "\n**ERROR**: " + str(e)
  317. yield tk_count
  318. class Zhipu4V(Base):
  319. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  320. self.client = ZhipuAI(api_key=key)
  321. self.model_name = model_name
  322. self.lang = lang
  323. def describe(self, image):
  324. b64 = self.image2base64(image)
  325. prompt = self.prompt(b64)
  326. prompt[0]["content"][1]["type"] = "text"
  327. res = self.client.chat.completions.create(
  328. model=self.model_name,
  329. messages=prompt,
  330. )
  331. return res.choices[0].message.content.strip(), res.usage.total_tokens
  332. def describe_with_prompt(self, image, prompt=None):
  333. b64 = self.image2base64(image)
  334. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  335. res = self.client.chat.completions.create(
  336. model=self.model_name,
  337. messages=vision_prompt
  338. )
  339. return res.choices[0].message.content.strip(), res.usage.total_tokens
  340. def chat(self, system, history, gen_conf, image=""):
  341. if system:
  342. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  343. try:
  344. for his in history:
  345. if his["role"] == "user":
  346. his["content"] = self.chat_prompt(his["content"], image)
  347. response = self.client.chat.completions.create(
  348. model=self.model_name,
  349. messages=history,
  350. temperature=gen_conf.get("temperature", 0.3),
  351. top_p=gen_conf.get("top_p", 0.7)
  352. )
  353. return response.choices[0].message.content.strip(), response.usage.total_tokens
  354. except Exception as e:
  355. return "**ERROR**: " + str(e), 0
  356. def chat_streamly(self, system, history, gen_conf, image=""):
  357. if system:
  358. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  359. ans = ""
  360. tk_count = 0
  361. try:
  362. for his in history:
  363. if his["role"] == "user":
  364. his["content"] = self.chat_prompt(his["content"], image)
  365. response = self.client.chat.completions.create(
  366. model=self.model_name,
  367. messages=history,
  368. temperature=gen_conf.get("temperature", 0.3),
  369. top_p=gen_conf.get("top_p", 0.7),
  370. stream=True
  371. )
  372. for resp in response:
  373. if not resp.choices[0].delta.content:
  374. continue
  375. delta = resp.choices[0].delta.content
  376. ans += delta
  377. if resp.choices[0].finish_reason == "length":
  378. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  379. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  380. tk_count = resp.usage.total_tokens
  381. if resp.choices[0].finish_reason == "stop":
  382. tk_count = resp.usage.total_tokens
  383. yield ans
  384. except Exception as e:
  385. yield ans + "\n**ERROR**: " + str(e)
  386. yield tk_count
  387. class OllamaCV(Base):
  388. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  389. self.client = Client(host=kwargs["base_url"])
  390. self.model_name = model_name
  391. self.lang = lang
  392. def describe(self, image):
  393. prompt = self.prompt("")
  394. try:
  395. response = self.client.generate(
  396. model=self.model_name,
  397. prompt=prompt[0]["content"][1]["text"],
  398. images=[image]
  399. )
  400. ans = response["response"].strip()
  401. return ans, 128
  402. except Exception as e:
  403. return "**ERROR**: " + str(e), 0
  404. def describe_with_prompt(self, image, prompt=None):
  405. vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
  406. try:
  407. response = self.client.generate(
  408. model=self.model_name,
  409. prompt=vision_prompt[0]["content"][1]["text"],
  410. images=[image],
  411. )
  412. ans = response["response"].strip()
  413. return ans, 128
  414. except Exception as e:
  415. return "**ERROR**: " + str(e), 0
  416. def chat(self, system, history, gen_conf, image=""):
  417. if system:
  418. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  419. try:
  420. for his in history:
  421. if his["role"] == "user":
  422. his["images"] = [image]
  423. options = {}
  424. if "temperature" in gen_conf:
  425. options["temperature"] = gen_conf["temperature"]
  426. if "top_p" in gen_conf:
  427. options["top_k"] = gen_conf["top_p"]
  428. if "presence_penalty" in gen_conf:
  429. options["presence_penalty"] = gen_conf["presence_penalty"]
  430. if "frequency_penalty" in gen_conf:
  431. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  432. response = self.client.chat(
  433. model=self.model_name,
  434. messages=history,
  435. options=options
  436. )
  437. ans = response["message"]["content"].strip()
  438. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  439. except Exception as e:
  440. return "**ERROR**: " + str(e), 0
  441. def chat_streamly(self, system, history, gen_conf, image=""):
  442. if system:
  443. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  444. for his in history:
  445. if his["role"] == "user":
  446. his["images"] = [image]
  447. options = {}
  448. if "temperature" in gen_conf:
  449. options["temperature"] = gen_conf["temperature"]
  450. if "top_p" in gen_conf:
  451. options["top_k"] = gen_conf["top_p"]
  452. if "presence_penalty" in gen_conf:
  453. options["presence_penalty"] = gen_conf["presence_penalty"]
  454. if "frequency_penalty" in gen_conf:
  455. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  456. ans = ""
  457. try:
  458. response = self.client.chat(
  459. model=self.model_name,
  460. messages=history,
  461. stream=True,
  462. options=options
  463. )
  464. for resp in response:
  465. if resp["done"]:
  466. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  467. ans += resp["message"]["content"]
  468. yield ans
  469. except Exception as e:
  470. yield ans + "\n**ERROR**: " + str(e)
  471. yield 0
  472. class LocalAICV(GptV4):
  473. def __init__(self, key, model_name, base_url, lang="Chinese"):
  474. if not base_url:
  475. raise ValueError("Local cv model url cannot be None")
  476. if base_url.split("/")[-1] != "v1":
  477. base_url = os.path.join(base_url, "v1")
  478. self.client = OpenAI(api_key="empty", base_url=base_url)
  479. self.model_name = model_name.split("___")[0]
  480. self.lang = lang
  481. class XinferenceCV(Base):
  482. def __init__(self, key, model_name="", lang="Chinese", base_url=""):
  483. if base_url.split("/")[-1] != "v1":
  484. base_url = os.path.join(base_url, "v1")
  485. self.client = OpenAI(api_key=key, base_url=base_url)
  486. self.model_name = model_name
  487. self.lang = lang
  488. def describe(self, image):
  489. b64 = self.image2base64(image)
  490. res = self.client.chat.completions.create(
  491. model=self.model_name,
  492. messages=self.prompt(b64)
  493. )
  494. return res.choices[0].message.content.strip(), res.usage.total_tokens
  495. def describe_with_prompt(self, image, prompt=None):
  496. b64 = self.image2base64(image)
  497. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  498. res = self.client.chat.completions.create(
  499. model=self.model_name,
  500. messages=vision_prompt,
  501. )
  502. return res.choices[0].message.content.strip(), res.usage.total_tokens
  503. class GeminiCV(Base):
  504. def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
  505. from google.generativeai import GenerativeModel, client
  506. client.configure(api_key=key)
  507. _client = client.get_default_generative_client()
  508. self.model_name = model_name
  509. self.model = GenerativeModel(model_name=self.model_name)
  510. self.model._client = _client
  511. self.lang = lang
  512. def describe(self, image):
  513. from PIL.Image import open
  514. prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
  515. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  516. b64 = self.image2base64(image)
  517. img = open(BytesIO(base64.b64decode(b64)))
  518. input = [prompt, img]
  519. res = self.model.generate_content(
  520. input
  521. )
  522. return res.text, res.usage_metadata.total_token_count
  523. def describe_with_prompt(self, image, prompt=None):
  524. from PIL.Image import open
  525. b64 = self.image2base64(image)
  526. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  527. img = open(BytesIO(base64.b64decode(b64)))
  528. input = [vision_prompt, img]
  529. res = self.model.generate_content(
  530. input,
  531. )
  532. return res.text, res.usage_metadata.total_token_count
  533. def chat(self, system, history, gen_conf, image=""):
  534. from transformers import GenerationConfig
  535. if system:
  536. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  537. try:
  538. for his in history:
  539. if his["role"] == "assistant":
  540. his["role"] = "model"
  541. his["parts"] = [his["content"]]
  542. his.pop("content")
  543. if his["role"] == "user":
  544. his["parts"] = [his["content"]]
  545. his.pop("content")
  546. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  547. response = self.model.generate_content(history, generation_config=GenerationConfig(
  548. temperature=gen_conf.get("temperature", 0.3),
  549. top_p=gen_conf.get("top_p", 0.7)))
  550. ans = response.text
  551. return ans, response.usage_metadata.total_token_count
  552. except Exception as e:
  553. return "**ERROR**: " + str(e), 0
  554. def chat_streamly(self, system, history, gen_conf, image=""):
  555. from transformers import GenerationConfig
  556. if system:
  557. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  558. ans = ""
  559. try:
  560. for his in history:
  561. if his["role"] == "assistant":
  562. his["role"] = "model"
  563. his["parts"] = [his["content"]]
  564. his.pop("content")
  565. if his["role"] == "user":
  566. his["parts"] = [his["content"]]
  567. his.pop("content")
  568. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  569. response = self.model.generate_content(history, generation_config=GenerationConfig(
  570. temperature=gen_conf.get("temperature", 0.3),
  571. top_p=gen_conf.get("top_p", 0.7)), stream=True)
  572. for resp in response:
  573. if not resp.text:
  574. continue
  575. ans += resp.text
  576. yield ans
  577. except Exception as e:
  578. yield ans + "\n**ERROR**: " + str(e)
  579. yield response._chunks[-1].usage_metadata.total_token_count
  580. class OpenRouterCV(GptV4):
  581. def __init__(
  582. self,
  583. key,
  584. model_name,
  585. lang="Chinese",
  586. base_url="https://openrouter.ai/api/v1",
  587. ):
  588. if not base_url:
  589. base_url = "https://openrouter.ai/api/v1"
  590. self.client = OpenAI(api_key=key, base_url=base_url)
  591. self.model_name = model_name
  592. self.lang = lang
  593. class LocalCV(Base):
  594. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  595. pass
  596. def describe(self, image):
  597. return "", 0
  598. class NvidiaCV(Base):
  599. def __init__(
  600. self,
  601. key,
  602. model_name,
  603. lang="Chinese",
  604. base_url="https://ai.api.nvidia.com/v1/vlm",
  605. ):
  606. if not base_url:
  607. base_url = ("https://ai.api.nvidia.com/v1/vlm",)
  608. self.lang = lang
  609. factory, llm_name = model_name.split("/")
  610. if factory != "liuhaotian":
  611. self.base_url = os.path.join(base_url, factory, llm_name)
  612. else:
  613. self.base_url = os.path.join(
  614. base_url, "community", llm_name.replace("-v1.6", "16")
  615. )
  616. self.key = key
  617. def describe(self, image):
  618. b64 = self.image2base64(image)
  619. response = requests.post(
  620. url=self.base_url,
  621. headers={
  622. "accept": "application/json",
  623. "content-type": "application/json",
  624. "Authorization": f"Bearer {self.key}",
  625. },
  626. json={
  627. "messages": self.prompt(b64)
  628. },
  629. )
  630. response = response.json()
  631. return (
  632. response["choices"][0]["message"]["content"].strip(),
  633. response["usage"]["total_tokens"],
  634. )
  635. def describe_with_prompt(self, image, prompt=None):
  636. b64 = self.image2base64(image)
  637. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  638. response = requests.post(
  639. url=self.base_url,
  640. headers={
  641. "accept": "application/json",
  642. "content-type": "application/json",
  643. "Authorization": f"Bearer {self.key}",
  644. },
  645. json={
  646. "messages": vision_prompt,
  647. },
  648. )
  649. response = response.json()
  650. return (
  651. response["choices"][0]["message"]["content"].strip(),
  652. response["usage"]["total_tokens"],
  653. )
  654. def prompt(self, b64):
  655. return [
  656. {
  657. "role": "user",
  658. "content": (
  659. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  660. if self.lang.lower() == "chinese"
  661. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  662. )
  663. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  664. }
  665. ]
  666. def vision_llm_prompt(self, b64, prompt=None):
  667. return [
  668. {
  669. "role": "user",
  670. "content": (
  671. prompt if prompt else vision_llm_describe_prompt()
  672. )
  673. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  674. }
  675. ]
  676. def chat_prompt(self, text, b64):
  677. return [
  678. {
  679. "role": "user",
  680. "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
  681. }
  682. ]
  683. class StepFunCV(GptV4):
  684. def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
  685. if not base_url:
  686. base_url = "https://api.stepfun.com/v1"
  687. self.client = OpenAI(api_key=key, base_url=base_url)
  688. self.model_name = model_name
  689. self.lang = lang
  690. class LmStudioCV(GptV4):
  691. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  692. if not base_url:
  693. raise ValueError("Local llm url cannot be None")
  694. if base_url.split("/")[-1] != "v1":
  695. base_url = os.path.join(base_url, "v1")
  696. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  697. self.model_name = model_name
  698. self.lang = lang
  699. class OpenAI_APICV(GptV4):
  700. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  701. if not base_url:
  702. raise ValueError("url cannot be None")
  703. if base_url.split("/")[-1] != "v1":
  704. base_url = os.path.join(base_url, "v1")
  705. self.client = OpenAI(api_key=key, base_url=base_url)
  706. self.model_name = model_name.split("___")[0]
  707. self.lang = lang
  708. class TogetherAICV(GptV4):
  709. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
  710. if not base_url:
  711. base_url = "https://api.together.xyz/v1"
  712. super().__init__(key, model_name, lang, base_url)
  713. class YiCV(GptV4):
  714. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
  715. if not base_url:
  716. base_url = "https://api.lingyiwanwu.com/v1"
  717. super().__init__(key, model_name, lang, base_url)
  718. class SILICONFLOWCV(GptV4):
  719. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.siliconflow.cn/v1",):
  720. if not base_url:
  721. base_url = "https://api.siliconflow.cn/v1"
  722. super().__init__(key, model_name, lang, base_url)
  723. class HunyuanCV(Base):
  724. def __init__(self, key, model_name, lang="Chinese", base_url=None):
  725. from tencentcloud.common import credential
  726. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  727. key = json.loads(key)
  728. sid = key.get("hunyuan_sid", "")
  729. sk = key.get("hunyuan_sk", "")
  730. cred = credential.Credential(sid, sk)
  731. self.model_name = model_name
  732. self.client = hunyuan_client.HunyuanClient(cred, "")
  733. self.lang = lang
  734. def describe(self, image):
  735. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  736. TencentCloudSDKException,
  737. )
  738. from tencentcloud.hunyuan.v20230901 import models
  739. b64 = self.image2base64(image)
  740. req = models.ChatCompletionsRequest()
  741. params = {"Model": self.model_name, "Messages": self.prompt(b64)}
  742. req.from_json_string(json.dumps(params))
  743. ans = ""
  744. try:
  745. response = self.client.ChatCompletions(req)
  746. ans = response.Choices[0].Message.Content
  747. return ans, response.Usage.TotalTokens
  748. except TencentCloudSDKException as e:
  749. return ans + "\n**ERROR**: " + str(e), 0
  750. def describe_with_prompt(self, image, prompt=None):
  751. from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
  752. from tencentcloud.hunyuan.v20230901 import models
  753. b64 = self.image2base64(image)
  754. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  755. req = models.ChatCompletionsRequest()
  756. params = {"Model": self.model_name, "Messages": vision_prompt}
  757. req.from_json_string(json.dumps(params))
  758. ans = ""
  759. try:
  760. response = self.client.ChatCompletions(req)
  761. ans = response.Choices[0].Message.Content
  762. return ans, response.Usage.TotalTokens
  763. except TencentCloudSDKException as e:
  764. return ans + "\n**ERROR**: " + str(e), 0
  765. def prompt(self, b64):
  766. return [
  767. {
  768. "Role": "user",
  769. "Contents": [
  770. {
  771. "Type": "image_url",
  772. "ImageUrl": {
  773. "Url": f"data:image/jpeg;base64,{b64}"
  774. },
  775. },
  776. {
  777. "Type": "text",
  778. "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  779. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  780. },
  781. ],
  782. }
  783. ]
  784. class AnthropicCV(Base):
  785. def __init__(self, key, model_name, base_url=None):
  786. import anthropic
  787. self.client = anthropic.Anthropic(api_key=key)
  788. self.model_name = model_name
  789. self.system = ""
  790. self.max_tokens = 8192
  791. if "haiku" in self.model_name or "opus" in self.model_name:
  792. self.max_tokens = 4096
  793. def prompt(self, b64, prompt):
  794. return [
  795. {
  796. "role": "user",
  797. "content": [
  798. {
  799. "type": "image",
  800. "source": {
  801. "type": "base64",
  802. "media_type": "image/jpeg",
  803. "data": b64,
  804. },
  805. },
  806. {
  807. "type": "text",
  808. "text": prompt
  809. }
  810. ],
  811. }
  812. ]
  813. def describe(self, image):
  814. b64 = self.image2base64(image)
  815. prompt = self.prompt(b64,
  816. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  817. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  818. )
  819. response = self.client.messages.create(
  820. model=self.model_name,
  821. max_tokens=self.max_tokens,
  822. messages=prompt
  823. )
  824. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
  825. def describe_with_prompt(self, image, prompt=None):
  826. b64 = self.image2base64(image)
  827. prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
  828. response = self.client.messages.create(
  829. model=self.model_name,
  830. max_tokens=self.max_tokens,
  831. messages=prompt
  832. )
  833. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"]+response["usage"]["output_tokens"]
  834. def chat(self, system, history, gen_conf):
  835. if "presence_penalty" in gen_conf:
  836. del gen_conf["presence_penalty"]
  837. if "frequency_penalty" in gen_conf:
  838. del gen_conf["frequency_penalty"]
  839. gen_conf["max_tokens"] = self.max_tokens
  840. ans = ""
  841. try:
  842. response = self.client.messages.create(
  843. model=self.model_name,
  844. messages=history,
  845. system=system,
  846. stream=False,
  847. **gen_conf,
  848. ).to_dict()
  849. ans = response["content"][0]["text"]
  850. if response["stop_reason"] == "max_tokens":
  851. ans += (
  852. "...\nFor the content length reason, it stopped, continue?"
  853. if is_english([ans])
  854. else "······\n由于长度的原因,回答被截断了,要继续吗?"
  855. )
  856. return (
  857. ans,
  858. response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
  859. )
  860. except Exception as e:
  861. return ans + "\n**ERROR**: " + str(e), 0
  862. def chat_streamly(self, system, history, gen_conf):
  863. if "presence_penalty" in gen_conf:
  864. del gen_conf["presence_penalty"]
  865. if "frequency_penalty" in gen_conf:
  866. del gen_conf["frequency_penalty"]
  867. gen_conf["max_tokens"] = self.max_tokens
  868. ans = ""
  869. total_tokens = 0
  870. try:
  871. response = self.client.messages.create(
  872. model=self.model_name,
  873. messages=history,
  874. system=system,
  875. stream=True,
  876. **gen_conf,
  877. )
  878. for res in response:
  879. if res.type == 'content_block_delta':
  880. if res.delta.type == "thinking_delta" and res.delta.thinking:
  881. if ans.find("<think>") < 0:
  882. ans += "<think>"
  883. ans = ans.replace("</think>", "")
  884. ans += res.delta.thinking + "</think>"
  885. else:
  886. text = res.delta.text
  887. ans += text
  888. total_tokens += num_tokens_from_string(text)
  889. yield ans
  890. except Exception as e:
  891. yield ans + "\n**ERROR**: " + str(e)
  892. yield total_tokens
  893. class GPUStackCV(GptV4):
  894. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  895. if not base_url:
  896. raise ValueError("Local llm url cannot be None")
  897. if base_url.split("/")[-1] != "v1":
  898. base_url = os.path.join(base_url, "v1")
  899. self.client = OpenAI(api_key=key, base_url=base_url)
  900. self.model_name = model_name
  901. self.lang = lang