您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from openai.lib.azure import AzureOpenAI
  17. from zhipuai import ZhipuAI
  18. import io
  19. from abc import ABC
  20. from ollama import Client
  21. from PIL import Image
  22. from openai import OpenAI
  23. import os
  24. import base64
  25. from io import BytesIO
  26. import json
  27. import requests
  28. from rag.nlp import is_english
  29. from api.utils import get_uuid
  30. from api.utils.file_utils import get_project_base_directory
  31. class Base(ABC):
  32. def __init__(self, key, model_name):
  33. pass
  34. def describe(self, image, max_tokens=300):
  35. raise NotImplementedError("Please implement encode method!")
  36. def chat(self, system, history, gen_conf, image=""):
  37. if system:
  38. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  39. try:
  40. for his in history:
  41. if his["role"] == "user":
  42. his["content"] = self.chat_prompt(his["content"], image)
  43. response = self.client.chat.completions.create(
  44. model=self.model_name,
  45. messages=history,
  46. max_tokens=gen_conf.get("max_tokens", 1000),
  47. temperature=gen_conf.get("temperature", 0.3),
  48. top_p=gen_conf.get("top_p", 0.7)
  49. )
  50. return response.choices[0].message.content.strip(), response.usage.total_tokens
  51. except Exception as e:
  52. return "**ERROR**: " + str(e), 0
  53. def chat_streamly(self, system, history, gen_conf, image=""):
  54. if system:
  55. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  56. ans = ""
  57. tk_count = 0
  58. try:
  59. for his in history:
  60. if his["role"] == "user":
  61. his["content"] = self.chat_prompt(his["content"], image)
  62. response = self.client.chat.completions.create(
  63. model=self.model_name,
  64. messages=history,
  65. max_tokens=gen_conf.get("max_tokens", 1000),
  66. temperature=gen_conf.get("temperature", 0.3),
  67. top_p=gen_conf.get("top_p", 0.7),
  68. stream=True
  69. )
  70. for resp in response:
  71. if not resp.choices[0].delta.content: continue
  72. delta = resp.choices[0].delta.content
  73. ans += delta
  74. if resp.choices[0].finish_reason == "length":
  75. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  76. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  77. tk_count = resp.usage.total_tokens
  78. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  79. yield ans
  80. except Exception as e:
  81. yield ans + "\n**ERROR**: " + str(e)
  82. yield tk_count
  83. def image2base64(self, image):
  84. if isinstance(image, bytes):
  85. return base64.b64encode(image).decode("utf-8")
  86. if isinstance(image, BytesIO):
  87. return base64.b64encode(image.getvalue()).decode("utf-8")
  88. buffered = BytesIO()
  89. try:
  90. image.save(buffered, format="JPEG")
  91. except Exception as e:
  92. image.save(buffered, format="PNG")
  93. return base64.b64encode(buffered.getvalue()).decode("utf-8")
  94. def prompt(self, b64):
  95. return [
  96. {
  97. "role": "user",
  98. "content": [
  99. {
  100. "type": "image_url",
  101. "image_url": {
  102. "url": f"data:image/jpeg;base64,{b64}"
  103. },
  104. },
  105. {
  106. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  107. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  108. },
  109. ],
  110. }
  111. ]
  112. def chat_prompt(self, text, b64):
  113. return [
  114. {
  115. "type": "image_url",
  116. "image_url": {
  117. "url": f"data:image/jpeg;base64,{b64}",
  118. },
  119. },
  120. {
  121. "type": "text",
  122. "text": text
  123. },
  124. ]
  125. class GptV4(Base):
  126. def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
  127. if not base_url: base_url="https://api.openai.com/v1"
  128. self.client = OpenAI(api_key=key, base_url=base_url)
  129. self.model_name = model_name
  130. self.lang = lang
  131. def describe(self, image, max_tokens=300):
  132. b64 = self.image2base64(image)
  133. prompt = self.prompt(b64)
  134. for i in range(len(prompt)):
  135. for c in prompt[i]["content"]:
  136. if "text" in c: c["type"] = "text"
  137. res = self.client.chat.completions.create(
  138. model=self.model_name,
  139. messages=prompt,
  140. max_tokens=max_tokens,
  141. )
  142. return res.choices[0].message.content.strip(), res.usage.total_tokens
  143. class AzureGptV4(Base):
  144. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  145. self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
  146. self.model_name = model_name
  147. self.lang = lang
  148. def describe(self, image, max_tokens=300):
  149. b64 = self.image2base64(image)
  150. prompt = self.prompt(b64)
  151. for i in range(len(prompt)):
  152. for c in prompt[i]["content"]:
  153. if "text" in c: c["type"] = "text"
  154. res = self.client.chat.completions.create(
  155. model=self.model_name,
  156. messages=prompt,
  157. max_tokens=max_tokens,
  158. )
  159. return res.choices[0].message.content.strip(), res.usage.total_tokens
  160. class QWenCV(Base):
  161. def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
  162. import dashscope
  163. dashscope.api_key = key
  164. self.model_name = model_name
  165. self.lang = lang
  166. def prompt(self, binary):
  167. # stupid as hell
  168. tmp_dir = get_project_base_directory("tmp")
  169. if not os.path.exists(tmp_dir):
  170. os.mkdir(tmp_dir)
  171. path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
  172. Image.open(io.BytesIO(binary)).save(path)
  173. return [
  174. {
  175. "role": "user",
  176. "content": [
  177. {
  178. "image": f"file://{path}"
  179. },
  180. {
  181. "text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  182. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  183. },
  184. ],
  185. }
  186. ]
  187. def chat_prompt(self, text, b64):
  188. return [
  189. {"image": f"{b64}"},
  190. {"text": text},
  191. ]
  192. def describe(self, image, max_tokens=300):
  193. from http import HTTPStatus
  194. from dashscope import MultiModalConversation
  195. response = MultiModalConversation.call(model=self.model_name,
  196. messages=self.prompt(image))
  197. if response.status_code == HTTPStatus.OK:
  198. return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
  199. return response.message, 0
  200. def chat(self, system, history, gen_conf, image=""):
  201. from http import HTTPStatus
  202. from dashscope import MultiModalConversation
  203. if system:
  204. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  205. for his in history:
  206. if his["role"] == "user":
  207. his["content"] = self.chat_prompt(his["content"], image)
  208. response = MultiModalConversation.call(model=self.model_name, messages=history,
  209. max_tokens=gen_conf.get("max_tokens", 1000),
  210. temperature=gen_conf.get("temperature", 0.3),
  211. top_p=gen_conf.get("top_p", 0.7))
  212. ans = ""
  213. tk_count = 0
  214. if response.status_code == HTTPStatus.OK:
  215. ans += response.output.choices[0]['message']['content']
  216. tk_count += response.usage.total_tokens
  217. if response.output.choices[0].get("finish_reason", "") == "length":
  218. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  219. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  220. return ans, tk_count
  221. return "**ERROR**: " + response.message, tk_count
  222. def chat_streamly(self, system, history, gen_conf, image=""):
  223. from http import HTTPStatus
  224. from dashscope import MultiModalConversation
  225. if system:
  226. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  227. for his in history:
  228. if his["role"] == "user":
  229. his["content"] = self.chat_prompt(his["content"], image)
  230. ans = ""
  231. tk_count = 0
  232. try:
  233. response = MultiModalConversation.call(model=self.model_name, messages=history,
  234. max_tokens=gen_conf.get("max_tokens", 1000),
  235. temperature=gen_conf.get("temperature", 0.3),
  236. top_p=gen_conf.get("top_p", 0.7),
  237. stream=True)
  238. for resp in response:
  239. if resp.status_code == HTTPStatus.OK:
  240. ans = resp.output.choices[0]['message']['content']
  241. tk_count = resp.usage.total_tokens
  242. if resp.output.choices[0].get("finish_reason", "") == "length":
  243. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  244. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  245. yield ans
  246. else:
  247. yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
  248. "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
  249. except Exception as e:
  250. yield ans + "\n**ERROR**: " + str(e)
  251. yield tk_count
  252. class Zhipu4V(Base):
  253. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  254. self.client = ZhipuAI(api_key=key)
  255. self.model_name = model_name
  256. self.lang = lang
  257. def describe(self, image, max_tokens=1024):
  258. b64 = self.image2base64(image)
  259. prompt = self.prompt(b64)
  260. prompt[0]["content"][1]["type"] = "text"
  261. res = self.client.chat.completions.create(
  262. model=self.model_name,
  263. messages=prompt,
  264. max_tokens=max_tokens,
  265. )
  266. return res.choices[0].message.content.strip(), res.usage.total_tokens
  267. def chat(self, system, history, gen_conf, image=""):
  268. if system:
  269. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  270. try:
  271. for his in history:
  272. if his["role"] == "user":
  273. his["content"] = self.chat_prompt(his["content"], image)
  274. response = self.client.chat.completions.create(
  275. model=self.model_name,
  276. messages=history,
  277. max_tokens=gen_conf.get("max_tokens", 1000),
  278. temperature=gen_conf.get("temperature", 0.3),
  279. top_p=gen_conf.get("top_p", 0.7)
  280. )
  281. return response.choices[0].message.content.strip(), response.usage.total_tokens
  282. except Exception as e:
  283. return "**ERROR**: " + str(e), 0
  284. def chat_streamly(self, system, history, gen_conf, image=""):
  285. if system:
  286. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  287. ans = ""
  288. tk_count = 0
  289. try:
  290. for his in history:
  291. if his["role"] == "user":
  292. his["content"] = self.chat_prompt(his["content"], image)
  293. response = self.client.chat.completions.create(
  294. model=self.model_name,
  295. messages=history,
  296. max_tokens=gen_conf.get("max_tokens", 1000),
  297. temperature=gen_conf.get("temperature", 0.3),
  298. top_p=gen_conf.get("top_p", 0.7),
  299. stream=True
  300. )
  301. for resp in response:
  302. if not resp.choices[0].delta.content: continue
  303. delta = resp.choices[0].delta.content
  304. ans += delta
  305. if resp.choices[0].finish_reason == "length":
  306. ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
  307. [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  308. tk_count = resp.usage.total_tokens
  309. if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
  310. yield ans
  311. except Exception as e:
  312. yield ans + "\n**ERROR**: " + str(e)
  313. yield tk_count
  314. class OllamaCV(Base):
  315. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  316. self.client = Client(host=kwargs["base_url"])
  317. self.model_name = model_name
  318. self.lang = lang
  319. def describe(self, image, max_tokens=1024):
  320. prompt = self.prompt("")
  321. try:
  322. options = {"num_predict": max_tokens}
  323. response = self.client.generate(
  324. model=self.model_name,
  325. prompt=prompt[0]["content"][1]["text"],
  326. images=[image],
  327. options=options
  328. )
  329. ans = response["response"].strip()
  330. return ans, 128
  331. except Exception as e:
  332. return "**ERROR**: " + str(e), 0
  333. def chat(self, system, history, gen_conf, image=""):
  334. if system:
  335. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  336. try:
  337. for his in history:
  338. if his["role"] == "user":
  339. his["images"] = [image]
  340. options = {}
  341. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  342. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  343. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  344. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  345. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  346. response = self.client.chat(
  347. model=self.model_name,
  348. messages=history,
  349. options=options,
  350. keep_alive=-1
  351. )
  352. ans = response["message"]["content"].strip()
  353. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  354. except Exception as e:
  355. return "**ERROR**: " + str(e), 0
  356. def chat_streamly(self, system, history, gen_conf, image=""):
  357. if system:
  358. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  359. for his in history:
  360. if his["role"] == "user":
  361. his["images"] = [image]
  362. options = {}
  363. if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
  364. if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
  365. if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
  366. if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
  367. if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
  368. ans = ""
  369. try:
  370. response = self.client.chat(
  371. model=self.model_name,
  372. messages=history,
  373. stream=True,
  374. options=options,
  375. keep_alive=-1
  376. )
  377. for resp in response:
  378. if resp["done"]:
  379. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  380. ans += resp["message"]["content"]
  381. yield ans
  382. except Exception as e:
  383. yield ans + "\n**ERROR**: " + str(e)
  384. yield 0
  385. class LocalAICV(GptV4):
  386. def __init__(self, key, model_name, base_url, lang="Chinese"):
  387. if not base_url:
  388. raise ValueError("Local cv model url cannot be None")
  389. if base_url.split("/")[-1] != "v1":
  390. base_url = os.path.join(base_url, "v1")
  391. self.client = OpenAI(api_key="empty", base_url=base_url)
  392. self.model_name = model_name.split("___")[0]
  393. self.lang = lang
  394. class XinferenceCV(Base):
  395. def __init__(self, key, model_name="", lang="Chinese", base_url=""):
  396. if base_url.split("/")[-1] != "v1":
  397. base_url = os.path.join(base_url, "v1")
  398. self.client = OpenAI(api_key="xxx", base_url=base_url)
  399. self.model_name = model_name
  400. self.lang = lang
  401. def describe(self, image, max_tokens=300):
  402. b64 = self.image2base64(image)
  403. res = self.client.chat.completions.create(
  404. model=self.model_name,
  405. messages=self.prompt(b64),
  406. max_tokens=max_tokens,
  407. )
  408. return res.choices[0].message.content.strip(), res.usage.total_tokens
  409. class GeminiCV(Base):
  410. def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
  411. from google.generativeai import client, GenerativeModel, GenerationConfig
  412. client.configure(api_key=key)
  413. _client = client.get_default_generative_client()
  414. self.model_name = model_name
  415. self.model = GenerativeModel(model_name=self.model_name)
  416. self.model._client = _client
  417. self.lang = lang
  418. def describe(self, image, max_tokens=2048):
  419. from PIL.Image import open
  420. gen_config = {'max_output_tokens':max_tokens}
  421. prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
  422. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  423. b64 = self.image2base64(image)
  424. img = open(BytesIO(base64.b64decode(b64)))
  425. input = [prompt,img]
  426. res = self.model.generate_content(
  427. input,
  428. generation_config=gen_config,
  429. )
  430. return res.text,res.usage_metadata.total_token_count
  431. def chat(self, system, history, gen_conf, image=""):
  432. if system:
  433. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  434. try:
  435. for his in history:
  436. if his["role"] == "assistant":
  437. his["role"] = "model"
  438. his["parts"] = [his["content"]]
  439. his.pop("content")
  440. if his["role"] == "user":
  441. his["parts"] = [his["content"]]
  442. his.pop("content")
  443. history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
  444. response = self.model.generate_content(history, generation_config=GenerationConfig(
  445. max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
  446. top_p=gen_conf.get("top_p", 0.7)))
  447. ans = response.text
  448. return ans, response.usage_metadata.total_token_count
  449. except Exception as e:
  450. return "**ERROR**: " + str(e), 0
  451. def chat_streamly(self, system, history, gen_conf, image=""):
  452. if system:
  453. history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
  454. ans = ""
  455. tk_count = 0
  456. try:
  457. for his in history:
  458. if his["role"] == "assistant":
  459. his["role"] = "model"
  460. his["parts"] = [his["content"]]
  461. his.pop("content")
  462. if his["role"] == "user":
  463. his["parts"] = [his["content"]]
  464. his.pop("content")
  465. history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
  466. response = self.model.generate_content(history, generation_config=GenerationConfig(
  467. max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
  468. top_p=gen_conf.get("top_p", 0.7)), stream=True)
  469. for resp in response:
  470. if not resp.text: continue
  471. ans += resp.text
  472. yield ans
  473. except Exception as e:
  474. yield ans + "\n**ERROR**: " + str(e)
  475. yield response._chunks[-1].usage_metadata.total_token_count
  476. class OpenRouterCV(GptV4):
  477. def __init__(
  478. self,
  479. key,
  480. model_name,
  481. lang="Chinese",
  482. base_url="https://openrouter.ai/api/v1",
  483. ):
  484. if not base_url:
  485. base_url = "https://openrouter.ai/api/v1"
  486. self.client = OpenAI(api_key=key, base_url=base_url)
  487. self.model_name = model_name
  488. self.lang = lang
  489. class LocalCV(Base):
  490. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  491. pass
  492. def describe(self, image, max_tokens=1024):
  493. return "", 0
  494. class NvidiaCV(Base):
  495. def __init__(
  496. self,
  497. key,
  498. model_name,
  499. lang="Chinese",
  500. base_url="https://ai.api.nvidia.com/v1/vlm",
  501. ):
  502. if not base_url:
  503. base_url = ("https://ai.api.nvidia.com/v1/vlm",)
  504. self.lang = lang
  505. factory, llm_name = model_name.split("/")
  506. if factory != "liuhaotian":
  507. self.base_url = os.path.join(base_url, factory, llm_name)
  508. else:
  509. self.base_url = os.path.join(
  510. base_url, "community", llm_name.replace("-v1.6", "16")
  511. )
  512. self.key = key
  513. def describe(self, image, max_tokens=1024):
  514. b64 = self.image2base64(image)
  515. response = requests.post(
  516. url=self.base_url,
  517. headers={
  518. "accept": "application/json",
  519. "content-type": "application/json",
  520. "Authorization": f"Bearer {self.key}",
  521. },
  522. json={
  523. "messages": self.prompt(b64),
  524. "max_tokens": max_tokens,
  525. },
  526. )
  527. response = response.json()
  528. return (
  529. response["choices"][0]["message"]["content"].strip(),
  530. response["usage"]["total_tokens"],
  531. )
  532. def prompt(self, b64):
  533. return [
  534. {
  535. "role": "user",
  536. "content": (
  537. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  538. if self.lang.lower() == "chinese"
  539. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  540. )
  541. + f' <img src="data:image/jpeg;base64,{b64}"/>',
  542. }
  543. ]
  544. def chat_prompt(self, text, b64):
  545. return [
  546. {
  547. "role": "user",
  548. "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
  549. }
  550. ]
  551. class StepFunCV(GptV4):
  552. def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
  553. if not base_url: base_url="https://api.stepfun.com/v1"
  554. self.client = OpenAI(api_key=key, base_url=base_url)
  555. self.model_name = model_name
  556. self.lang = lang
  557. class LmStudioCV(GptV4):
  558. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  559. if not base_url:
  560. raise ValueError("Local llm url cannot be None")
  561. if base_url.split("/")[-1] != "v1":
  562. base_url = os.path.join(base_url, "v1")
  563. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  564. self.model_name = model_name
  565. self.lang = lang
  566. class OpenAI_APICV(GptV4):
  567. def __init__(self, key, model_name, lang="Chinese", base_url=""):
  568. if not base_url:
  569. raise ValueError("url cannot be None")
  570. if base_url.split("/")[-1] != "v1":
  571. base_url = os.path.join(base_url, "v1")
  572. self.client = OpenAI(api_key=key, base_url=base_url)
  573. self.model_name = model_name.split("___")[0]
  574. self.lang = lang
  575. class TogetherAICV(GptV4):
  576. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
  577. if not base_url:
  578. base_url = "https://api.together.xyz/v1"
  579. super().__init__(key, model_name,lang,base_url)
  580. class YiCV(GptV4):
  581. def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
  582. if not base_url:
  583. base_url = "https://api.lingyiwanwu.com/v1"
  584. super().__init__(key, model_name,lang,base_url)
  585. class HunyuanCV(Base):
  586. def __init__(self, key, model_name, lang="Chinese",base_url=None):
  587. from tencentcloud.common import credential
  588. from tencentcloud.hunyuan.v20230901 import hunyuan_client
  589. key = json.loads(key)
  590. sid = key.get("hunyuan_sid", "")
  591. sk = key.get("hunyuan_sk", "")
  592. cred = credential.Credential(sid, sk)
  593. self.model_name = model_name
  594. self.client = hunyuan_client.HunyuanClient(cred, "")
  595. self.lang = lang
  596. def describe(self, image, max_tokens=4096):
  597. from tencentcloud.hunyuan.v20230901 import models
  598. from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
  599. TencentCloudSDKException,
  600. )
  601. b64 = self.image2base64(image)
  602. req = models.ChatCompletionsRequest()
  603. params = {"Model": self.model_name, "Messages": self.prompt(b64)}
  604. req.from_json_string(json.dumps(params))
  605. ans = ""
  606. try:
  607. response = self.client.ChatCompletions(req)
  608. ans = response.Choices[0].Message.Content
  609. return ans, response.Usage.TotalTokens
  610. except TencentCloudSDKException as e:
  611. return ans + "\n**ERROR**: " + str(e), 0
  612. def prompt(self, b64):
  613. return [
  614. {
  615. "Role": "user",
  616. "Contents": [
  617. {
  618. "Type": "image_url",
  619. "ImageUrl": {
  620. "Url": f"data:image/jpeg;base64,{b64}"
  621. },
  622. },
  623. {
  624. "Type": "text",
  625. "Text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else
  626. "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  627. },
  628. ],
  629. }
  630. ]