Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

cv_model.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. import io
  19. from abc import ABC
  20. from ollama import Client
  21. from PIL import Image
  22. from openai import OpenAI
  23. import os
  24. import base64
  25. from io import BytesIO
  26. import json
  27. import requests
  28. from transformers import GenerationConfig
  29. from rag.nlp import is_english
  30. from api.utils import get_uuid
  31. from api.utils.file_utils import get_project_base_directory
  32. class Base(ABC):
  33. def __init__(self, key, model_name):
  34. pass
  35. def describe(self, image, max_tokens=300):
  36. raise NotImplementedError("Please implement encode method!")
  37. def chat(self, system, history, gen_conf, image=""):
  38. if system:
  39. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  40. try:
  41. for his in history:
  42. if his["role"] == "user":
  43. his["content"] = self.chat_prompt(his["content"], image)
  44. response = self.client.chat.completions.create(
  45. model=self.model_name,
  46. messages=history,
  47. max_tokens=gen_conf.get("max_tokens", 1000),
  48. temperature=gen_conf.get("temperature", 0.3),
  49. top_p=gen_conf.get("top_p", 0.7)
  50. )
  51. return response.choices[0].message.content.strip(), response.usage.total_tokens
  52. except Exception as e:
  53. return "**ERROR**: " + str(e), 0
  54. def chat_streamly(self, system, history, gen_conf, image=""):
  55. if system:
  56. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  57. ans = ""
  58. tk_count = 0
  59. try:
  60. for his in history:
  61. if his["role"] == "user":
  62. his["content"] = self.chat_prompt(his["content"], image)
  63. response = self.client.chat.completions.create(
  64. model=self.model_name,
  65. messages=history,
  66. max_tokens=gen_conf.get("max_tokens", 1000),
  67. temperature=gen_conf.get("temperature", 0.3),
  68. top_p=gen_conf.get("top_p", 0.7),
  69. stream=True
  70. )
  71. for resp in response:
  72. if not resp.choices[0].delta.content:
  73. continue
  74. delta = resp.choices[0].delta.content
  75. ans += delta
  76. if resp.choices[0].finish_reason == "length":
  77. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  78. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  79. tk_count = resp.usage.total_tokens
  80. if resp.choices[0].finish_reason == "stop":
  81. tk_count = resp.usage.total_tokens
  82. yield ans
  83. except Exception as e:
  84. yield ans + "\n**ERROR**: " + str(e)
  85. yield tk_count
  86. def image2base64(self, image):
  87. if isinstance(image, bytes):
  88. return base64.b64encode(image).decode("utf-8")
  89. if isinstance(image, BytesIO):
  90. return base64.b64encode(image.getvalue()).decode("utf-8")
  91. buffered = BytesIO()
  92. try:
  93. image.save(buffered, format="JPEG")
  94. except Exception:
  95. image.save(buffered, format="PNG")
  96. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  97. def prompt(self, b64):
  98. return [
  99. {
  100. "role": "user",
  101. "content": [
  102. {
  103. "type": "image_url",
  104. "image_url": {
  105. "url": f"data:image/jpeg;base64,{b64}"
  106. },
  107. },
  108. {
  109. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  110. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  111. },
  112. ],
  113. }
  114. ]
  115. def chat_prompt(self, text, b64):
  116. return [
  117. {
  118. "type": "image_url",
  119. "image_url": {
  120. "url": f"data:image/jpeg;base64,{b64}",
  121. },
  122. },
  123. {
  124. "type": "text",
  125. "text": text
  126. },
  127. ]
  128. class GptV4(Base):
  129. def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
  130. if not base_url:
  131. base_url="https://api.openai.com/v1"
  132. self.client = OpenAI(api_key=key, base_url=base_url)
  133. self.model_name = model_name
  134. self.lang = lang
  135. def describe(self, image, max_tokens=300):
  136. b64 = self.image2base64(image)
  137. prompt = self.prompt(b64)
  138. for i in range(len(prompt)):
  139. for c in prompt[i]["content"]:
  140. if "text" in c:
  141. c["type"] = "text"
  142. res = self.client.chat.completions.create(
  143. model=self.model_name,
  144. messages=prompt,
  145. max_tokens=max_tokens,
  146. )
  147. return res.choices[0].message.content.strip(), res.usage.total_tokens
  148. class AzureGptV4(Base):
  149. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  150. api_key = json.loads(key).get('api_key', '')
  151. api_version = json.loads(key).get('api_version', '2024-02-01')
  152. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  153. self.model_name = model_name
  154. self.lang = lang
  155. def describe(self, image, max_tokens=300):
  156. b64 = self.image2base64(image)
  157. prompt = self.prompt(b64)
  158. for i in range(len(prompt)):
  159. for c in prompt[i]["content"]:
  160. if "text" in c:
  161. c["type"] = "text"
  162. res = self.client.chat.completions.create(
  163. model=self.model_name,
  164. messages=prompt,
  165. max_tokens=max_tokens,
  166. )
  167. return res.choices[0].message.content.strip(), res.usage.total_tokens
  168. class QWenCV(Base):
  169. def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
  170. import dashscope
  171. dashscope.api_key = key
  172. self.model_name = model_name
  173. self.lang = lang
  174. def prompt(self, binary):
  175. # stupid as hell
  176. tmp_dir = get_project_base_directory("tmp")
  177. if not os.path.exists(tmp_dir):
  178. os.mkdir(tmp_dir)
  179. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  180. Image.open(io.BytesIO(binary)).save(path)
  181. return [
  182. {
  183. "role": "user",
  184. "content": [
  185. {
  186. "image": f"file://{path}"
  187. },
  188. {
  189. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  190. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  191. },
  192. ],
  193. }
  194. ]
  195. def chat_prompt(self, text, b64):
  196. return [
  197. {"image": f"{b64}"},
  198. {"text": text},
  199. ]
  200. def describe(self, image, max_tokens=300):
  201. from http import HTTPStatus
  202. from dashscope import MultiModalConversation
  203. response = MultiModalConversation.call(model=self.model_name,
  204. messages=self.prompt(image))
  205. if response.status_code == HTTPStatus.OK:
  206. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  207. return response.message, 0
  208. def chat(self, system, history, gen_conf, image=""):
  209. from http import HTTPStatus
  210. from dashscope import MultiModalConversation
  211. if system:
  212. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  213. for his in history:
  214. if his["role"] == "user":
  215. his["content"] = self.chat_prompt(his["content"], image)
  216. response = MultiModalConversation.call(model=self.model_name, messages=history,
  217. max_tokens=gen_conf.get("max_tokens", 1000),
  218. temperature=gen_conf.get("temperature", 0.3),
  219. top_p=gen_conf.get("top_p", 0.7))
  220. ans = ""
  221. tk_count = 0
  222. if response.status_code == HTTPStatus.OK:
  223. ans += response.output.choices[0]['message']['content']
  224. tk_count += response.usage.total_tokens
  225. if response.output.choices[0].get("finish_reason", "") == "length":
  226. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  227. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  228. return ans, tk_count
  229. return "**ERROR**: " + response.message, tk_count
  230. def chat_streamly(self, system, history, gen_conf, image=""):
  231. from http import HTTPStatus
  232. from dashscope import MultiModalConversation
  233. if system:
  234. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  235. for his in history:
  236. if his["role"] == "user":
  237. his["content"] = self.chat_prompt(his["content"], image)
  238. ans = ""
  239. tk_count = 0
  240. try:
  241. response = MultiModalConversation.call(model=self.model_name, messages=history,
  242. max_tokens=gen_conf.get("max_tokens", 1000),
  243. temperature=gen_conf.get("temperature", 0.3),
  244. top_p=gen_conf.get("top_p", 0.7),
  245. stream=True)
  246. for resp in response:
  247. if resp.status_code == HTTPStatus.OK:
  248. ans = resp.output.choices[0]['message']['content']
  249. tk_count = resp.usage.total_tokens
  250. if resp.output.choices[0].get("finish_reason", "") == "length":
  251. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  252. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  253. yield ans
  254. else:
  255. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
  256. "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  257. except Exception as e:
  258. yield ans + "\n**ERROR**: " + str(e)
  259. yield tk_count
  260. class Zhipu4V(Base):
  261. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  262. self.client = ZhipuAI(api_key=key)
  263. self.model_name = model_name
  264. self.lang = lang
  265. def describe(self, image, max_tokens=1024):
  266. b64 = self.image2base64(image)
  267. prompt = self.prompt(b64)
  268. prompt[0]["content"][1]["type"] = "text"
  269. res = self.client.chat.completions.create(
  270. model=self.model_name,
  271. messages=prompt,
  272. max_tokens=max_tokens,
  273. )
  274. return res.choices[0].message.content.strip(), res.usage.total_tokens
  275. def chat(self, system, history, gen_conf, image=""):
  276. if system:
  277. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  278. try:
  279. for his in history:
  280. if his["role"] == "user":
  281. his["content"] = self.chat_prompt(his["content"], image)
  282. response = self.client.chat.completions.create(
  283. model=self.model_name,
  284. messages=history,
  285. max_tokens=gen_conf.get("max_tokens", 1000),
  286. temperature=gen_conf.get("temperature", 0.3),
  287. top_p=gen_conf.get("top_p", 0.7)
  288. )
  289. return response.choices[0].message.content.strip(), response.usage.total_tokens
  290. except Exception as e:
  291. return "**ERROR**: " + str(e), 0
  292. def chat_streamly(self, system, history, gen_conf, image=""):
  293. if system:
  294. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  295. ans = ""
  296. tk_count = 0
  297. try:
  298. for his in history:
  299. if his["role"] == "user":
  300. his["content"] = self.chat_prompt(his["content"], image)
  301. response = self.client.chat.completions.create(
  302. model=self.model_name,
  303. messages=history,
  304. max_tokens=gen_conf.get("max_tokens", 1000),
  305. temperature=gen_conf.get("temperature", 0.3),
  306. top_p=gen_conf.get("top_p", 0.7),
  307. stream=True
  308. )
  309. for resp in response:
  310. if not resp.choices[0].delta.content:
  311. continue
  312. delta = resp.choices[0].delta.content
  313. ans += delta
  314. if resp.choices[0].finish_reason == "length":
  315. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  316. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  317. tk_count = resp.usage.total_tokens
  318. if resp.choices[0].finish_reason == "stop":
  319. tk_count = resp.usage.total_tokens
  320. yield ans
  321. except Exception as e:
  322. yield ans + "\n**ERROR**: " + str(e)
  323. yield tk_count
  324. class OllamaCV(Base):
  325. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  326. self.client = Client(host=kwargs["base_url"])
  327. self.model_name = model_name
  328. self.lang = lang
  329. def describe(self, image, max_tokens=1024):
  330. prompt = self.prompt("")
  331. try:
  332. options = {"num_predict": max_tokens}
  333. response = self.client.generate(
  334. model=self.model_name,
  335. prompt=prompt[0]["content"][1]["text"],
  336. images=[image],
  337. options=options
  338. )
  339. ans = response["response"].strip()
  340. return ans, 128
  341. except Exception as e:
  342. return "**ERROR**: " + str(e), 0
  343. def chat(self, system, history, gen_conf, image=""):
  344. if system:
  345. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  346. try:
  347. for his in history:
  348. if his["role"] == "user":
  349. his["images"] = [image]
  350. options = {}
  351. if "temperature" in gen_conf:
  352. options["temperature"] = gen_conf["temperature"]
  353. if "max_tokens" in gen_conf:
  354. options["num_predict"] = gen_conf["max_tokens"]
  355. if "top_p" in gen_conf:
  356. options["top_k"] = gen_conf["top_p"]
  357. if "presence_penalty" in gen_conf:
  358. options["presence_penalty"] = gen_conf["presence_penalty"]
  359. if "frequency_penalty" in gen_conf:
  360. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  361. response = self.client.chat(
  362. model=self.model_name,
  363. messages=history,
  364. options=options,
  365. keep_alive=-1
  366. )
  367. ans = response["message"]["content"].strip()
  368. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  369. except Exception as e:
  370. return "**ERROR**: " + str(e), 0
  371. def chat_streamly(self, system, history, gen_conf, image=""):
  372. if system:
  373. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  374. for his in history:
  375. if his["role"] == "user":
  376. his["images"] = [image]
  377. options = {}
  378. if "temperature" in gen_conf:
  379. options["temperature"] = gen_conf["temperature"]
  380. if "max_tokens" in gen_conf:
  381. options["num_predict"] = gen_conf["max_tokens"]
  382. if "top_p" in gen_conf:
  383. options["top_k"] = gen_conf["top_p"]
  384. if "presence_penalty" in gen_conf:
  385. options["presence_penalty"] = gen_conf["presence_penalty"]
  386. if "frequency_penalty" in gen_conf:
  387. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  388. ans = ""
  389. try:
  390. response = self.client.chat(
  391. model=self.model_name,
  392. messages=history,
  393. stream=True,
  394. options=options,
  395. keep_alive=-1
  396. )
  397. for resp in response:
  398. if resp["done"]:
  399. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  400. ans += resp["message"]["content"]
  401. yield ans
  402. except Exception as e:
  403. yield ans + "\n**ERROR**: " + str(e)
  404. yield 0
  405. class LocalAICV(GptV4):
  406. def __init__(self, key, model_name, base_url, lang="Chinese"):
  407. if not base_url:
  408. raise ValueError("Local cv model url cannot be None")
  409. if base_url.split("/")[-1] != "v1":
  410. base_url = os.path.join(base_url, "v1")
  411. self.client = OpenAI(api_key="empty", base_url=base_url)
  412. self.model_name = model_name.split("___")[0]
  413. self.lang = lang
  414. class XinferenceCV(Base):
  415. def __init__(self, key, model_name="", lang="Chinese", base_url=""):
  416. if base_url.split("/")[-1] != "v1":
  417. base_url = os.path.join(base_url, "v1")
  418. self.client = OpenAI(api_key=key, base_url=base_url)
  419. self.model_name = model_name
  420. self.lang = lang
  421. def describe(self, image, max_tokens=300):
  422. b64 = self.image2base64(image)
  423. res = self.client.chat.completions.create(
  424. model=self.model_name,
  425. messages=self.prompt(b64),
  426. max_tokens=max_tokens,
  427. )
  428. return res.choices[0].message.content.strip(), res.usage.total_tokens
  429. class GeminiCV(Base):
  430. def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
  431. from google.generativeai import client, GenerativeModel
  432. client.configure(api_key=key)
  433. _client = client.get_default_generative_client()
  434. self.model_name = model_name
  435. self.model = GenerativeModel(model_name=self.model_name)
  436. self.model._client = _client
  437. self.lang = lang
  438. def describe(self, image, max_tokens=2048):
  439. from PIL.Image import open
  440. gen_config = {'max_output_tokens':max_tokens}
  441. prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
  442. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  443. b64 = self.image2base64(image)
  444. img = open(BytesIO(base64.b64decode(b64)))
  445. input = [prompt,img]
  446. res = self.model.generate_content(
  447. input,
  448. generation_config=gen_config,
  449. )
  450. return res.text,res.usage_metadata.total_token_count
  451. def chat(self, system, history, gen_conf, image=""):
  452. if system:
  453. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  454. try:
  455. for his in history:
  456. if his["role"] == "assistant":
  457. his["role"] = "model"
  458. his["parts"] = [his["content"]]
  459. his.pop("content")
  460. if his["role"] == "user":
  461. his["parts"] = [his["content"]]
  462. his.pop("content")
  463. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  464. response = self.model.generate_content(history, generation_config=GenerationConfig(
  465. max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
  466. top_p=gen_conf.get("top_p", 0.7)))
  467. ans = response.text
  468. return ans, response.usage_metadata.total_token_count
  469. except Exception as e:
  470. return "**ERROR**: " + str(e), 0
  471. def chat_streamly(self, system, history, gen_conf, image=""):
  472. if system:
  473. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  474. ans = ""
  475. try:
  476. for his in history:
  477. if his["role"] == "assistant":
  478. his["role"] = "model"
  479. his["parts"] = [his["content"]]
  480. his.pop("content")
  481. if his["role"] == "user":
  482. his["parts"] = [his["content"]]
  483. his.pop("content")
  484. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  485. response = self.model.generate_content(history, generation_config=GenerationConfig(
  486. max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
  487. top_p=gen_conf.get("top_p", 0.7)), stream=True)
  488. for resp in response:
  489. if not resp.text:
  490. continue
  491. ans += resp.text
  492. yield ans
  493. except Exception as e:
  494. yield ans + "\n**ERROR**: " + str(e)
  495. yield response._chunks[-1].usage_metadata.total_token_count
  496. class OpenRouterCV(GptV4):
  497. def __init__(
  498. self,
  499. key,
  500. model_name,
  501. lang="Chinese",
  502. base_url="https://openrouter.ai/api/v1",
  503. ):
  504. if not base_url:
  505. base_url = "https://openrouter.ai/api/v1"
  506. self.client = OpenAI(api_key=key, base_url=base_url)
  507. self.model_name = model_name
  508. self.lang = lang
  509. class LocalCV(Base):
  510. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  511. pass
  512. def describe(self, image, max_tokens=1024):
  513. return "", 0
  514. class NvidiaCV(Base):
  515. def __init__(
  516. self,
  517. key,
  518. model_name,
  519. lang="Chinese",
  520. base_url="https://ai.api.nvidia.com/v1/vlm",
  521. ):
  522. if not base_url:
  523. base_url = ("https://ai.api.nvidia.com/v1/vlm",)
  524. self.lang = lang
  525. factory, llm_name = model_name.split("/")
  526. if factory != "liuhaotian":
  527. self.base_url = os.path.join(base_url, factory, llm_name)
  528. else:
  529. self.base_url = os.path.join(
  530. base_url, "community", llm_name.replace("-v1.6", "16")
  531. )
  532. self.key = key
  533. def describe(self, image, max_tokens=1024):
  534. b64 = self.image2base64(image)
  535. response = requests.post(
  536. url=self.base_url,
  537. headers={
  538. "accept": "application/json",
  539. "content-type": "application/json",
  540. "Authorization": f"Bearer {self.key}",
  541. },
  542. json={
  543. "messages": self.prompt(b64),
  544. "max_tokens": max_tokens,
  545. },
  546. )
  547. response = response.json()
  548. return (
  549. response["choices"][0]["message"]["content"].strip(),
  550. response["usage"]["total_tokens"],
  551. )
  552. def prompt(self, b64):
  553. return [
  554. {
  555. "role": "user",
  556. "content": (
  557. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  558. if self.lang.lower() == "chinese"
  559. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  560. )
  561. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  562. }
  563. ]
  564. def chat_prompt(self, text, b64):
  565. return [
  566. {
  567. "role": "user",
  568. "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
  569. }
  570. ]
  571. class StepFunCV(GptV4):
  572. def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
  573. if not base_url:
  574. base_url="https://api.stepfun.com/v1"
  575. self.client = OpenAI(api_key=key, base_url=base_url)
  576. self.model_name = model_name
  577. self.lang = lang
  578. class LmStudioCV(GptV4):
  579. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  580. if not base_url:
  581. raise ValueError("Local llm url cannot be None")
  582. if base_url.split("/")[-1] != "v1":
  583. base_url = os.path.join(base_url, "v1")
  584. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  585. self.model_name = model_name
  586. self.lang = lang
  587. class OpenAI_APICV(GptV4):
  588. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  589. if not base_url:
  590. raise ValueError("url cannot be None")
  591. if base_url.split("/")[-1] != "v1":
  592. base_url = os.path.join(base_url, "v1")
  593. self.client = OpenAI(api_key=key, base_url=base_url)
  594. self.model_name = model_name.split("___")[0]
  595. self.lang = lang
  596. class TogetherAICV(GptV4):
  597. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
  598. if not base_url:
  599. base_url = "https://api.together.xyz/v1"
  600. super().__init__(key, model_name,lang,base_url)
  601. class YiCV(GptV4):
  602. def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
  603. if not base_url:
  604. base_url = "https://api.lingyiwanwu.com/v1"
  605. super().__init__(key, model_name,lang,base_url)
  606. class HunyuanCV(Base):
  607. def __init__(self, key, model_name, lang="Chinese",base_url=None):
  608. from tencentcloud.common import credential
  609. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  610. key = json.loads(key)
  611. sid = key.get("hunyuan_sid", "")
  612. sk = key.get("hunyuan_sk", "")
  613. cred = credential.Credential(sid, sk)
  614. self.model_name = model_name
  615. self.client = hunyuan_client.HunyuanClient(cred, "")
  616. self.lang = lang
  617. def describe(self, image, max_tokens=4096):
  618. from tencentcloud.hunyuan.v20230901 import models
  619. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  620. TencentCloudSDKException,
  621. )
  622. b64 = self.image2base64(image)
  623. req = models.ChatCompletionsRequest()
  624. params = {"Model": self.model_name, "Messages": self.prompt(b64)}
  625. req.from_json_string(json.dumps(params))
  626. ans = ""
  627. try:
  628. response = self.client.ChatCompletions(req)
  629. ans = response.Choices[0].Message.Content
  630. return ans, response.Usage.TotalTokens
  631. except TencentCloudSDKException as e:
  632. return ans + "\n**ERROR**: " + str(e), 0
  633. def prompt(self, b64):
  634. return [
  635. {
  636. "Role": "user",
  637. "Contents": [
  638. {
  639. "Type": "image_url",
  640. "ImageUrl": {
  641. "Url": f"data:image/jpeg;base64,{b64}"
  642. },
  643. },
  644. {
  645. "Type": "text",
  646. "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  647. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  648. },
  649. ],
  650. }
  651. ]