You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cv_model.py 30KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. import json
  18. import os
  19. from abc import ABC
  20. from copy import deepcopy
  21. from io import BytesIO
  22. from urllib.parse import urljoin
  23. import requests
  24. from openai import OpenAI
  25. from openai.lib.azure import AzureOpenAI
  26. from zhipuai import ZhipuAI
  27. from rag.nlp import is_english
  28. from rag.prompts import vision_llm_describe_prompt
  29. from rag.utils import num_tokens_from_string
  30. class Base(ABC):
  31. def __init__(self, **kwargs):
  32. # Configure retry parameters
  33. self.max_retries = kwargs.get("max_retries", int(os.environ.get("LLM_MAX_RETRIES", 5)))
  34. self.base_delay = kwargs.get("retry_interval", float(os.environ.get("LLM_BASE_DELAY", 2.0)))
  35. self.max_rounds = kwargs.get("max_rounds", 5)
  36. self.is_tools = False
  37. self.tools = []
  38. self.toolcall_sessions = {}
  39. def describe(self, image):
  40. raise NotImplementedError("Please implement encode method!")
  41. def describe_with_prompt(self, image, prompt=None):
  42. raise NotImplementedError("Please implement encode method!")
  43. def _form_history(self, system, history, images=[]):
  44. hist = []
  45. if system:
  46. hist.append({"role": "system", "content": system})
  47. for h in history:
  48. if images and h["role"] == "user":
  49. h["content"] = self._image_prompt(h["content"], images)
  50. images = []
  51. hist.append(h)
  52. return hist
  53. def _image_prompt(self, text, images):
  54. if not images:
  55. return text
  56. if isinstance(images, str) or "bytes" in type(images).__name__:
  57. images = [images]
  58. pmpt = [{"type": "text", "text": text}]
  59. for img in images:
  60. pmpt.append({
  61. "type": "image_url",
  62. "image_url": {
  63. "url": img if isinstance(img, str) and img.startswith("data:") else f"data:image/png;base64,{img}"
  64. }
  65. })
  66. return pmpt
  67. def chat(self, system, history, gen_conf, images=[], **kwargs):
  68. try:
  69. response = self.client.chat.completions.create(
  70. model=self.model_name,
  71. messages=self._form_history(system, history, images)
  72. )
  73. return response.choices[0].message.content.strip(), response.usage.total_tokens
  74. except Exception as e:
  75. return "**ERROR**: " + str(e), 0
  76. def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
  77. ans = ""
  78. tk_count = 0
  79. try:
  80. response = self.client.chat.completions.create(
  81. model=self.model_name,
  82. messages=self._form_history(system, history, images),
  83. stream=True
  84. )
  85. for resp in response:
  86. if not resp.choices[0].delta.content:
  87. continue
  88. delta = resp.choices[0].delta.content
  89. ans = delta
  90. if resp.choices[0].finish_reason == "length":
  91. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  92. if resp.choices[0].finish_reason == "stop":
  93. tk_count += resp.usage.total_tokens
  94. yield ans
  95. except Exception as e:
  96. yield ans + "\n**ERROR**: " + str(e)
  97. yield tk_count
  98. @staticmethod
  99. def image2base64(image):
  100. # Return a data URL with the correct MIME to avoid provider mismatches
  101. if isinstance(image, bytes):
  102. # Best-effort magic number sniffing
  103. mime = "image/png"
  104. if len(image) >= 2 and image[0] == 0xFF and image[1] == 0xD8:
  105. mime = "image/jpeg"
  106. b64 = base64.b64encode(image).decode("utf-8")
  107. return f"data:{mime};base64,{b64}"
  108. if isinstance(image, BytesIO):
  109. data = image.getvalue()
  110. mime = "image/png"
  111. if len(data) >= 2 and data[0] == 0xFF and data[1] == 0xD8:
  112. mime = "image/jpeg"
  113. b64 = base64.b64encode(data).decode("utf-8")
  114. return f"data:{mime};base64,{b64}"
  115. buffered = BytesIO()
  116. fmt = "JPEG"
  117. try:
  118. image.save(buffered, format="JPEG")
  119. except Exception:
  120. buffered = BytesIO() # reset buffer before saving PNG
  121. image.save(buffered, format="PNG")
  122. fmt = "PNG"
  123. data = buffered.getvalue()
  124. b64 = base64.b64encode(data).decode("utf-8")
  125. mime = f"image/{fmt.lower()}"
  126. return f"data:{mime};base64,{b64}"
  127. def prompt(self, b64):
  128. return [
  129. {
  130. "role": "user",
  131. "content": self._image_prompt(
  132. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  133. if self.lang.lower() == "chinese"
  134. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
  135. b64
  136. )
  137. }
  138. ]
  139. def vision_llm_prompt(self, b64, prompt=None):
  140. return [
  141. {
  142. "role": "user",
  143. "content": self._image_prompt(prompt if prompt else vision_llm_describe_prompt(), b64)
  144. }
  145. ]
  146. class GptV4(Base):
  147. _FACTORY_NAME = "OpenAI"
  148. def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1", **kwargs):
  149. if not base_url:
  150. base_url = "https://api.openai.com/v1"
  151. self.client = OpenAI(api_key=key, base_url=base_url)
  152. self.model_name = model_name
  153. self.lang = lang
  154. super().__init__(**kwargs)
  155. def describe(self, image):
  156. b64 = self.image2base64(image)
  157. res = self.client.chat.completions.create(
  158. model=self.model_name,
  159. messages=self.prompt(b64),
  160. )
  161. return res.choices[0].message.content.strip(), res.usage.total_tokens
  162. def describe_with_prompt(self, image, prompt=None):
  163. b64 = self.image2base64(image)
  164. res = self.client.chat.completions.create(
  165. model=self.model_name,
  166. messages=self.vision_llm_prompt(b64, prompt),
  167. )
  168. return res.choices[0].message.content.strip(), res.usage.total_tokens
  169. class AzureGptV4(GptV4):
  170. _FACTORY_NAME = "Azure-OpenAI"
  171. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  172. api_key = json.loads(key).get("api_key", "")
  173. api_version = json.loads(key).get("api_version", "2024-02-01")
  174. self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
  175. self.model_name = model_name
  176. self.lang = lang
  177. Base.__init__(self, **kwargs)
  178. class xAICV(GptV4):
  179. _FACTORY_NAME = "xAI"
  180. def __init__(self, key, model_name="grok-3", lang="Chinese", base_url=None, **kwargs):
  181. if not base_url:
  182. base_url = "https://api.x.ai/v1"
  183. super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
  184. class QWenCV(GptV4):
  185. _FACTORY_NAME = "Tongyi-Qianwen"
  186. def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", base_url=None, **kwargs):
  187. if not base_url:
  188. base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
  189. super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
  190. class HunyuanCV(GptV4):
  191. _FACTORY_NAME = "Tencent Hunyuan"
  192. def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
  193. if not base_url:
  194. base_url = "https://api.hunyuan.cloud.tencent.com/v1"
  195. super().__init__(key, model_name, lang=lang, base_url=base_url, **kwargs)
  196. class Zhipu4V(GptV4):
  197. _FACTORY_NAME = "ZHIPU-AI"
  198. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  199. self.client = ZhipuAI(api_key=key)
  200. self.model_name = model_name
  201. self.lang = lang
  202. Base.__init__(self, **kwargs)
  203. class StepFunCV(GptV4):
  204. _FACTORY_NAME = "StepFun"
  205. def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1", **kwargs):
  206. if not base_url:
  207. base_url = "https://api.stepfun.com/v1"
  208. self.client = OpenAI(api_key=key, base_url=base_url)
  209. self.model_name = model_name
  210. self.lang = lang
  211. Base.__init__(self, **kwargs)
  212. class LmStudioCV(GptV4):
  213. _FACTORY_NAME = "LM-Studio"
  214. def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
  215. if not base_url:
  216. raise ValueError("Local llm url cannot be None")
  217. base_url = urljoin(base_url, "v1")
  218. self.client = OpenAI(api_key="lm-studio", base_url=base_url)
  219. self.model_name = model_name
  220. self.lang = lang
  221. Base.__init__(self, **kwargs)
  222. class OpenAI_APICV(GptV4):
  223. _FACTORY_NAME = ["VLLM", "OpenAI-API-Compatible"]
  224. def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
  225. if not base_url:
  226. raise ValueError("url cannot be None")
  227. base_url = urljoin(base_url, "v1")
  228. self.client = OpenAI(api_key=key, base_url=base_url)
  229. self.model_name = model_name.split("___")[0]
  230. self.lang = lang
  231. Base.__init__(self, **kwargs)
  232. class TogetherAICV(GptV4):
  233. _FACTORY_NAME = "TogetherAI"
  234. def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1", **kwargs):
  235. if not base_url:
  236. base_url = "https://api.together.xyz/v1"
  237. super().__init__(key, model_name, lang, base_url, **kwargs)
  238. class YiCV(GptV4):
  239. _FACTORY_NAME = "01.AI"
  240. def __init__(
  241. self,
  242. key,
  243. model_name,
  244. lang="Chinese",
  245. base_url="https://api.lingyiwanwu.com/v1", **kwargs
  246. ):
  247. if not base_url:
  248. base_url = "https://api.lingyiwanwu.com/v1"
  249. super().__init__(key, model_name, lang, base_url, **kwargs)
  250. class SILICONFLOWCV(GptV4):
  251. _FACTORY_NAME = "SILICONFLOW"
  252. def __init__(
  253. self,
  254. key,
  255. model_name,
  256. lang="Chinese",
  257. base_url="https://api.siliconflow.cn/v1", **kwargs
  258. ):
  259. if not base_url:
  260. base_url = "https://api.siliconflow.cn/v1"
  261. super().__init__(key, model_name, lang, base_url, **kwargs)
  262. class OpenRouterCV(GptV4):
  263. _FACTORY_NAME = "OpenRouter"
  264. def __init__(
  265. self,
  266. key,
  267. model_name,
  268. lang="Chinese",
  269. base_url="https://openrouter.ai/api/v1", **kwargs
  270. ):
  271. if not base_url:
  272. base_url = "https://openrouter.ai/api/v1"
  273. self.client = OpenAI(api_key=key, base_url=base_url)
  274. self.model_name = model_name
  275. self.lang = lang
  276. Base.__init__(self, **kwargs)
  277. class LocalAICV(GptV4):
  278. _FACTORY_NAME = "LocalAI"
  279. def __init__(self, key, model_name, base_url, lang="Chinese", **kwargs):
  280. if not base_url:
  281. raise ValueError("Local cv model url cannot be None")
  282. base_url = urljoin(base_url, "v1")
  283. self.client = OpenAI(api_key="empty", base_url=base_url)
  284. self.model_name = model_name.split("___")[0]
  285. self.lang = lang
  286. Base.__init__(self, **kwargs)
  287. class XinferenceCV(GptV4):
  288. _FACTORY_NAME = "Xinference"
  289. def __init__(self, key, model_name="", lang="Chinese", base_url="", **kwargs):
  290. base_url = urljoin(base_url, "v1")
  291. self.client = OpenAI(api_key=key, base_url=base_url)
  292. self.model_name = model_name
  293. self.lang = lang
  294. Base.__init__(self, **kwargs)
  295. class GPUStackCV(GptV4):
  296. _FACTORY_NAME = "GPUStack"
  297. def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs):
  298. if not base_url:
  299. raise ValueError("Local llm url cannot be None")
  300. base_url = urljoin(base_url, "v1")
  301. self.client = OpenAI(api_key=key, base_url=base_url)
  302. self.model_name = model_name
  303. self.lang = lang
  304. Base.__init__(self, **kwargs)
  305. class LocalCV(Base):
  306. _FACTORY_NAME = "Moonshot"
  307. def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
  308. pass
  309. def describe(self, image):
  310. return "", 0
  311. class OllamaCV(Base):
  312. _FACTORY_NAME = "Ollama"
  313. def __init__(self, key, model_name, lang="Chinese", **kwargs):
  314. from ollama import Client
  315. self.client = Client(host=kwargs["base_url"])
  316. self.model_name = model_name
  317. self.lang = lang
  318. self.keep_alive = kwargs.get("ollama_keep_alive", int(os.environ.get("OLLAMA_KEEP_ALIVE", -1)))
  319. Base.__init__(self, **kwargs)
  320. def _clean_img(self, img):
  321. if not isinstance(img, str):
  322. return img
  323. #remove the header like "data/*;base64,"
  324. if img.startswith("data:") and ";base64," in img:
  325. img = img.split(";base64,")[1]
  326. return img
  327. def _clean_conf(self, gen_conf):
  328. options = {}
  329. if "temperature" in gen_conf:
  330. options["temperature"] = gen_conf["temperature"]
  331. if "top_p" in gen_conf:
  332. options["top_k"] = gen_conf["top_p"]
  333. if "presence_penalty" in gen_conf:
  334. options["presence_penalty"] = gen_conf["presence_penalty"]
  335. if "frequency_penalty" in gen_conf:
  336. options["frequency_penalty"] = gen_conf["frequency_penalty"]
  337. return options
  338. def _form_history(self, system, history, images=[]):
  339. hist = deepcopy(history)
  340. if system and hist[0]["role"] == "user":
  341. hist.insert(0, {"role": "system", "content": system})
  342. if not images:
  343. return hist
  344. temp_images = []
  345. for img in images:
  346. temp_images.append(self._clean_img(img))
  347. for his in hist:
  348. if his["role"] == "user":
  349. his["images"] = temp_images
  350. break
  351. return hist
  352. def describe(self, image):
  353. prompt = self.prompt("")
  354. try:
  355. response = self.client.generate(
  356. model=self.model_name,
  357. prompt=prompt[0]["content"][0]["text"],
  358. images=[image],
  359. )
  360. ans = response["response"].strip()
  361. return ans, 128
  362. except Exception as e:
  363. return "**ERROR**: " + str(e), 0
  364. def describe_with_prompt(self, image, prompt=None):
  365. vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
  366. try:
  367. response = self.client.generate(
  368. model=self.model_name,
  369. prompt=vision_prompt[0]["content"][0]["text"],
  370. images=[image],
  371. )
  372. ans = response["response"].strip()
  373. return ans, 128
  374. except Exception as e:
  375. return "**ERROR**: " + str(e), 0
  376. def chat(self, system, history, gen_conf, images=[]):
  377. try:
  378. response = self.client.chat(
  379. model=self.model_name,
  380. messages=self._form_history(system, history, images),
  381. options=self._clean_conf(gen_conf),
  382. keep_alive=self.keep_alive
  383. )
  384. ans = response["message"]["content"].strip()
  385. return ans, response["eval_count"] + response.get("prompt_eval_count", 0)
  386. except Exception as e:
  387. return "**ERROR**: " + str(e), 0
  388. def chat_streamly(self, system, history, gen_conf, images=[]):
  389. ans = ""
  390. try:
  391. response = self.client.chat(
  392. model=self.model_name,
  393. messages=self._form_history(system, history, images),
  394. stream=True,
  395. options=self._clean_conf(gen_conf),
  396. keep_alive=self.keep_alive
  397. )
  398. for resp in response:
  399. if resp["done"]:
  400. yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0)
  401. ans = resp["message"]["content"]
  402. yield ans
  403. except Exception as e:
  404. yield ans + "\n**ERROR**: " + str(e)
  405. yield 0
  406. class GeminiCV(Base):
  407. _FACTORY_NAME = "Gemini"
  408. def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
  409. from google.generativeai import GenerativeModel, client
  410. client.configure(api_key=key)
  411. _client = client.get_default_generative_client()
  412. self.model_name = model_name
  413. self.model = GenerativeModel(model_name=self.model_name)
  414. self.model._client = _client
  415. self.lang = lang
  416. Base.__init__(self, **kwargs)
  417. def _form_history(self, system, history, images=[]):
  418. hist = []
  419. if system:
  420. hist.append({"role": "user", "parts": [system, history[0]["content"]]})
  421. for img in images:
  422. hist[0]["parts"].append(("data:image/jpeg;base64," + img) if img[:4]!="data" else img)
  423. for h in history[1:]:
  424. hist.append({"role": "user" if h["role"]=="user" else "model", "parts": [h["content"]]})
  425. return hist
  426. def describe(self, image):
  427. from PIL.Image import open
  428. prompt = (
  429. "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
  430. if self.lang.lower() == "chinese"
  431. else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
  432. )
  433. b64 = self.image2base64(image)
  434. img = open(BytesIO(base64.b64decode(b64)))
  435. input = [prompt, img]
  436. res = self.model.generate_content(input)
  437. img.close()
  438. return res.text, res.usage_metadata.total_token_count
  439. def describe_with_prompt(self, image, prompt=None):
  440. from PIL.Image import open
  441. b64 = self.image2base64(image)
  442. vision_prompt = prompt if prompt else vision_llm_describe_prompt()
  443. img = open(BytesIO(base64.b64decode(b64)))
  444. input = [vision_prompt, img]
  445. res = self.model.generate_content(
  446. input,
  447. )
  448. img.close()
  449. return res.text, res.usage_metadata.total_token_count
  450. def chat(self, system, history, gen_conf, images=[]):
  451. generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
  452. try:
  453. response = self.model.generate_content(
  454. self._form_history(system, history, images),
  455. generation_config=generation_config)
  456. ans = response.text
  457. return ans, response.usage_metadata.total_token_count
  458. except Exception as e:
  459. return "**ERROR**: " + str(e), 0
  460. def chat_streamly(self, system, history, gen_conf, images=[]):
  461. ans = ""
  462. response = None
  463. try:
  464. generation_config = dict(temperature=gen_conf.get("temperature", 0.3), top_p=gen_conf.get("top_p", 0.7))
  465. response = self.model.generate_content(
  466. self._form_history(system, history, images),
  467. generation_config=generation_config,
  468. stream=True,
  469. )
  470. for resp in response:
  471. if not resp.text:
  472. continue
  473. ans = resp.text
  474. yield ans
  475. except Exception as e:
  476. yield ans + "\n**ERROR**: " + str(e)
  477. if response and hasattr(response, "usage_metadata") and hasattr(response.usage_metadata, "total_token_count"):
  478. yield response.usage_metadata.total_token_count
  479. else:
  480. yield 0
  481. class NvidiaCV(Base):
  482. _FACTORY_NAME = "NVIDIA"
  483. def __init__(
  484. self,
  485. key,
  486. model_name,
  487. lang="Chinese",
  488. base_url="https://ai.api.nvidia.com/v1/vlm", **kwargs
  489. ):
  490. if not base_url:
  491. base_url = ("https://ai.api.nvidia.com/v1/vlm",)
  492. self.lang = lang
  493. factory, llm_name = model_name.split("/")
  494. if factory != "liuhaotian":
  495. self.base_url = urljoin(base_url, f"{factory}/{llm_name}")
  496. else:
  497. self.base_url = urljoin(f"{base_url}/community", llm_name.replace("-v1.6", "16"))
  498. self.key = key
  499. Base.__init__(self, **kwargs)
  500. def _image_prompt(self, text, images):
  501. if not images:
  502. return text
  503. htmls = ""
  504. for img in images:
  505. htmls += ' <img src="{}"/>'.format(f"data:image/jpeg;base64,{img}" if img[:4] != "data" else img)
  506. return text + htmls
  507. def describe(self, image):
  508. b64 = self.image2base64(image)
  509. response = requests.post(
  510. url=self.base_url,
  511. headers={
  512. "accept": "application/json",
  513. "content-type": "application/json",
  514. "Authorization": f"Bearer {self.key}",
  515. },
  516. json={"messages": self.prompt(b64)},
  517. )
  518. response = response.json()
  519. return (
  520. response["choices"][0]["message"]["content"].strip(),
  521. response["usage"]["total_tokens"],
  522. )
  523. def _request(self, msg, gen_conf={}):
  524. response = requests.post(
  525. url=self.base_url,
  526. headers={
  527. "accept": "application/json",
  528. "content-type": "application/json",
  529. "Authorization": f"Bearer {self.key}",
  530. },
  531. json={
  532. "messages": msg, **gen_conf
  533. },
  534. )
  535. return response.json()
  536. def describe_with_prompt(self, image, prompt=None):
  537. b64 = self.image2base64(image)
  538. vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
  539. response = self._request(vision_prompt)
  540. return (
  541. response["choices"][0]["message"]["content"].strip(),
  542. response["usage"]["total_tokens"],
  543. )
  544. def chat(self, system, history, gen_conf, images=[], **kwargs):
  545. try:
  546. response = self._request(self._form_history(system, history, images), gen_conf)
  547. return (
  548. response["choices"][0]["message"]["content"].strip(),
  549. response["usage"]["total_tokens"],
  550. )
  551. except Exception as e:
  552. return "**ERROR**: " + str(e), 0
  553. def chat_streamly(self, system, history, gen_conf, images=[], **kwargs):
  554. total_tokens = 0
  555. try:
  556. response = self._request(self._form_history(system, history, images), gen_conf)
  557. cnt = response["choices"][0]["message"]["content"]
  558. if "usage" in response and "total_tokens" in response["usage"]:
  559. total_tokens += response["usage"]["total_tokens"]
  560. for resp in cnt:
  561. yield resp
  562. except Exception as e:
  563. yield "\n**ERROR**: " + str(e)
  564. yield total_tokens
  565. class AnthropicCV(Base):
  566. _FACTORY_NAME = "Anthropic"
  567. def __init__(self, key, model_name, base_url=None, **kwargs):
  568. import anthropic
  569. self.client = anthropic.Anthropic(api_key=key)
  570. self.model_name = model_name
  571. self.system = ""
  572. self.max_tokens = 8192
  573. if "haiku" in self.model_name or "opus" in self.model_name:
  574. self.max_tokens = 4096
  575. Base.__init__(self, **kwargs)
  576. def _image_prompt(self, text, images):
  577. if not images:
  578. return text
  579. pmpt = [{"type": "text", "text": text}]
  580. for img in images:
  581. pmpt.append({
  582. "type": "image",
  583. "source": {
  584. "type": "base64",
  585. "media_type": (img.split(":")[1].split(";")[0] if isinstance(img, str) and img[:4] == "data" else "image/png"),
  586. "data": (img.split(",")[1] if isinstance(img, str) and img[:4] == "data" else img)
  587. },
  588. }
  589. )
  590. return pmpt
  591. def describe(self, image):
  592. b64 = self.image2base64(image)
  593. response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=self.prompt(b64))
  594. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
  595. def describe_with_prompt(self, image, prompt=None):
  596. b64 = self.image2base64(image)
  597. prompt = self.prompt(b64, prompt if prompt else vision_llm_describe_prompt())
  598. response = self.client.messages.create(model=self.model_name, max_tokens=self.max_tokens, messages=prompt)
  599. return response["content"][0]["text"].strip(), response["usage"]["input_tokens"] + response["usage"]["output_tokens"]
  600. def _clean_conf(self, gen_conf):
  601. if "presence_penalty" in gen_conf:
  602. del gen_conf["presence_penalty"]
  603. if "frequency_penalty" in gen_conf:
  604. del gen_conf["frequency_penalty"]
  605. if "max_token" in gen_conf:
  606. gen_conf["max_tokens"] = self.max_tokens
  607. return gen_conf
  608. def chat(self, system, history, gen_conf, images=[]):
  609. gen_conf = self._clean_conf(gen_conf)
  610. ans = ""
  611. try:
  612. response = self.client.messages.create(
  613. model=self.model_name,
  614. messages=self._form_history(system, history, images),
  615. system=system,
  616. stream=False,
  617. **gen_conf,
  618. ).to_dict()
  619. ans = response["content"][0]["text"]
  620. if response["stop_reason"] == "max_tokens":
  621. ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
  622. return (
  623. ans,
  624. response["usage"]["input_tokens"] + response["usage"]["output_tokens"],
  625. )
  626. except Exception as e:
  627. return ans + "\n**ERROR**: " + str(e), 0
  628. def chat_streamly(self, system, history, gen_conf, images=[]):
  629. gen_conf = self._clean_conf(gen_conf)
  630. total_tokens = 0
  631. try:
  632. response = self.client.messages.create(
  633. model=self.model_name,
  634. messages=self._form_history(system, history, images),
  635. system=system,
  636. stream=True,
  637. **gen_conf,
  638. )
  639. think = False
  640. for res in response:
  641. if res.type == "content_block_delta":
  642. if res.delta.type == "thinking_delta" and res.delta.thinking:
  643. if not think:
  644. yield "<think>"
  645. think = True
  646. yield res.delta.thinking
  647. total_tokens += num_tokens_from_string(res.delta.thinking)
  648. elif think:
  649. yield "</think>"
  650. else:
  651. yield res.delta.text
  652. total_tokens += num_tokens_from_string(res.delta.text)
  653. except Exception as e:
  654. yield "\n**ERROR**: " + str(e)
  655. yield total_tokens
  656. class GoogleCV(AnthropicCV, GeminiCV):
  657. _FACTORY_NAME = "Google Cloud"
  658. def __init__(self, key, model_name, lang="Chinese", base_url=None, **kwargs):
  659. import base64
  660. from google.oauth2 import service_account
  661. key = json.loads(key)
  662. access_token = json.loads(base64.b64decode(key.get("google_service_account_key", "")))
  663. project_id = key.get("google_project_id", "")
  664. region = key.get("google_region", "")
  665. scopes = ["https://www.googleapis.com/auth/cloud-platform"]
  666. self.model_name = model_name
  667. self.lang = lang
  668. if "claude" in self.model_name:
  669. from anthropic import AnthropicVertex
  670. from google.auth.transport.requests import Request
  671. if access_token:
  672. credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes)
  673. request = Request()
  674. credits.refresh(request)
  675. token = credits.token
  676. self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token)
  677. else:
  678. self.client = AnthropicVertex(region=region, project_id=project_id)
  679. else:
  680. import vertexai.generative_models as glm
  681. from google.cloud import aiplatform
  682. if access_token:
  683. credits = service_account.Credentials.from_service_account_info(access_token)
  684. aiplatform.init(credentials=credits, project=project_id, location=region)
  685. else:
  686. aiplatform.init(project=project_id, location=region)
  687. self.client = glm.GenerativeModel(model_name=self.model_name)
  688. Base.__init__(self, **kwargs)
  689. def describe(self, image):
  690. if "claude" in self.model_name:
  691. return AnthropicCV.describe(self, image)
  692. else:
  693. return GeminiCV.describe(self, image)
  694. def describe_with_prompt(self, image, prompt=None):
  695. if "claude" in self.model_name:
  696. return AnthropicCV.describe_with_prompt(self, image, prompt)
  697. else:
  698. return GeminiCV.describe_with_prompt(self, image, prompt)
  699. def chat(self, system, history, gen_conf, images=[]):
  700. if "claude" in self.model_name:
  701. return AnthropicCV.chat(self, system, history, gen_conf, images)
  702. else:
  703. return GeminiCV.chat(self, system, history, gen_conf, images)
  704. def chat_streamly(self, system, history, gen_conf, images=[]):
  705. if "claude" in self.model_name:
  706. for ans in AnthropicCV.chat_streamly(self, system, history, gen_conf, images):
  707. yield ans
  708. else:
  709. for ans in GeminiCV.chat_streamly(self, system, history, gen_conf, images):
  710. yield ans