### What problem does this PR solve? Add vision LLM PDF parser ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.18.0
| @@ -15,13 +15,12 @@ | |||
| # | |||
| import logging | |||
| from api.db.services.user_service import TenantService | |||
| from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel | |||
| from api import settings | |||
| from api.db import LLMType | |||
| from api.db.db_models import DB | |||
| from api.db.db_models import LLMFactories, LLM, TenantLLM | |||
| from api.db.db_models import DB, LLM, LLMFactories, TenantLLM | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.user_service import TenantService | |||
| from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel | |||
| class LLMFactoriesService(CommonService): | |||
| @@ -266,6 +265,14 @@ class LLMBundle: | |||
| "LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) | |||
| return txt | |||
| def describe_with_prompt(self, image, prompt): | |||
| txt, used_tokens = self.mdl.describe_with_prompt(image, prompt) | |||
| if not TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, used_tokens): | |||
| logging.error( | |||
| "LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) | |||
| return txt | |||
| def transcription(self, audio): | |||
| txt, used_tokens = self.mdl.transcription(audio) | |||
| if not TenantLLMService.increase_usage( | |||
| @@ -17,26 +17,27 @@ | |||
| import logging | |||
| import os | |||
| import random | |||
| from timeit import default_timer as timer | |||
| import re | |||
| import sys | |||
| import threading | |||
| import trio | |||
| import xgboost as xgb | |||
| from copy import deepcopy | |||
| from io import BytesIO | |||
| import re | |||
| from timeit import default_timer as timer | |||
| import numpy as np | |||
| import pdfplumber | |||
| import trio | |||
| import xgboost as xgb | |||
| from huggingface_hub import snapshot_download | |||
| from PIL import Image | |||
| import numpy as np | |||
| from pypdf import PdfReader as pdf2_read | |||
| from api import settings | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer | |||
| from deepdoc.vision import OCR, LayoutRecognizer, Recognizer, TableStructureRecognizer | |||
| from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk | |||
| from rag.nlp import rag_tokenizer | |||
| from copy import deepcopy | |||
| from huggingface_hub import snapshot_download | |||
| from rag.prompts import vision_llm_describe_prompt | |||
| from rag.settings import PARALLEL_DEVICES | |||
| LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" | |||
| @@ -45,7 +46,7 @@ if LOCK_KEY_pdfplumber not in sys.modules: | |||
| class RAGFlowPdfParser: | |||
| def __init__(self): | |||
| def __init__(self, **kwargs): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| @@ -57,12 +58,12 @@ class RAGFlowPdfParser: | |||
| ^_- | |||
| """ | |||
| self.ocr = OCR() | |||
| self.parallel_limiter = None | |||
| if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1: | |||
| self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] | |||
| if hasattr(self, "model_speciess"): | |||
| self.layouter = LayoutRecognizer("layout." + self.model_speciess) | |||
| else: | |||
| @@ -106,7 +107,7 @@ class RAGFlowPdfParser: | |||
| def _y_dis( | |||
| self, a, b): | |||
| return ( | |||
| b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 | |||
| b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 | |||
| def _match_proj(self, b): | |||
| proj_patt = [ | |||
| @@ -129,9 +130,9 @@ class RAGFlowPdfParser: | |||
| tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split() | |||
| tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split() | |||
| tks_all = up["text"][-LEN:].strip() \ | |||
| + (" " if re.match(r"[a-zA-Z0-9]+", | |||
| up["text"][-1] + down["text"][0]) else "") \ | |||
| + down["text"][:LEN].strip() | |||
| + (" " if re.match(r"[a-zA-Z0-9]+", | |||
| up["text"][-1] + down["text"][0]) else "") \ | |||
| + down["text"][:LEN].strip() | |||
| tks_all = rag_tokenizer.tokenize(tks_all).split() | |||
| fea = [ | |||
| up.get("R", -1) == down.get("R", -1), | |||
| @@ -153,7 +154,7 @@ class RAGFlowPdfParser: | |||
| True if re.search(r"[,,][^。.]+$", up["text"]) else False, | |||
| True if re.search(r"[,,][^。.]+$", up["text"]) else False, | |||
| True if re.search(r"[\((][^\))]+$", up["text"]) | |||
| and re.search(r"[\))]", down["text"]) else False, | |||
| and re.search(r"[\))]", down["text"]) else False, | |||
| self._match_proj(down), | |||
| True if re.match(r"[A-Z]", down["text"]) else False, | |||
| True if re.match(r"[A-Z]", up["text"][-1]) else False, | |||
| @@ -215,7 +216,7 @@ class RAGFlowPdfParser: | |||
| continue | |||
| for tb in tbls: # for table | |||
| left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ | |||
| tb["x1"] + MARGIN, tb["bottom"] + MARGIN | |||
| tb["x1"] + MARGIN, tb["bottom"] + MARGIN | |||
| left *= ZM | |||
| top *= ZM | |||
| right *= ZM | |||
| @@ -309,7 +310,7 @@ class RAGFlowPdfParser: | |||
| "page_number": pagenum} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], | |||
| self.mean_height[-1] / 3 | |||
| ) | |||
| # merge chars in the same rect | |||
| for c in Recognizer.sort_Y_firstly( | |||
| chars, self.mean_height[pagenum - 1] // 4): | |||
| @@ -457,7 +458,7 @@ class RAGFlowPdfParser: | |||
| b_["text"], | |||
| any(feats), | |||
| any(concatting_feats), | |||
| )) | |||
| )) | |||
| i += 1 | |||
| continue | |||
| # merge up and down | |||
| @@ -665,7 +666,7 @@ class RAGFlowPdfParser: | |||
| i += 1 | |||
| continue | |||
| lout_no = str(self.boxes[i]["page_number"]) + \ | |||
| "-" + str(self.boxes[i]["layoutno"]) | |||
| "-" + str(self.boxes[i]["layoutno"]) | |||
| if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", | |||
| "title", | |||
| "figure caption", | |||
| @@ -968,7 +969,7 @@ class RAGFlowPdfParser: | |||
| fnm) if not binary else pdfplumber.open(BytesIO(binary)) | |||
| total_page = len(pdf.pages) | |||
| pdf.close() | |||
| return total_page | |||
| return total_page | |||
| except Exception: | |||
| logging.exception("total_page_number") | |||
| @@ -994,7 +995,7 @@ class RAGFlowPdfParser: | |||
| except Exception as e: | |||
| logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}") | |||
| self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead. | |||
| self.total_page = len(self.pdf.pages) | |||
| except Exception: | |||
| logging.exception("RAGFlowPdfParser __images__") | |||
| @@ -1023,7 +1024,7 @@ class RAGFlowPdfParser: | |||
| logging.debug("Images converted.") | |||
| self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( | |||
| random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in | |||
| range(len(self.page_chars))] | |||
| range(len(self.page_chars))] | |||
| if sum([1 if e else 0 for e in self.is_english]) > len( | |||
| self.page_images) / 2: | |||
| self.is_english = True | |||
| @@ -1036,7 +1037,7 @@ class RAGFlowPdfParser: | |||
| if chars[j]["text"] and chars[j + 1]["text"] \ | |||
| and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \ | |||
| and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"], | |||
| chars[j]["width"]) / 2: | |||
| chars[j]["width"]) / 2: | |||
| chars[j]["text"] += " " | |||
| j += 1 | |||
| @@ -1045,7 +1046,7 @@ class RAGFlowPdfParser: | |||
| await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id)) | |||
| else: | |||
| self.__ocr(i + 1, img, chars, zoomin, id) | |||
| if callback and i % 6 == 5: | |||
| callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") | |||
| @@ -1060,14 +1061,14 @@ class RAGFlowPdfParser: | |||
| ) | |||
| self.page_cum_height.append(img.size[1] / zoomin) | |||
| return chars | |||
| if self.parallel_limiter: | |||
| async with trio.open_nursery() as nursery: | |||
| for i, img in enumerate(self.page_images): | |||
| chars = __ocr_preprocess() | |||
| nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, | |||
| self.parallel_limiter[i % PARALLEL_DEVICES]) | |||
| self.parallel_limiter[i % PARALLEL_DEVICES]) | |||
| await trio.sleep(0.1) | |||
| else: | |||
| for i, img in enumerate(self.page_images): | |||
| @@ -1075,9 +1076,9 @@ class RAGFlowPdfParser: | |||
| await __img_ocr(i, 0, img, chars, None) | |||
| start = timer() | |||
| trio.run(__img_ocr_launcher) | |||
| logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") | |||
| if not self.is_english and not any( | |||
| @@ -1142,7 +1143,7 @@ class RAGFlowPdfParser: | |||
| self.page_images[pns[0]].crop((left * ZM, top * ZM, | |||
| right * | |||
| ZM, min( | |||
| bottom, self.page_images[pns[0]].size[1]) | |||
| bottom, self.page_images[pns[0]].size[1]) | |||
| )) | |||
| ) | |||
| if 0 < ii < len(poss) - 1: | |||
| @@ -1240,5 +1241,52 @@ class PlainParser: | |||
| raise NotImplementedError | |||
| class VisionParser(RAGFlowPdfParser): | |||
| def __init__(self, vision_model, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.vision_model = vision_model | |||
| def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None): | |||
| try: | |||
| with sys.modules[LOCK_KEY_pdfplumber]: | |||
| self.pdf = pdfplumber.open(fnm) if isinstance( | |||
| fnm, str) else pdfplumber.open(BytesIO(fnm)) | |||
| self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in | |||
| enumerate(self.pdf.pages[page_from:page_to])] | |||
| self.total_page = len(self.pdf.pages) | |||
| except Exception: | |||
| self.page_images = None | |||
| self.total_page = 0 | |||
| logging.exception("VisionParser __images__") | |||
| def __call__(self, filename, from_page=0, to_page=100000, **kwargs): | |||
| callback = kwargs.get("callback", lambda prog, msg: None) | |||
| self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs) | |||
| total_pdf_pages = self.total_page | |||
| start_page = max(0, from_page) | |||
| end_page = min(to_page, total_pdf_pages) | |||
| all_docs = [] | |||
| for idx, img_binary in enumerate(self.page_images or []): | |||
| pdf_page_num = idx # 0-based | |||
| if pdf_page_num < start_page or pdf_page_num >= end_page: | |||
| continue | |||
| docs = picture_vision_llm_chunk( | |||
| binary=img_binary, | |||
| vision_model=self.vision_model, | |||
| prompt=vision_llm_describe_prompt(page=pdf_page_num+1), | |||
| callback=callback, | |||
| ) | |||
| if docs: | |||
| all_docs.append(docs) | |||
| return [(doc, "") for doc in all_docs], [] | |||
| if __name__ == "__main__": | |||
| pass | |||
| @@ -26,8 +26,10 @@ from markdown import markdown | |||
| from PIL import Image | |||
| from tika import parser | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser | |||
| from deepdoc.parser.pdf_parser import PlainParser | |||
| from deepdoc.parser.pdf_parser import PlainParser, VisionParser | |||
| from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_docx, tokenize_table | |||
| from rag.utils import num_tokens_from_string | |||
| @@ -237,9 +239,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, | |||
| return res | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() | |||
| if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text": | |||
| layout_recognizer = parser_config.get("layout_recognize", "DeepDOC") | |||
| if layout_recognizer == "DeepDOC": | |||
| pdf_parser = Pdf() | |||
| elif layout_recognizer == "Plain Text": | |||
| pdf_parser = PlainParser() | |||
| else: | |||
| vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang) | |||
| pdf_parser = VisionParser(vision_model=vision_model, **kwargs) | |||
| sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, | |||
| callback=callback) | |||
| res = tokenize_table(tables, doc, is_english) | |||
| @@ -21,8 +21,9 @@ from PIL import Image | |||
| from api.db import LLMType | |||
| from api.db.services.llm_service import LLMBundle | |||
| from rag.nlp import tokenize | |||
| from deepdoc.vision import OCR | |||
| from rag.nlp import tokenize | |||
| from rag.utils import clean_markdown_block | |||
| ocr = OCR() | |||
| @@ -57,3 +58,32 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): | |||
| callback(prog=-1, msg=str(e)) | |||
| return [] | |||
| def vision_llm_chunk(binary, vision_model, prompt=None, callback=None): | |||
| """ | |||
| A simple wrapper to process image to markdown texts via VLM. | |||
| Returns: | |||
| Simple markdown texts generated by VLM. | |||
| """ | |||
| callback = callback or (lambda prog, msg: None) | |||
| img = binary | |||
| txt = "" | |||
| try: | |||
| img_binary = io.BytesIO() | |||
| img.save(img_binary, format='JPEG') | |||
| img_binary.seek(0) | |||
| ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt)) | |||
| txt += "\n" + ans | |||
| return txt | |||
| except Exception as e: | |||
| callback(-1, str(e)) | |||
| return [] | |||
| @@ -13,31 +13,36 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from openai.lib.azure import AzureOpenAI | |||
| from zhipuai import ZhipuAI | |||
| import base64 | |||
| import io | |||
| from abc import ABC | |||
| from ollama import Client | |||
| from PIL import Image | |||
| from openai import OpenAI | |||
| import json | |||
| import os | |||
| import base64 | |||
| from abc import ABC | |||
| from io import BytesIO | |||
| import json | |||
| import requests | |||
| from ollama import Client | |||
| from openai import OpenAI | |||
| from openai.lib.azure import AzureOpenAI | |||
| from PIL import Image | |||
| from zhipuai import ZhipuAI | |||
| from rag.nlp import is_english | |||
| from api.utils import get_uuid | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from rag.nlp import is_english | |||
| from rag.prompts import vision_llm_describe_prompt | |||
| class Base(ABC): | |||
| def __init__(self, key, model_name): | |||
| pass | |||
| def describe(self, image, max_tokens=300): | |||
| def describe(self, image): | |||
| raise NotImplementedError("Please implement encode method!") | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| raise NotImplementedError("Please implement encode method!") | |||
| def chat(self, system, history, gen_conf, image=""): | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| @@ -90,7 +95,7 @@ class Base(ABC): | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| yield tk_count | |||
| def image2base64(self, image): | |||
| if isinstance(image, bytes): | |||
| return base64.b64encode(image).decode("utf-8") | |||
| @@ -122,6 +127,25 @@ class Base(ABC): | |||
| } | |||
| ] | |||
| def vision_llm_prompt(self, b64, prompt=None): | |||
| return [ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "type": "image_url", | |||
| "image_url": { | |||
| "url": f"data:image/jpeg;base64,{b64}" | |||
| }, | |||
| }, | |||
| { | |||
| "type": "text", | |||
| "text": prompt if prompt else vision_llm_describe_prompt(), | |||
| }, | |||
| ], | |||
| } | |||
| ] | |||
| def chat_prompt(self, text, b64): | |||
| return [ | |||
| { | |||
| @@ -140,12 +164,12 @@ class Base(ABC): | |||
| class GptV4(Base): | |||
| def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): | |||
| if not base_url: | |||
| base_url="https://api.openai.com/v1" | |||
| base_url = "https://api.openai.com/v1" | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| def describe(self, image, max_tokens=300): | |||
| def describe(self, image): | |||
| b64 = self.image2base64(image) | |||
| prompt = self.prompt(b64) | |||
| for i in range(len(prompt)): | |||
| @@ -159,6 +183,16 @@ class GptV4(Base): | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=vision_prompt, | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| class AzureGptV4(Base): | |||
| def __init__(self, key, model_name, lang="Chinese", **kwargs): | |||
| @@ -168,7 +202,7 @@ class AzureGptV4(Base): | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| def describe(self, image, max_tokens=300): | |||
| def describe(self, image): | |||
| b64 = self.image2base64(image) | |||
| prompt = self.prompt(b64) | |||
| for i in range(len(prompt)): | |||
| @@ -182,6 +216,16 @@ class AzureGptV4(Base): | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=vision_prompt, | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| class QWenCV(Base): | |||
| def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs): | |||
| @@ -212,23 +256,57 @@ class QWenCV(Base): | |||
| } | |||
| ] | |||
| def vision_llm_prompt(self, binary, prompt=None): | |||
| # stupid as hell | |||
| tmp_dir = get_project_base_directory("tmp") | |||
| if not os.path.exists(tmp_dir): | |||
| os.mkdir(tmp_dir) | |||
| path = os.path.join(tmp_dir, "%s.jpg" % get_uuid()) | |||
| Image.open(io.BytesIO(binary)).save(path) | |||
| return [ | |||
| { | |||
| "role": "user", | |||
| "content": [ | |||
| { | |||
| "image": f"file://{path}" | |||
| }, | |||
| { | |||
| "text": prompt if prompt else vision_llm_describe_prompt(), | |||
| }, | |||
| ], | |||
| } | |||
| ] | |||
| def chat_prompt(self, text, b64): | |||
| return [ | |||
| {"image": f"{b64}"}, | |||
| {"text": text}, | |||
| ] | |||
| def describe(self, image, max_tokens=300): | |||
| def describe(self, image): | |||
| from http import HTTPStatus | |||
| from dashscope import MultiModalConversation | |||
| response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image)) | |||
| if response.status_code == HTTPStatus.OK: | |||
| return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens | |||
| return response.message, 0 | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| from http import HTTPStatus | |||
| from dashscope import MultiModalConversation | |||
| response = MultiModalConversation.call(model=self.model_name, | |||
| messages=self.prompt(image)) | |||
| vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image) | |||
| response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt) | |||
| if response.status_code == HTTPStatus.OK: | |||
| return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens | |||
| return response.message, 0 | |||
| def chat(self, system, history, gen_conf, image=""): | |||
| from http import HTTPStatus | |||
| from dashscope import MultiModalConversation | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| @@ -254,6 +332,7 @@ class QWenCV(Base): | |||
| def chat_streamly(self, system, history, gen_conf, image=""): | |||
| from http import HTTPStatus | |||
| from dashscope import MultiModalConversation | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| @@ -292,15 +371,25 @@ class Zhipu4V(Base): | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| def describe(self, image, max_tokens=1024): | |||
| def describe(self, image): | |||
| b64 = self.image2base64(image) | |||
| prompt = self.prompt(b64) | |||
| prompt[0]["content"][1]["type"] = "text" | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=prompt | |||
| messages=prompt, | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=vision_prompt | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| @@ -334,7 +423,7 @@ class Zhipu4V(Base): | |||
| his["content"] = self.chat_prompt(his["content"], image) | |||
| response = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| model=self.model_name, | |||
| messages=history, | |||
| temperature=gen_conf.get("temperature", 0.3), | |||
| top_p=gen_conf.get("top_p", 0.7), | |||
| @@ -364,7 +453,7 @@ class OllamaCV(Base): | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| def describe(self, image, max_tokens=1024): | |||
| def describe(self, image): | |||
| prompt = self.prompt("") | |||
| try: | |||
| response = self.client.generate( | |||
| @@ -377,6 +466,19 @@ class OllamaCV(Base): | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("") | |||
| try: | |||
| response = self.client.generate( | |||
| model=self.model_name, | |||
| prompt=vision_prompt[0]["content"][1]["text"], | |||
| images=[image], | |||
| ) | |||
| ans = response["response"].strip() | |||
| return ans, 128 | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| def chat(self, system, history, gen_conf, image=""): | |||
| if system: | |||
| history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] | |||
| @@ -460,7 +562,7 @@ class XinferenceCV(Base): | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| def describe(self, image, max_tokens=300): | |||
| def describe(self, image): | |||
| b64 = self.image2base64(image) | |||
| res = self.client.chat.completions.create( | |||
| @@ -469,27 +571,49 @@ class XinferenceCV(Base): | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| res = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=vision_prompt, | |||
| ) | |||
| return res.choices[0].message.content.strip(), res.usage.total_tokens | |||
| class GeminiCV(Base): | |||
| def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): | |||
| from google.generativeai import client, GenerativeModel | |||
| from google.generativeai import GenerativeModel, client | |||
| client.configure(api_key=key) | |||
| _client = client.get_default_generative_client() | |||
| self.model_name = model_name | |||
| self.model = GenerativeModel(model_name=self.model_name) | |||
| self.model._client = _client | |||
| self.lang = lang | |||
| self.lang = lang | |||
| def describe(self, image, max_tokens=2048): | |||
| def describe(self, image): | |||
| from PIL.Image import open | |||
| prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ | |||
| "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." | |||
| b64 = self.image2base64(image) | |||
| img = open(BytesIO(base64.b64decode(b64))) | |||
| input = [prompt,img] | |||
| b64 = self.image2base64(image) | |||
| img = open(BytesIO(base64.b64decode(b64))) | |||
| input = [prompt, img] | |||
| res = self.model.generate_content( | |||
| input | |||
| ) | |||
| return res.text,res.usage_metadata.total_token_count | |||
| return res.text, res.usage_metadata.total_token_count | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| from PIL.Image import open | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| img = open(BytesIO(base64.b64decode(b64))) | |||
| input = [vision_prompt, img] | |||
| res = self.model.generate_content( | |||
| input, | |||
| ) | |||
| return res.text, res.usage_metadata.total_token_count | |||
| def chat(self, system, history, gen_conf, image=""): | |||
| from transformers import GenerationConfig | |||
| @@ -566,7 +690,7 @@ class LocalCV(Base): | |||
| def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): | |||
| pass | |||
| def describe(self, image, max_tokens=1024): | |||
| def describe(self, image): | |||
| return "", 0 | |||
| @@ -590,7 +714,7 @@ class NvidiaCV(Base): | |||
| ) | |||
| self.key = key | |||
| def describe(self, image, max_tokens=1024): | |||
| def describe(self, image): | |||
| b64 = self.image2base64(image) | |||
| response = requests.post( | |||
| url=self.base_url, | |||
| @@ -609,6 +733,27 @@ class NvidiaCV(Base): | |||
| response["usage"]["total_tokens"], | |||
| ) | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| response = requests.post( | |||
| url=self.base_url, | |||
| headers={ | |||
| "accept": "application/json", | |||
| "content-type": "application/json", | |||
| "Authorization": f"Bearer {self.key}", | |||
| }, | |||
| json={ | |||
| "messages": vision_prompt, | |||
| }, | |||
| ) | |||
| response = response.json() | |||
| return ( | |||
| response["choices"][0]["message"]["content"].strip(), | |||
| response["usage"]["total_tokens"], | |||
| ) | |||
| def prompt(self, b64): | |||
| return [ | |||
| { | |||
| @@ -622,6 +767,17 @@ class NvidiaCV(Base): | |||
| } | |||
| ] | |||
| def vision_llm_prompt(self, b64, prompt=None): | |||
| return [ | |||
| { | |||
| "role": "user", | |||
| "content": ( | |||
| prompt if prompt else vision_llm_describe_prompt() | |||
| ) | |||
| + f' <img src="data:image/jpeg;base64,{b64}"/>', | |||
| } | |||
| ] | |||
| def chat_prompt(self, text, b64): | |||
| return [ | |||
| { | |||
| @@ -634,7 +790,7 @@ class NvidiaCV(Base): | |||
| class StepFunCV(GptV4): | |||
| def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): | |||
| if not base_url: | |||
| base_url="https://api.stepfun.com/v1" | |||
| base_url = "https://api.stepfun.com/v1" | |||
| self.client = OpenAI(api_key=key, base_url=base_url) | |||
| self.model_name = model_name | |||
| self.lang = lang | |||
| @@ -666,18 +822,18 @@ class TogetherAICV(GptV4): | |||
| def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"): | |||
| if not base_url: | |||
| base_url = "https://api.together.xyz/v1" | |||
| super().__init__(key, model_name,lang,base_url) | |||
| super().__init__(key, model_name, lang, base_url) | |||
| class YiCV(GptV4): | |||
| def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",): | |||
| def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",): | |||
| if not base_url: | |||
| base_url = "https://api.lingyiwanwu.com/v1" | |||
| super().__init__(key, model_name,lang,base_url) | |||
| super().__init__(key, model_name, lang, base_url) | |||
| class HunyuanCV(Base): | |||
| def __init__(self, key, model_name, lang="Chinese",base_url=None): | |||
| def __init__(self, key, model_name, lang="Chinese", base_url=None): | |||
| from tencentcloud.common import credential | |||
| from tencentcloud.hunyuan.v20230901 import hunyuan_client | |||
| @@ -689,12 +845,12 @@ class HunyuanCV(Base): | |||
| self.client = hunyuan_client.HunyuanClient(cred, "") | |||
| self.lang = lang | |||
| def describe(self, image, max_tokens=4096): | |||
| from tencentcloud.hunyuan.v20230901 import models | |||
| def describe(self, image): | |||
| from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( | |||
| TencentCloudSDKException, | |||
| ) | |||
| from tencentcloud.hunyuan.v20230901 import models | |||
| b64 = self.image2base64(image) | |||
| req = models.ChatCompletionsRequest() | |||
| params = {"Model": self.model_name, "Messages": self.prompt(b64)} | |||
| @@ -706,7 +862,24 @@ class HunyuanCV(Base): | |||
| return ans, response.Usage.TotalTokens | |||
| except TencentCloudSDKException as e: | |||
| return ans + "\n**ERROR**: " + str(e), 0 | |||
| def describe_with_prompt(self, image, prompt=None): | |||
| from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException | |||
| from tencentcloud.hunyuan.v20230901 import models | |||
| b64 = self.image2base64(image) | |||
| vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64) | |||
| req = models.ChatCompletionsRequest() | |||
| params = {"Model": self.model_name, "Messages": vision_prompt} | |||
| req.from_json_string(json.dumps(params)) | |||
| ans = "" | |||
| try: | |||
| response = self.client.ChatCompletions(req) | |||
| ans = response.Choices[0].Message.Content | |||
| return ans, response.Usage.TotalTokens | |||
| except TencentCloudSDKException as e: | |||
| return ans + "\n**ERROR**: " + str(e), 0 | |||
| def prompt(self, b64): | |||
| return [ | |||
| { | |||
| @@ -725,4 +898,4 @@ class HunyuanCV(Base): | |||
| }, | |||
| ], | |||
| } | |||
| ] | |||
| ] | |||
| @@ -18,13 +18,13 @@ import json | |||
| import logging | |||
| import re | |||
| from collections import defaultdict | |||
| import json_repair | |||
| from api import settings | |||
| from api.db import LLMType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.llm_service import TenantLLMService, LLMBundle | |||
| from rag.settings import TAG_FLD | |||
| from rag.utils import num_tokens_from_string, encoder | |||
| from rag.utils import encoder, num_tokens_from_string | |||
| def chunks_format(reference): | |||
| @@ -44,9 +44,11 @@ def chunks_format(reference): | |||
| def llm_id2llm_type(llm_id): | |||
| from api.db.services.llm_service import TenantLLMService | |||
| llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) | |||
| llm_factories = settings.FACTORY_LLM_INFOS | |||
| llm_factories = settings.FACTORY_LLM_INFOS | |||
| for llm_factory in llm_factories: | |||
| for llm in llm_factory["llm"]: | |||
| if llm_id == llm["llm_name"]: | |||
| @@ -92,6 +94,8 @@ def message_fit_in(msg, max_length=4000): | |||
| def kb_prompt(kbinfos, max_tokens): | |||
| from api.db.services.document_service import DocumentService | |||
| knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] | |||
| used_token_count = 0 | |||
| chunks_num = 0 | |||
| @@ -166,15 +170,15 @@ Overall, while Musk enjoys Dogecoin and often promotes it, he also warns against | |||
| def keyword_extraction(chat_mdl, content, topn=3): | |||
| prompt = f""" | |||
| Role: You're a text analyzer. | |||
| Role: You're a text analyzer. | |||
| Task: extract the most important keywords/phrases of a given piece of text content. | |||
| Requirements: | |||
| Requirements: | |||
| - Summarize the text content, and give top {topn} important keywords/phrases. | |||
| - The keywords MUST be in language of the given piece of text content. | |||
| - The keywords are delimited by ENGLISH COMMA. | |||
| - Keywords ONLY in output. | |||
| ### Text Content | |||
| ### Text Content | |||
| {content} | |||
| """ | |||
| @@ -194,9 +198,9 @@ Requirements: | |||
| def question_proposal(chat_mdl, content, topn=3): | |||
| prompt = f""" | |||
| Role: You're a text analyzer. | |||
| Role: You're a text analyzer. | |||
| Task: propose {topn} questions about a given piece of text content. | |||
| Requirements: | |||
| Requirements: | |||
| - Understand and summarize the text content, and propose top {topn} important questions. | |||
| - The questions SHOULD NOT have overlapping meanings. | |||
| - The questions SHOULD cover the main content of the text as much as possible. | |||
| @@ -204,7 +208,7 @@ Requirements: | |||
| - One question per line. | |||
| - Question ONLY in output. | |||
| ### Text Content | |||
| ### Text Content | |||
| {content} | |||
| """ | |||
| @@ -223,6 +227,8 @@ Requirements: | |||
| def full_question(tenant_id, llm_id, messages, language=None): | |||
| from api.db.services.llm_service import LLMBundle | |||
| if llm_id2llm_type(llm_id) == "image2text": | |||
| chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) | |||
| else: | |||
| @@ -239,7 +245,7 @@ def full_question(tenant_id, llm_id, messages, language=None): | |||
| prompt = f""" | |||
| Role: A helpful assistant | |||
| Task and steps: | |||
| Task and steps: | |||
| 1. Generate a full user question that would follow the conversation. | |||
| 2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}. | |||
| @@ -300,11 +306,11 @@ Output: What's the weather in Rochester on {tomorrow}? | |||
| def content_tagging(chat_mdl, content, all_tags, examples, topn=3): | |||
| prompt = f""" | |||
| Role: You're a text analyzer. | |||
| Role: You're a text analyzer. | |||
| Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set. | |||
| Steps:: | |||
| Steps:: | |||
| - Comprehend the tag/label set. | |||
| - Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON. | |||
| - Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score. | |||
| @@ -358,3 +364,32 @@ Output: | |||
| except Exception as e: | |||
| logging.exception(f"JSON parsing error: {result} -> {e}") | |||
| raise e | |||
| def vision_llm_describe_prompt(page=None) -> str: | |||
| prompt_en = """ | |||
| INSTRUCTION: | |||
| Transcribe the content from the provided PDF page image into clean Markdown format. | |||
| - Only output the content transcribed from the image. | |||
| - Do NOT output this instruction or any other explanation. | |||
| - If the content is missing or you do not understand the input, return an empty string. | |||
| RULES: | |||
| 1. Do NOT generate examples, demonstrations, or templates. | |||
| 2. Do NOT output any extra text such as 'Example', 'Example Output', or similar. | |||
| 3. Do NOT generate any tables, headings, or content that is not explicitly present in the image. | |||
| 4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content. | |||
| 5. Do NOT explain Markdown or mention that you are using Markdown. | |||
| 6. Do NOT wrap the output in ```markdown or ``` blocks. | |||
| 7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image. | |||
| 8. Preserve the original language, information, and order exactly as shown in the image. | |||
| """ | |||
| if page is not None: | |||
| prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`." | |||
| prompt_en += """ | |||
| FAILURE HANDLING: | |||
| - If you do not detect valid content in the image, return an empty string. | |||
| """ | |||
| return prompt_en | |||
| @@ -16,7 +16,9 @@ | |||
| import os | |||
| import re | |||
| import tiktoken | |||
| from api.utils.file_utils import get_project_base_directory | |||
| @@ -54,7 +56,7 @@ def findMaxDt(fnm): | |||
| pass | |||
| return m | |||
| def findMaxTm(fnm): | |||
| m = 0 | |||
| try: | |||
| @@ -91,11 +93,18 @@ def truncate(string: str, max_len: int) -> str: | |||
| """Returns truncated text if the length of text exceed max_len.""" | |||
| return encoder.decode(encoder.encode(string)[:max_len]) | |||
| def clean_markdown_block(text): | |||
| text = re.sub(r'^\s*```markdown\s*\n?', '', text) | |||
| text = re.sub(r'\n?\s*```\s*$', '', text) | |||
| return text.strip() | |||
| def get_float(v: str | None): | |||
| if v is None: | |||
| return float('-inf') | |||
| try: | |||
| return float(v) | |||
| except Exception: | |||
| return float('-inf') | |||
| return float('-inf') | |||