You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cv_model.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. import io
  19. from abc import ABC
  20. from ollama import Client
  21. from PIL import Image
  22. from openai import OpenAI
  23. import os
  24. import base64
  25. from io import BytesIO
  26. import json
  27. import requests
  28. from rag.nlp import is_english
  29. from api.utils import get_uuid
  30. from api.utils.file_utils import get_project_base_directory
  31. class Base(ABC):
  32. def __init__(self, key, model_name):
  33. pass
  34. def describe(self, image, max_tokens=300):
  35. raise NotImplementedError("Please implement encode method!")
  36. def chat(self, system, history, gen_conf, image=""):
  37. if system:
  38. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  39. try:
  40. for his in history:
  41. if his["role"] == "user":
  42. his["content"] = self.chat_prompt(his["content"], image)
  43. response = self.client.chat.completions.create(
  44. model=self.model_name,
  45. messages=history,
  46. max_tokens=gen_conf.get("max_tokens", 1000),
  47. temperature=gen_conf.get("temperature", 0.3),
  48. top_p=gen_conf.get("top_p", 0.7)
  49. )
  50. return response.choices[0].message.content.strip(), response.usage.total_tokens
  51. except Exception as e:
  52. return "**ERROR**: " + str(e), 0
  53. def chat_streamly(self, system, history, gen_conf, image=""):
  54. if system:
  55. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  56. ans = ""
  57. tk_count = 0
  58. try:
  59. for his in history:
  60. if his["role"] == "user":
  61. his["content"] = self.chat_prompt(his["content"], image)
  62. response = self.client.chat.completions.create(
  63. model=self.model_name,
  64. messages=history,
  65. max_tokens=gen_conf.get("max_tokens", 1000),
  66. temperature=gen_conf.get("temperature", 0.3),
  67. top_p=gen_conf.get("top_p", 0.7),
  68. stream=True
  69. )
  70. for resp in response:
  71. if not resp.choices[0].delta.content:
  72. continue
  73. delta = resp.choices[0].delta.content
  74. ans += delta
  75. if resp.choices[0].finish_reason == "length":
  76. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  77. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  78. tk_count = resp.usage.total_tokens
  79. if resp.choices[0].finish_reason == "stop":
  80. tk_count = resp.usage.total_tokens
  81. yield ans
  82. except Exception as e:
  83. yield ans + "\n**ERROR**: " + str(e)
  84. yield tk_count
  85. def image2base64(self, image):
  86. if isinstance(image, bytes):
  87. return base64.b64encode(image).decode("utf-8")
  88. if isinstance(image, BytesIO):
  89. return base64.b64encode(image.getvalue()).decode("utf-8")
  90. buffered = BytesIO()
  91. try:
  92. image.save(buffered, format="JPEG")
  93. except Exception:
  94. image.save(buffered, format="PNG")
  95. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  96. def prompt(self, b64):
  97. return [
  98. {
  99. "role": "user",
  100. "content": [
  101. {
  102. "type": "image_url",
  103. "image_url": {
  104. "url": f"data:image/jpeg;base64,{b64}"
  105. },
  106. },
  107. {
  108. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  109. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  110. },
  111. ],
  112. }
  113. ]
  114. def chat_prompt(self, text, b64):
  115. return [
  116. {
  117. "type": "image_url",
  118. "image_url": {
  119. "url": f"data:image/jpeg;base64,{b64}",
  120. },
  121. },
  122. {
  123. "type": "text",
  124. "text": text
  125. },
  126. ]
  127. class GptV4(Base):
  128. def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
  129. if not base_url:
  130. base_url="https://api.openai.com/v1"
  131. self.client = OpenAI(api_key=key, base_url=base_url)
  132. self.model_name = model_name
  133. self.lang = lang
  134. def describe(self, image, max_tokens=300):
  135. b64 = self.image2base64(image)
  136. prompt = self.prompt(b64)
  137. for i in range(len(prompt)):
  138. for c in prompt[i]["content"]:
  139. if "text" in c:
  140. c["type"] = "text"
  141. res = self.client.chat.completions.create(
  142. model=self.model_name,
  143. messages=prompt,
  144. max_tokens=max_tokens,
  145. )
  146. return res.choices[0].message.content.strip(), res.usage.total_tokens
  147. class AzureGptV4(Base):
  148. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  149. api_key = json.loads(key).get('api_key', '')
  150. api_version = json.loads(key).get('api_version', '2024-02-01')
  151. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  152. self.model_name = model_name
  153. self.lang = lang
  154. def describe(self, image, max_tokens=300):
  155. b64 = self.image2base64(image)
  156. prompt = self.prompt(b64)
  157. for i in range(len(prompt)):
  158. for c in prompt[i]["content"]:
  159. if "text" in c:
  160. c["type"] = "text"
  161. res = self.client.chat.completions.create(
  162. model=self.model_name,
  163. messages=prompt,
  164. max_tokens=max_tokens,
  165. )
  166. return res.choices[0].message.content.strip(), res.usage.total_tokens
  167. class QWenCV(Base):
  168. def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
  169. import dashscope
  170. dashscope.api_key = key
  171. self.model_name = model_name
  172. self.lang = lang
  173. def prompt(self, binary):
  174. # stupid as hell
  175. tmp_dir = get_project_base_directory("tmp")
  176. if not os.path.exists(tmp_dir):
  177. os.mkdir(tmp_dir)
  178. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  179. Image.open(io.BytesIO(binary)).save(path)
  180. return [
  181. {
  182. "role": "user",
  183. "content": [
  184. {
  185. "image": f"file://{path}"
  186. },
  187. {
  188. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  189. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  190. },
  191. ],
  192. }
  193. ]
  194. def chat_prompt(self, text, b64):
  195. return [
  196. {"image": f"{b64}"},
  197. {"text": text},
  198. ]
  199. def describe(self, image, max_tokens=300):
  200. from http import HTTPStatus
  201. from dashscope import MultiModalConversation
  202. response = MultiModalConversation.call(model=self.model_name,
  203. messages=self.prompt(image))
  204. if response.status_code == HTTPStatus.OK:
  205. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  206. return response.message, 0
  207. def chat(self, system, history, gen_conf, image=""):
  208. from http import HTTPStatus
  209. from dashscope import MultiModalConversation
  210. if system:
  211. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  212. for his in history:
  213. if his["role"] == "user":
  214. his["content"] = self.chat_prompt(his["content"], image)
  215. response = MultiModalConversation.call(model=self.model_name, messages=history,
  216. max_tokens=gen_conf.get("max_tokens", 1000),
  217. temperature=gen_conf.get("temperature", 0.3),
  218. top_p=gen_conf.get("top_p", 0.7))
  219. ans = ""
  220. tk_count = 0
  221. if response.status_code == HTTPStatus.OK:
  222. ans += response.output.choices[0]['message']['content']
  223. tk_count += response.usage.total_tokens
  224. if response.output.choices[0].get("finish_reason", "") == "length":
  225. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  226. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  227. return ans, tk_count
  228. return "**ERROR**: " + response.message, tk_count
  229. def chat_streamly(self, system, history, gen_conf, image=""):
  230. from http import HTTPStatus
  231. from dashscope import MultiModalConversation
  232. if system:
  233. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  234. for his in history:
  235. if his["role"] == "user":
  236. his["content"] = self.chat_prompt(his["content"], image)
  237. ans = ""
  238. tk_count = 0
  239. try:
  240. response = MultiModalConversation.call(model=self.model_name, messages=history,
  241. max_tokens=gen_conf.get("max_tokens", 1000),
  242. temperature=gen_conf.get("temperature", 0.3),
  243. top_p=gen_conf.get("top_p", 0.7),
  244. stream=True)
  245. for resp in response:
  246. if resp.status_code == HTTPStatus.OK:
  247. ans = resp.output.choices[0]['message']['content']
  248. tk_count = resp.usage.total_tokens
  249. if resp.output.choices[0].get("finish_reason", "") == "length":
  250. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  251. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  252. yield ans
  253. else:
  254. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
  255. "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  256. except Exception as e:
  257. yield ans + "\n**ERROR**: " + str(e)
  258. yield tk_count
  259. class Zhipu4V(Base):
  260. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  261. self.client = ZhipuAI(api_key=key)
  262. self.model_name = model_name
  263. self.lang = lang
  264. def describe(self, image, max_tokens=1024):
  265. b64 = self.image2base64(image)
  266. prompt = self.prompt(b64)
  267. prompt[0]["content"][1]["type"] = "text"
  268. res = self.client.chat.completions.create(
  269. model=self.model_name,
  270. messages=prompt,
  271. max_tokens=max_tokens,
  272. )
  273. return res.choices[0].message.content.strip(), res.usage.total_tokens
  274. def chat(self, system, history, gen_conf, image=""):
  275. if system:
  276. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  277. try:
  278. for his in history:
  279. if his["role"] == "user":
  280. his["content"] = self.chat_prompt(his["content"], image)
  281. response = self.client.chat.completions.create(
  282. model=self.model_name,
  283. messages=history,
  284. max_tokens=gen_conf.get("max_tokens", 1000),
  285. temperature=gen_conf.get("temperature", 0.3),
  286. top_p=gen_conf.get("top_p", 0.7)
  287. )
  288. return response.choices[0].message.content.strip(), response.usage.total_tokens
  289. except Exception as e:
  290. return "**ERROR**: " + str(e), 0
  291. def chat_streamly(self, system, history, gen_conf, image=""):
  292. if system:
  293. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  294. ans = ""
  295. tk_count = 0
  296. try:
  297. for his in history:
  298. if his["role"] == "user":
  299. his["content"] = self.chat_prompt(his["content"], image)
  300. response = self.client.chat.completions.create(
  301. model=self.model_name,
  302. messages=history,
  303. max_tokens=gen_conf.get("max_tokens", 1000),
  304. temperature=gen_conf.get("temperature", 0.3),
  305. top_p=gen_conf.get("top_p", 0.7),
  306. stream=True
  307. )
  308. for resp in response:
  309. if not resp.choices[0].delta.content:
  310. continue
  311. delta = resp.choices[0].delta.content
  312. ans += delta
  313. if resp.choices[0].finish_reason == "length":
  314. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  315. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  316. tk_count = resp.usage.total_tokens
  317. if resp.choices[0].finish_reason == "stop":
  318. tk_count = resp.usage.total_tokens
  319. yield ans
  320. except Exception as e:
  321. yield ans + "\n**ERROR**: " + str(e)
  322. yield tk_count
  323. class OllamaCV(Base):
  324. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  325. self.client = Client(host=kwargs["base_url"])
  326. self.model_name = model_name
  327. self.lang = lang
  328. def describe(self, image, max_tokens=1024):
  329. prompt = self.prompt("")
  330. try:
  331. options = {"num_predict": max_tokens}
  332. response = self.client.generate(
  333. model=self.model_name,
  334. prompt=prompt[0]["content"][1]["text"],
  335. images=[image],
  336. options=options
  337. )
  338. ans = response["response"].strip()
  339. return ans, 128
  340. except Exception as e:
  341. return "**ERROR**: " + str(e), 0
  342. def chat(self, system, history, gen_conf, image=""):
  343. if system:
  344. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  345. try:
  346. for his in history:
  347. if his["role"] == "user":
  348. his["images"] = [image]
  349. options = {}
  350. if "temperature" in gen_conf:
  351. options["temperature"] = gen_conf["temperature"]
  352. if "max_tokens" in gen_conf:
  353. options["num_predict"] = gen_conf["max_tokens"]
  354. if "top_p" in gen_conf:
  355. options["top_k"] = gen_conf["top_p"]
  356. if "presence_penalty" in gen_conf:
  357. options["presence_penalty"] = gen_conf["presence_penalty"]
  358. if "frequency_penalty" in gen_conf:
  359. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  360. response = self.client.chat(
  361. model=self.model_name,
  362. messages=history,
  363. options=options,
  364. keep_alive=-1
  365. )
  366. ans = response["message"]["content"].strip()
  367. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  368. except Exception as e:
  369. return "**ERROR**: " + str(e), 0
  370. def chat_streamly(self, system, history, gen_conf, image=""):
  371. if system:
  372. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  373. for his in history:
  374. if his["role"] == "user":
  375. his["images"] = [image]
  376. options = {}
  377. if "temperature" in gen_conf:
  378. options["temperature"] = gen_conf["temperature"]
  379. if "max_tokens" in gen_conf:
  380. options["num_predict"] = gen_conf["max_tokens"]
  381. if "top_p" in gen_conf:
  382. options["top_k"] = gen_conf["top_p"]
  383. if "presence_penalty" in gen_conf:
  384. options["presence_penalty"] = gen_conf["presence_penalty"]
  385. if "frequency_penalty" in gen_conf:
  386. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  387. ans = ""
  388. try:
  389. response = self.client.chat(
  390. model=self.model_name,
  391. messages=history,
  392. stream=True,
  393. options=options,
  394. keep_alive=-1
  395. )
  396. for resp in response:
  397. if resp["done"]:
  398. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  399. ans += resp["message"]["content"]
  400. yield ans
  401. except Exception as e:
  402. yield ans + "\n**ERROR**: " + str(e)
  403. yield 0
  404. class LocalAICV(GptV4):
  405. def __init__(self, key, model_name, base_url, lang="Chinese"):
  406. if not base_url:
  407. raise ValueError("Local cv model url cannot be None")
  408. if base_url.split("/")[-1] != "v1":
  409. base_url = os.path.join(base_url, "v1")
  410. self.client = OpenAI(api_key="empty", base_url=base_url)
  411. self.model_name = model_name.split("___")[0]
  412. self.lang = lang
  413. class XinferenceCV(Base):
  414. def __init__(self, key, model_name="", lang="Chinese", base_url=""):
  415. if base_url.split("/")[-1] != "v1":
  416. base_url = os.path.join(base_url, "v1")
  417. self.client = OpenAI(api_key=key, base_url=base_url)
  418. self.model_name = model_name
  419. self.lang = lang
  420. def describe(self, image, max_tokens=300):
  421. b64 = self.image2base64(image)
  422. res = self.client.chat.completions.create(
  423. model=self.model_name,
  424. messages=self.prompt(b64),
  425. max_tokens=max_tokens,
  426. )
  427. return res.choices[0].message.content.strip(), res.usage.total_tokens
  428. class GeminiCV(Base):
  429. def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
  430. from google.generativeai import client, GenerativeModel
  431. client.configure(api_key=key)
  432. _client = client.get_default_generative_client()
  433. self.model_name = model_name
  434. self.model = GenerativeModel(model_name=self.model_name)
  435. self.model._client = _client
  436. self.lang = lang
  437. def describe(self, image, max_tokens=2048):
  438. from PIL.Image import open
  439. gen_config = {'max_output_tokens':max_tokens}
  440. prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
  441. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  442. b64 = self.image2base64(image)
  443. img = open(BytesIO(base64.b64decode(b64)))
  444. input = [prompt,img]
  445. res = self.model.generate_content(
  446. input,
  447. generation_config=gen_config,
  448. )
  449. return res.text,res.usage_metadata.total_token_count
  450. def chat(self, system, history, gen_conf, image=""):
  451. from transformers import GenerationConfig
  452. if system:
  453. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  454. try:
  455. for his in history:
  456. if his["role"] == "assistant":
  457. his["role"] = "model"
  458. his["parts"] = [his["content"]]
  459. his.pop("content")
  460. if his["role"] == "user":
  461. his["parts"] = [his["content"]]
  462. his.pop("content")
  463. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  464. response = self.model.generate_content(history, generation_config=GenerationConfig(
  465. max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
  466. top_p=gen_conf.get("top_p", 0.7)))
  467. ans = response.text
  468. return ans, response.usage_metadata.total_token_count
  469. except Exception as e:
  470. return "**ERROR**: " + str(e), 0
  471. def chat_streamly(self, system, history, gen_conf, image=""):
  472. from transformers import GenerationConfig
  473. if system:
  474. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  475. ans = ""
  476. try:
  477. for his in history:
  478. if his["role"] == "assistant":
  479. his["role"] = "model"
  480. his["parts"] = [his["content"]]
  481. his.pop("content")
  482. if his["role"] == "user":
  483. his["parts"] = [his["content"]]
  484. his.pop("content")
  485. history[-1]["parts"].append("data:image/jpeg;base64," + image)
  486. response = self.model.generate_content(history, generation_config=GenerationConfig(
  487. max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
  488. top_p=gen_conf.get("top_p", 0.7)), stream=True)
  489. for resp in response:
  490. if not resp.text:
  491. continue
  492. ans += resp.text
  493. yield ans
  494. except Exception as e:
  495. yield ans + "\n**ERROR**: " + str(e)
  496. yield response._chunks[-1].usage_metadata.total_token_count
  497. class OpenRouterCV(GptV4):
  498. def __init__(
  499. self,
  500. key,
  501. model_name,
  502. lang="Chinese",
  503. base_url="https://openrouter.ai/api/v1",
  504. ):
  505. if not base_url:
  506. base_url = "https://openrouter.ai/api/v1"
  507. self.client = OpenAI(api_key=key, base_url=base_url)
  508. self.model_name = model_name
  509. self.lang = lang
  510. class LocalCV(Base):
  511. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  512. pass
  513. def describe(self, image, max_tokens=1024):
  514. return "", 0
  515. class NvidiaCV(Base):
  516. def __init__(
  517. self,
  518. key,
  519. model_name,
  520. lang="Chinese",
  521. base_url="https://ai.api.nvidia.com/v1/vlm",
  522. ):
  523. if not base_url:
  524. base_url = ("https://ai.api.nvidia.com/v1/vlm",)
  525. self.lang = lang
  526. factory, llm_name = model_name.split("/")
  527. if factory != "liuhaotian":
  528. self.base_url = os.path.join(base_url, factory, llm_name)
  529. else:
  530. self.base_url = os.path.join(
  531. base_url, "community", llm_name.replace("-v1.6", "16")
  532. )
  533. self.key = key
  534. def describe(self, image, max_tokens=1024):
  535. b64 = self.image2base64(image)
  536. response = requests.post(
  537. url=self.base_url,
  538. headers={
  539. "accept": "application/json",
  540. "content-type": "application/json",
  541. "Authorization": f"Bearer {self.key}",
  542. },
  543. json={
  544. "messages": self.prompt(b64),
  545. "max_tokens": max_tokens,
  546. },
  547. )
  548. response = response.json()
  549. return (
  550. response["choices"][0]["message"]["content"].strip(),
  551. response["usage"]["total_tokens"],
  552. )
  553. def prompt(self, b64):
  554. return [
  555. {
  556. "role": "user",
  557. "content": (
  558. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  559. if self.lang.lower() == "chinese"
  560. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  561. )
  562. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  563. }
  564. ]
  565. def chat_prompt(self, text, b64):
  566. return [
  567. {
  568. "role": "user",
  569. "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
  570. }
  571. ]
  572. class StepFunCV(GptV4):
  573. def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
  574. if not base_url:
  575. base_url="https://api.stepfun.com/v1"
  576. self.client = OpenAI(api_key=key, base_url=base_url)
  577. self.model_name = model_name
  578. self.lang = lang
  579. class LmStudioCV(GptV4):
  580. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  581. if not base_url:
  582. raise ValueError("Local llm url cannot be None")
  583. if base_url.split("/")[-1] != "v1":
  584. base_url = os.path.join(base_url, "v1")
  585. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  586. self.model_name = model_name
  587. self.lang = lang
  588. class OpenAI_APICV(GptV4):
  589. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  590. if not base_url:
  591. raise ValueError("url cannot be None")
  592. if base_url.split("/")[-1] != "v1":
  593. base_url = os.path.join(base_url, "v1")
  594. self.client = OpenAI(api_key=key, base_url=base_url)
  595. self.model_name = model_name.split("___")[0]
  596. self.lang = lang
  597. class TogetherAICV(GptV4):
  598. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
  599. if not base_url:
  600. base_url = "https://api.together.xyz/v1"
  601. super().__init__(key, model_name,lang,base_url)
  602. class YiCV(GptV4):
  603. def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
  604. if not base_url:
  605. base_url = "https://api.lingyiwanwu.com/v1"
  606. super().__init__(key, model_name,lang,base_url)
  607. class HunyuanCV(Base):
  608. def __init__(self, key, model_name, lang="Chinese",base_url=None):
  609. from tencentcloud.common import credential
  610. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  611. key = json.loads(key)
  612. sid = key.get("hunyuan_sid", "")
  613. sk = key.get("hunyuan_sk", "")
  614. cred = credential.Credential(sid, sk)
  615. self.model_name = model_name
  616. self.client = hunyuan_client.HunyuanClient(cred, "")
  617. self.lang = lang
  618. def describe(self, image, max_tokens=4096):
  619. from tencentcloud.hunyuan.v20230901 import models
  620. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  621. TencentCloudSDKException,
  622. )
  623. b64 = self.image2base64(image)
  624. req = models.ChatCompletionsRequest()
  625. params = {"Model": self.model_name, "Messages": self.prompt(b64)}
  626. req.from_json_string(json.dumps(params))
  627. ans = ""
  628. try:
  629. response = self.client.ChatCompletions(req)
  630. ans = response.Choices[0].Message.Content
  631. return ans, response.Usage.TotalTokens
  632. except TencentCloudSDKException as e:
  633. return ans + "\n**ERROR**: " + str(e), 0
  634. def prompt(self, b64):
  635. return [
  636. {
  637. "Role": "user",
  638. "Contents": [
  639. {
  640. "Type": "image_url",
  641. "ImageUrl": {
  642. "Url": f"data:image/jpeg;base64,{b64}"
  643. },
  644. },
  645. {
  646. "Type": "text",
  647. "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  648. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  649. },
  650. ],
  651. }
  652. ]