Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

cv_model.py 46KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import io
  18. import json
  19. import os
  20. from abc import ABC
  21. from io import BytesIO
  22. from urllib.parse import urljoin
  23. import requests
  24. from ollama import Client
  25. from openai import OpenAI
  26. from openai.lib.azure import AzureOpenAI
  27. from PIL import Image
  28. from zhipuai import ZhipuAI
  29. from api.utils import get_uuid
  30. from api.utils.file_utils import get_project_base_directory
  31. from rag.nlp import is_english
  32. from rag.prompts import vision_llm_describe_prompt
  33. from rag.utils import num_tokens_from_string
  34. class Base(ABC):
  35. def __init__(self, key, model_name):
  36. pass
  37. def describe(self, image):
  38. raise NotImplementedError("Please implement encode method!")
  39. def describe_with_prompt(self, image, prompt=None):
  40. raise NotImplementedError("Please implement encode method!")
  41. def chat(self, system, history, gen_conf, image=""):
  42. if system:
  43. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  44. try:
  45. for his in history:
  46. if his["role"] == "user":
  47. his["content"] = self.chat_prompt(his["content"], image)
  48. response = self.client.chat.completions.create(
  49. model=self.model_name,
  50. messages=history,
  51. temperature=gen_conf.get("temperature", 0.3),
  52. top_p=gen_conf.get("top_p", 0.7),
  53. )
  54. return response.choices[0].message.content.strip(), response.usage.total_tokens
  55. except Exception as e:
  56. return "**ERROR**: " + str(e), 0
  57. def chat_streamly(self, system, history, gen_conf, image=""):
  58. if system:
  59. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  60. ans = ""
  61. tk_count = 0
  62. try:
  63. for his in history:
  64. if his["role"] == "user":
  65. his["content"] = self.chat_prompt(his["content"], image)
  66. response = self.client.chat.completions.create(
  67. model=self.model_name,
  68. messages=history,
  69. temperature=gen_conf.get("temperature", 0.3),
  70. top_p=gen_conf.get("top_p", 0.7),
  71. stream=True,
  72. )
  73. for resp in response:
  74. if not resp.choices[0].delta.content:
  75. continue
  76. delta = resp.choices[0].delta.content
  77. ans += delta
  78. if resp.choices[0].finish_reason == "length":
  79. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  80. tk_count = resp.usage.total_tokens
  81. if resp.choices[0].finish_reason == "stop":
  82. tk_count = resp.usage.total_tokens
  83. yield ans
  84. except Exception as e:
  85. yield ans + "\n**ERROR**: " + str(e)
  86. yield tk_count
  87. def image2base64(self, image):
  88. if isinstance(image, bytes):
  89. return base64.b64encode(image).decode("utf-8")
  90. if isinstance(image, BytesIO):
  91. return base64.b64encode(image.getvalue()).decode("utf-8")
  92. buffered = BytesIO()
  93. try:
  94. image.save(buffered, format="JPEG")
  95. except Exception:
  96. image.save(buffered, format="PNG")
  97. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  98. def prompt(self, b64):
  99. return [
  100. {
  101. "role": "user",
  102. "content": [
  103. {
  104. "type": "image_url",
  105. "image_url": {"url": f"data:image/jpeg;base64,{b64}"},
  106. },
  107. {
  108. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  109. if self.lang.lower() == "chinese"
  110. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  111. },
  112. ],
  113. }
  114. ]
  115. def vision_llm_prompt(self, b64, prompt=None):
  116. return [
  117. {
  118. "role": "user",
  119. "content": [
  120. {
  121. "type": "image_url",
  122. "image_url": {"url": f"data:image/jpeg;base64,{b64}"},
  123. },
  124. {
  125. "type": "text",
  126. "text": prompt if prompt else vision_llm_describe_prompt(),
  127. },
  128. ],
  129. }
  130. ]
  131. def chat_prompt(self, text, b64):
  132. return [
  133. {
  134. "type": "image_url",
  135. "image_url": {
  136. "url": f"data:image/jpeg;base64,{b64}",
  137. },
  138. },
  139. {"type": "text", "text": text},
  140. ]
  141. class GptV4(Base):
  142. _FACTORY_NAME = "OpenAI"
  143. def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
  144. if not base_url:
  145. base_url = "https://api.openai.com/v1"
  146. self.client = OpenAI(api_key=key, base_url=base_url)
  147. self.model_name = model_name
  148. self.lang = lang
  149. def describe(self, image):
  150. b64 = self.image2base64(image)
  151. prompt = self.prompt(b64)
  152. for i in range(len(prompt)):
  153. for c in prompt[i]["content"]:
  154. if "text" in c:
  155. c["type"] = "text"
  156. res = self.client.chat.completions.create(
  157. model=self.model_name,
  158. messages=prompt,
  159. )
  160. return res.choices[0].message.content.strip(), res.usage.total_tokens
  161. def describe_with_prompt(self, image, prompt=None):
  162. b64 = self.image2base64(image)
  163. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  164. res = self.client.chat.completions.create(
  165. model=self.model_name,
  166. messages=vision_prompt,
  167. )
  168. return res.choices[0].message.content.strip(), res.usage.total_tokens
  169. class AzureGptV4(Base):
  170. _FACTORY_NAME = "Azure-OpenAI"
  171. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  172. api_key = json.loads(key).get("api_key", "")
  173. api_version = json.loads(key).get("api_version", "2024-02-01")
  174. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  175. self.model_name = model_name
  176. self.lang = lang
  177. def describe(self, image):
  178. b64 = self.image2base64(image)
  179. prompt = self.prompt(b64)
  180. for i in range(len(prompt)):
  181. for c in prompt[i]["content"]:
  182. if "text" in c:
  183. c["type"] = "text"
  184. res = self.client.chat.completions.create(model=self.model_name, messages=prompt)
  185. return res.choices[0].message.content.strip(), res.usage.total_tokens
  186. def describe_with_prompt(self, image, prompt=None):
  187. b64 = self.image2base64(image)
  188. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  189. res = self.client.chat.completions.create(
  190. model=self.model_name,
  191. messages=vision_prompt,
  192. )
  193. return res.choices[0].message.content.strip(), res.usage.total_tokens
  194. class QWenCV(Base):
  195. _FACTORY_NAME = "Tongyi-Qianwen"
  196. def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
  197. import dashscope
  198. dashscope.api_key = key
  199. self.model_name = model_name
  200. self.lang = lang
  201. def prompt(self, binary):
  202. # stupid as hell
  203. tmp_dir = get_project_base_directory("tmp")
  204. if not os.path.exists(tmp_dir):
  205. os.makedirs(tmp_dir, exist_ok=True)
  206. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  207. Image.open(io.BytesIO(binary)).save(path)
  208. return [
  209. {
  210. "role": "user",
  211. "content": [
  212. {"image": f"file://{path}"},
  213. {
  214. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  215. if self.lang.lower() == "chinese"
  216. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  217. },
  218. ],
  219. }
  220. ]
  221. def vision_llm_prompt(self, binary, prompt=None):
  222. # stupid as hell
  223. tmp_dir = get_project_base_directory("tmp")
  224. if not os.path.exists(tmp_dir):
  225. os.makedirs(tmp_dir, exist_ok=True)
  226. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  227. Image.open(io.BytesIO(binary)).save(path)
  228. return [
  229. {
  230. "role": "user",
  231. "content": [
  232. {"image": f"file://{path}"},
  233. {
  234. "text": prompt if prompt else vision_llm_describe_prompt(),
  235. },
  236. ],
  237. }
  238. ]
  239. def chat_prompt(self, text, b64):
  240. return [
  241. {"image": f"{b64}"},
  242. {"text": text},
  243. ]
  244. def describe(self, image):
  245. from http import HTTPStatus
  246. from dashscope import MultiModalConversation
  247. response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
  248. if response.status_code == HTTPStatus.OK:
  249. return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
  250. return response.message, 0
  251. def describe_with_prompt(self, image, prompt=None):
  252. from http import HTTPStatus
  253. from dashscope import MultiModalConversation
  254. vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
  255. response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
  256. if response.status_code == HTTPStatus.OK:
  257. return response.output.choices[0]["message"]["content"][0]["text"], response.usage.output_tokens
  258. return response.message, 0
  259. def chat(self, system, history, gen_conf, image=""):
  260. from http import HTTPStatus
  261. from dashscope import MultiModalConversation
  262. if system:
  263. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  264. for his in history:
  265. if his["role"] == "user":
  266. his["content"] = self.chat_prompt(his["content"], image)
  267. response = MultiModalConversation.call(
  268. model=self.model_name,
  269. messages=history,
  270. temperature=gen_conf.get("temperature", 0.3),
  271. top_p=gen_conf.get("top_p", 0.7),
  272. )
  273. ans = ""
  274. tk_count = 0
  275. if response.status_code == HTTPStatus.OK:
  276. ans = response.output.choices[0]["message"]["content"]
  277. if isinstance(ans, list):
  278. ans = ans[0]["text"] if ans else ""
  279. tk_count += response.usage.total_tokens
  280. if response.output.choices[0].get("finish_reason", "") == "length":
  281. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  282. return ans, tk_count
  283. return "**ERROR**: " + response.message, tk_count
  284. def chat_streamly(self, system, history, gen_conf, image=""):
  285. from http import HTTPStatus
  286. from dashscope import MultiModalConversation
  287. if system:
  288. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  289. for his in history:
  290. if his["role"] == "user":
  291. his["content"] = self.chat_prompt(his["content"], image)
  292. ans = ""
  293. tk_count = 0
  294. try:
  295. response = MultiModalConversation.call(
  296. model=self.model_name,
  297. messages=history,
  298. temperature=gen_conf.get("temperature", 0.3),
  299. top_p=gen_conf.get("top_p", 0.7),
  300. stream=True,
  301. )
  302. for resp in response:
  303. if resp.status_code == HTTPStatus.OK:
  304. cnt = resp.output.choices[0]["message"]["content"]
  305. if isinstance(cnt, list):
  306. cnt = cnt[0]["text"] if ans else ""
  307. ans += cnt
  308. tk_count = resp.usage.total_tokens
  309. if resp.output.choices[0].get("finish_reason", "") == "length":
  310. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  311. yield ans
  312. else:
  313. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  314. except Exception as e:
  315. yield ans + "\n**ERROR**: " + str(e)
  316. yield tk_count
  317. class Zhipu4V(Base):
  318. _FACTORY_NAME = "ZHIPU-AI"
  319. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  320. self.client = ZhipuAI(api_key=key)
  321. self.model_name = model_name
  322. self.lang = lang
  323. def describe(self, image):
  324. b64 = self.image2base64(image)
  325. prompt = self.prompt(b64)
  326. prompt[0]["content"][1]["type"] = "text"
  327. res = self.client.chat.completions.create(
  328. model=self.model_name,
  329. messages=prompt,
  330. )
  331. return res.choices[0].message.content.strip(), res.usage.total_tokens
  332. def describe_with_prompt(self, image, prompt=None):
  333. b64 = self.image2base64(image)
  334. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  335. res = self.client.chat.completions.create(model=self.model_name, messages=vision_prompt)
  336. return res.choices[0].message.content.strip(), res.usage.total_tokens
  337. def chat(self, system, history, gen_conf, image=""):
  338. if system:
  339. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  340. try:
  341. for his in history:
  342. if his["role"] == "user":
  343. his["content"] = self.chat_prompt(his["content"], image)
  344. response = self.client.chat.completions.create(
  345. model=self.model_name,
  346. messages=history,
  347. temperature=gen_conf.get("temperature", 0.3),
  348. top_p=gen_conf.get("top_p", 0.7),
  349. )
  350. return response.choices[0].message.content.strip(), response.usage.total_tokens
  351. except Exception as e:
  352. return "**ERROR**: " + str(e), 0
  353. def chat_streamly(self, system, history, gen_conf, image=""):
  354. if system:
  355. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  356. ans = ""
  357. tk_count = 0
  358. try:
  359. for his in history:
  360. if his["role"] == "user":
  361. his["content"] = self.chat_prompt(his["content"], image)
  362. response = self.client.chat.completions.create(
  363. model=self.model_name,
  364. messages=history,
  365. temperature=gen_conf.get("temperature", 0.3),
  366. top_p=gen_conf.get("top_p", 0.7),
  367. stream=True,
  368. )
  369. for resp in response:
  370. if not resp.choices[0].delta.content:
  371. continue
  372. delta = resp.choices[0].delta.content
  373. ans += delta
  374. if resp.choices[0].finish_reason == "length":
  375. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  376. tk_count = resp.usage.total_tokens
  377. if resp.choices[0].finish_reason == "stop":
  378. tk_count = resp.usage.total_tokens
  379. yield ans
  380. except Exception as e:
  381. yield ans + "\n**ERROR**: " + str(e)
  382. yield tk_count
  383. class OllamaCV(Base):
  384. _FACTORY_NAME = "Ollama"
  385. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  386. self.client = Client(host=kwargs["base_url"])
  387. self.model_name = model_name
  388. self.lang = lang
  389. def describe(self, image):
  390. prompt = self.prompt("")
  391. try:
  392. response = self.client.generate(
  393. model=self.model_name,
  394. prompt=prompt[0]["content"][1]["text"],
  395. images=[image],
  396. )
  397. ans = response["response"].strip()
  398. return ans, 128
  399. except Exception as e:
  400. return "**ERROR**: " + str(e), 0
  401. def describe_with_prompt(self, image, prompt=None):
  402. vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
  403. try:
  404. response = self.client.generate(
  405. model=self.model_name,
  406. prompt=vision_prompt[0]["content"][1]["text"],
  407. images=[image],
  408. )
  409. ans = response["response"].strip()
  410. return ans, 128
  411. except Exception as e:
  412. return "**ERROR**: " + str(e), 0
  413. def chat(self, system, history, gen_conf, image=""):
  414. if system:
  415. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  416. try:
  417. for his in history:
  418. if his["role"] == "user":
  419. his["images"] = [image]
  420. options = {}
  421. if "temperature" in gen_conf:
  422. options["temperature"] = gen_conf["temperature"]
  423. if "top_p" in gen_conf:
  424. options["top_k"] = gen_conf["top_p"]
  425. if "presence_penalty" in gen_conf:
  426. options["presence_penalty"] = gen_conf["presence_penalty"]
  427. if "frequency_penalty" in gen_conf:
  428. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  429. response = self.client.chat(
  430. model=self.model_name,
  431. messages=history,
  432. options=options,
  433. keep_alive=-1,
  434. )
  435. ans = response["message"]["content"].strip()
  436. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  437. except Exception as e:
  438. return "**ERROR**: " + str(e), 0
  439. def chat_streamly(self, system, history, gen_conf, image=""):
  440. if system:
  441. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  442. for his in history:
  443. if his["role"] == "user":
  444. his["images"] = [image]
  445. options = {}
  446. if "temperature" in gen_conf:
  447. options["temperature"] = gen_conf["temperature"]
  448. if "top_p" in gen_conf:
  449. options["top_k"] = gen_conf["top_p"]
  450. if "presence_penalty" in gen_conf:
  451. options["presence_penalty"] = gen_conf["presence_penalty"]
  452. if "frequency_penalty" in gen_conf:
  453. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  454. ans = ""
  455. try:
  456. response = self.client.chat(
  457. model=self.model_name,
  458. messages=history,
  459. stream=True,
  460. options=options,
  461. keep_alive=-1,
  462. )
  463. for resp in response:
  464. if resp["done"]:
  465. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  466. ans += resp["message"]["content"]
  467. yield ans
  468. except Exception as e:
  469. yield ans + "\n**ERROR**: " + str(e)
  470. yield 0
  471. class LocalAICV(GptV4):
  472. _FACTORY_NAME = "LocalAI"
  473. def __init__(self, key, model_name, base_url, lang="Chinese"):
  474. if not base_url:
  475. raise ValueError("Local cv model url cannot be None")
  476. base_url = urljoin(base_url, "v1")
  477. self.client = OpenAI(api_key="empty", base_url=base_url)
  478. self.model_name = model_name.split("___")[0]
  479. self.lang = lang
  480. class XinferenceCV(Base):
  481. _FACTORY_NAME = "Xinference"
  482. def __init__(self, key, model_name="", lang="Chinese", base_url=""):
  483. base_url = urljoin(base_url, "v1")
  484. self.client = OpenAI(api_key=key, base_url=base_url)
  485. self.model_name = model_name
  486. self.lang = lang
  487. def describe(self, image):
  488. b64 = self.image2base64(image)
  489. res = self.client.chat.completions.create(model=self.model_name, messages=self.prompt(b64))
  490. return res.choices[0].message.content.strip(), res.usage.total_tokens
  491. def describe_with_prompt(self, image, prompt=None):
  492. b64 = self.image2base64(image)
  493. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  494. res = self.client.chat.completions.create(
  495. model=self.model_name,
  496. messages=vision_prompt,
  497. )
  498. return res.choices[0].message.content.strip(), res.usage.total_tokens
  499. class GeminiCV(Base):
  500. _FACTORY_NAME = "Gemini"
  501. def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
  502. from google.generativeai import GenerativeModel, client
  503. client.configure(api_key=key)
  504. _client = client.get_default_generative_client()
  505. self.model_name = model_name
  506. self.model = GenerativeModel(model_name=self.model_name)
  507. self.model._client = _client
  508. self.lang = lang
  509. def describe(self, image):
  510. from PIL.Image import open
  511. prompt = (
  512. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  513. if self.lang.lower() == "chinese"
  514. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  515. )
  516. b64 = self.image2base64(image)
  517. img = open(BytesIO(base64.b64decode(b64)))
  518. input = [prompt, img]
  519. res = self.model.generate_content(input)
  520. return res.text, res.usage_metadata.total_token_count
  521. def describe_with_prompt(self, image, prompt=None):
  522. from PIL.Image import open
  523. b64 = self.image2base64(image)
  524. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  525. img = open(BytesIO(base64.b64decode(b64)))
  526. input = [vision_prompt, img]
  527. res = self.model.generate_content(
  528. input,
  529. )
  530. return res.text, res.usage_metadata.total_token_count
  531. def chat(self, system, history, gen_conf, image=""):
  532. from transformers import GenerationConfig
  533. if system:
  534. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  535. try:
  536. for his in history:
  537. if his["role"] == "assistant":
  538. his["role"] = "model"
  539. his["parts"] = [his["content"]]
  540. his.pop("content")
  541. if his["role"] == "user":
  542. his["parts"] = [his["content"]]
  543. his.pop("content")
  544. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  545. response = self.model.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
  546. ans = response.text
  547. return ans, response.usage_metadata.total_token_count
  548. except Exception as e:
  549. return "**ERROR**: " + str(e), 0
  550. def chat_streamly(self, system, history, gen_conf, image=""):
  551. from transformers import GenerationConfig
  552. if system:
  553. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  554. ans = ""
  555. try:
  556. for his in history:
  557. if his["role"] == "assistant":
  558. his["role"] = "model"
  559. his["parts"] = [his["content"]]
  560. his.pop("content")
  561. if his["role"] == "user":
  562. his["parts"] = [his["content"]]
  563. his.pop("content")
  564. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  565. response = self.model.generate_content(
  566. history,
  567. generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)),
  568. stream=True,
  569. )
  570. for resp in response:
  571. if not resp.text:
  572. continue
  573. ans += resp.text
  574. yield ans
  575. except Exception as e:
  576. yield ans + "\n**ERROR**: " + str(e)
  577. yield response._chunks[-1].usage_metadata.total_token_count
  578. class OpenRouterCV(GptV4):
  579. _FACTORY_NAME = "OpenRouter"
  580. def __init__(
  581. self,
  582. key,
  583. model_name,
  584. lang="Chinese",
  585. base_url="https://openrouter.ai/api/v1",
  586. ):
  587. if not base_url:
  588. base_url = "https://openrouter.ai/api/v1"
  589. self.client = OpenAI(api_key=key, base_url=base_url)
  590. self.model_name = model_name
  591. self.lang = lang
  592. class LocalCV(Base):
  593. _FACTORY_NAME = "Moonshot"
  594. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  595. pass
  596. def describe(self, image):
  597. return "", 0
  598. class NvidiaCV(Base):
  599. _FACTORY_NAME = "NVIDIA"
  600. def __init__(
  601. self,
  602. key,
  603. model_name,
  604. lang="Chinese",
  605. base_url="https://ai.api.nvidia.com/v1/vlm",
  606. ):
  607. if not base_url:
  608. base_url = ("https://ai.api.nvidia.com/v1/vlm",)
  609. self.lang = lang
  610. factory, llm_name = model_name.split("/")
  611. if factory != "liuhaotian":
  612. self.base_url = urljoin(base_url, f"{factory}/{llm_name}")
  613. else:
  614. self.base_url = urljoin(f"{base_url}/community", llm_name.replace("-v1.6", "16"))
  615. self.key = key
  616. def describe(self, image):
  617. b64 = self.image2base64(image)
  618. response = requests.post(
  619. url=self.base_url,
  620. headers={
  621. "accept": "application/json",
  622. "content-type": "application/json",
  623. "Authorization": f"Bearer {self.key}",
  624. },
  625. json={"messages": self.prompt(b64)},
  626. )
  627. response = response.json()
  628. return (
  629. response["choices"][0]["message"]["content"].strip(),
  630. response["usage"]["total_tokens"],
  631. )
  632. def describe_with_prompt(self, image, prompt=None):
  633. b64 = self.image2base64(image)
  634. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  635. response = requests.post(
  636. url=self.base_url,
  637. headers={
  638. "accept": "application/json",
  639. "content-type": "application/json",
  640. "Authorization": f"Bearer {self.key}",
  641. },
  642. json={
  643. "messages": vision_prompt,
  644. },
  645. )
  646. response = response.json()
  647. return (
  648. response["choices"][0]["message"]["content"].strip(),
  649. response["usage"]["total_tokens"],
  650. )
  651. def prompt(self, b64):
  652. return [
  653. {
  654. "role": "user",
  655. "content": (
  656. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  657. if self.lang.lower() == "chinese"
  658. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  659. )
  660. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  661. }
  662. ]
  663. def vision_llm_prompt(self, b64, prompt=None):
  664. return [
  665. {
  666. "role": "user",
  667. "content": (prompt if prompt else vision_llm_describe_prompt()) + f' <img src="data:image/jpeg;base64,{b64}"/>',
  668. }
  669. ]
  670. def chat_prompt(self, text, b64):
  671. return [
  672. {
  673. "role": "user",
  674. "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
  675. }
  676. ]
  677. class StepFunCV(GptV4):
  678. _FACTORY_NAME = "StepFun"
  679. def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
  680. if not base_url:
  681. base_url = "https://api.stepfun.com/v1"
  682. self.client = OpenAI(api_key=key, base_url=base_url)
  683. self.model_name = model_name
  684. self.lang = lang
  685. class LmStudioCV(GptV4):
  686. _FACTORY_NAME = "LM-Studio"
  687. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  688. if not base_url:
  689. raise ValueError("Local llm url cannot be None")
  690. base_url = urljoin(base_url, "v1")
  691. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  692. self.model_name = model_name
  693. self.lang = lang
  694. class OpenAI_APICV(GptV4):
  695. _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
  696. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  697. if not base_url:
  698. raise ValueError("url cannot be None")
  699. base_url = urljoin(base_url, "v1")
  700. self.client = OpenAI(api_key=key, base_url=base_url)
  701. self.model_name = model_name.split("___")[0]
  702. self.lang = lang
  703. class TogetherAICV(GptV4):
  704. _FACTORY_NAME = "TogetherAI"
  705. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
  706. if not base_url:
  707. base_url = "https://api.together.xyz/v1"
  708. super().__init__(key, model_name, lang, base_url)
  709. class YiCV(GptV4):
  710. _FACTORY_NAME = "01.AI"
  711. def __init__(
  712. self,
  713. key,
  714. model_name,
  715. lang="Chinese",
  716. base_url="https://api.lingyiwanwu.com/v1",
  717. ):
  718. if not base_url:
  719. base_url = "https://api.lingyiwanwu.com/v1"
  720. super().__init__(key, model_name, lang, base_url)
  721. class SILICONFLOWCV(GptV4):
  722. _FACTORY_NAME = "SILICONFLOW"
  723. def __init__(
  724. self,
  725. key,
  726. model_name,
  727. lang="Chinese",
  728. base_url="https://api.siliconflow.cn/v1",
  729. ):
  730. if not base_url:
  731. base_url = "https://api.siliconflow.cn/v1"
  732. super().__init__(key, model_name, lang, base_url)
  733. class HunyuanCV(Base):
  734. _FACTORY_NAME = "Tencent Hunyuan"
  735. def __init__(self, key, model_name, lang="Chinese", base_url=None):
  736. from tencentcloud.common import credential
  737. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  738. key = json.loads(key)
  739. sid = key.get("hunyuan_sid", "")
  740. sk = key.get("hunyuan_sk", "")
  741. cred = credential.Credential(sid, sk)
  742. self.model_name = model_name
  743. self.client = hunyuan_client.HunyuanClient(cred, "")
  744. self.lang = lang
  745. def describe(self, image):
  746. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  747. TencentCloudSDKException,
  748. )
  749. from tencentcloud.hunyuan.v20230901 import models
  750. b64 = self.image2base64(image)
  751. req = models.ChatCompletionsRequest()
  752. params = {"Model": self.model_name, "Messages": self.prompt(b64)}
  753. req.from_json_string(json.dumps(params))
  754. ans = ""
  755. try:
  756. response = self.client.ChatCompletions(req)
  757. ans = response.Choices[0].Message.Content
  758. return ans, response.Usage.TotalTokens
  759. except TencentCloudSDKException as e:
  760. return ans + "\n**ERROR**: " + str(e), 0
  761. def describe_with_prompt(self, image, prompt=None):
  762. from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
  763. from tencentcloud.hunyuan.v20230901 import models
  764. b64 = self.image2base64(image)
  765. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  766. req = models.ChatCompletionsRequest()
  767. params = {"Model": self.model_name, "Messages": vision_prompt}
  768. req.from_json_string(json.dumps(params))
  769. ans = ""
  770. try:
  771. response = self.client.ChatCompletions(req)
  772. ans = response.Choices[0].Message.Content
  773. return ans, response.Usage.TotalTokens
  774. except TencentCloudSDKException as e:
  775. return ans + "\n**ERROR**: " + str(e), 0
  776. def prompt(self, b64):
  777. return [
  778. {
  779. "Role": "user",
  780. "Contents": [
  781. {
  782. "Type": "image_url",
  783. "ImageUrl": {"Url": f"data:image/jpeg;base64,{b64}"},
  784. },
  785. {
  786. "Type": "text",
  787. "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  788. if self.lang.lower() == "chinese"
  789. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  790. },
  791. ],
  792. }
  793. ]
  794. class AnthropicCV(Base):
  795. _FACTORY_NAME = "Anthropic"
  796. def __init__(self, key, model_name, base_url=None):
  797. import anthropic
  798. self.client = anthropic.Anthropic(api_key=key)
  799. self.model_name = model_name
  800. self.system = ""
  801. self.max_tokens = 8192
  802. if "haiku" in self.model_name or "opus" in self.model_name:
  803. self.max_tokens = 4096
  804. def prompt(self, b64, prompt):
  805. return [
  806. {
  807. "role": "user",
  808. "content": [
  809. {
  810. "type": "image",
  811. "source": {
  812. "type": "base64",
  813. "media_type": "image/jpeg",
  814. "data": b64,
  815. },
  816. },
  817. {"type": "text", "text": prompt},
  818. ],
  819. }
  820. ]
  821. def describe(self, image):
  822. b64 = self.image2base64(image)
  823. prompt = self.prompt(
  824. b64,
  825. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  826. if self.lang.lower() == "chinese"
  827. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  828. )
  829. response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
  830. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
  831. def describe_with_prompt(self, image, prompt=None):
  832. b64 = self.image2base64(image)
  833. prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
  834. response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
  835. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
  836. def chat(self, system, history, gen_conf):
  837. if "presence_penalty" in gen_conf:
  838. del gen_conf["presence_penalty"]
  839. if "frequency_penalty" in gen_conf:
  840. del gen_conf["frequency_penalty"]
  841. gen_conf["max_tokens"] = self.max_tokens
  842. ans = ""
  843. try:
  844. response = self.client.messages.create(
  845. model=self.model_name,
  846. messages=history,
  847. system=system,
  848. stream=False,
  849. **gen_conf,
  850. ).to_dict()
  851. ans = response["content"][0]["text"]
  852. if response["stop_reason"] == "max_tokens":
  853. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  854. return (
  855. ans,
  856. response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
  857. )
  858. except Exception as e:
  859. return ans + "\n**ERROR**: " + str(e), 0
  860. def chat_streamly(self, system, history, gen_conf):
  861. if "presence_penalty" in gen_conf:
  862. del gen_conf["presence_penalty"]
  863. if "frequency_penalty" in gen_conf:
  864. del gen_conf["frequency_penalty"]
  865. gen_conf["max_tokens"] = self.max_tokens
  866. ans = ""
  867. total_tokens = 0
  868. try:
  869. response = self.client.messages.create(
  870. model=self.model_name,
  871. messages=history,
  872. system=system,
  873. stream=True,
  874. **gen_conf,
  875. )
  876. for res in response:
  877. if res.type == "content_block_delta":
  878. if res.delta.type == "thinking_delta" and res.delta.thinking:
  879. if ans.find("<think>") < 0:
  880. ans += "<think>"
  881. ans = ans.replace("</think>", "")
  882. ans += res.delta.thinking + "</think>"
  883. else:
  884. text = res.delta.text
  885. ans += text
  886. total_tokens += num_tokens_from_string(text)
  887. yield ans
  888. except Exception as e:
  889. yield ans + "\n**ERROR**: " + str(e)
  890. yield total_tokens
  891. class GPUStackCV(GptV4):
  892. _FACTORY_NAME = "GPUStack"
  893. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  894. if not base_url:
  895. raise ValueError("Local llm url cannot be None")
  896. base_url = urljoin(base_url, "v1")
  897. self.client = OpenAI(api_key=key, base_url=base_url)
  898. self.model_name = model_name
  899. self.lang = lang
  900. class GoogleCV(Base):
  901. _FACTORY_NAME = "Google Cloud"
  902. def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
  903. import base64
  904. from google.oauth2 import service_account
  905. key = json.loads(key)
  906. access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
  907. project_id = key.get("google_project_id", "")
  908. region = key.get("google_region", "")
  909. scopes = ["https://www.googleapis.com/auth/cloud-platform"]
  910. self.model_name = model_name
  911. self.lang = lang
  912. if "claude" in self.model_name:
  913. from anthropic import AnthropicVertex
  914. from google.auth.transport.requests import Request
  915. if access_token:
  916. credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
  917. request = Request()
  918. credits.refresh(request)
  919. token = credits.token
  920. self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
  921. else:
  922. self.client = AnthropicVertex(region=region, project_id=project_id)
  923. else:
  924. import vertexai.generative_models as glm
  925. from google.cloud import aiplatform
  926. if access_token:
  927. credits = service_account.Credentials.from_service_account_info(access_token)
  928. aiplatform.init(credentials=credits, project=project_id, location=region)
  929. else:
  930. aiplatform.init(project=project_id, location=region)
  931. self.client = glm.GenerativeModel(model_name=self.model_name)
  932. def describe(self, image):
  933. prompt = (
  934. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  935. if self.lang.lower() == "chinese"
  936. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  937. )
  938. if "claude" in self.model_name:
  939. b64 = self.image2base64(image)
  940. vision_prompt = [
  941. {
  942. "role": "user",
  943. "content": [
  944. {
  945. "type": "image",
  946. "source": {
  947. "type": "base64",
  948. "media_type": "image/jpeg",
  949. "data": b64,
  950. },
  951. },
  952. {"type": "text", "text": prompt},
  953. ],
  954. }
  955. ]
  956. response = self.client.messages.create(
  957. model=self.model_name,
  958. max_tokens=8192,
  959. messages=vision_prompt,
  960. )
  961. return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
  962. else:
  963. import vertexai.generative_models as glm
  964. b64 = self.image2base64(image)
  965. # Create proper image part for Gemini
  966. image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
  967. input = [prompt, image_part]
  968. res = self.client.generate_content(input)
  969. return res.text, res.usage_metadata.total_token_count
  970. def describe_with_prompt(self, image, prompt=None):
  971. if "claude" in self.model_name:
  972. b64 = self.image2base64(image)
  973. vision_prompt = [
  974. {
  975. "role": "user",
  976. "content": [
  977. {
  978. "type": "image",
  979. "source": {
  980. "type": "base64",
  981. "media_type": "image/jpeg",
  982. "data": b64,
  983. },
  984. },
  985. {"type": "text", "text": prompt if prompt else vision_llm_describe_prompt()},
  986. ],
  987. }
  988. ]
  989. response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=vision_prompt)
  990. return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
  991. else:
  992. import vertexai.generative_models as glm
  993. b64 = self.image2base64(image)
  994. vision_prompt = prompt if prompt else vision_llm_describe_prompt()
  995. # Create proper image part for Gemini
  996. image_part = glm.Part.from_data(data=base64.b64decode(b64), mime_type="image/jpeg")
  997. input = [vision_prompt, image_part]
  998. res = self.client.generate_content(input)
  999. return res.text, res.usage_metadata.total_token_count
  1000. def chat(self, system, history, gen_conf, image=""):
  1001. if "claude" in self.model_name:
  1002. if system:
  1003. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  1004. try:
  1005. for his in history:
  1006. if his["role"] == "user":
  1007. his["content"] = [
  1008. {
  1009. "type": "image",
  1010. "source": {
  1011. "type": "base64",
  1012. "media_type": "image/jpeg",
  1013. "data": image,
  1014. },
  1015. },
  1016. {"type": "text", "text": his["content"]},
  1017. ]
  1018. response = self.client.messages.create(model=self.model_name, max_tokens=8192, messages=history, temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
  1019. return response.content[0].text.strip(), response.usage.input_tokens + response.usage.output_tokens
  1020. except Exception as e:
  1021. return "**ERROR**: " + str(e), 0
  1022. else:
  1023. import vertexai.generative_models as glm
  1024. from transformers import GenerationConfig
  1025. if system:
  1026. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  1027. try:
  1028. for his in history:
  1029. if his["role"] == "assistant":
  1030. his["role"] = "model"
  1031. his["parts"] = [his["content"]]
  1032. his.pop("content")
  1033. if his["role"] == "user":
  1034. his["parts"] = [his["content"]]
  1035. his.pop("content")
  1036. # Create proper image part for Gemini
  1037. img_bytes = base64.b64decode(image)
  1038. image_part = glm.Part.from_data(data=img_bytes, mime_type="image/jpeg")
  1039. history[-1]["parts"].append(image_part)
  1040. response = self.client.generate_content(history, generation_config=GenerationConfig(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7)))
  1041. ans = response.text
  1042. return ans, response.usage_metadata.total_token_count
  1043. except Exception as e:
  1044. return "**ERROR**: " + str(e), 0