You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

parse_user_docs.py 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import logging
  18. import os
  19. import hashlib
  20. import copy
  21. import time
  22. import random
  23. import re
  24. from timeit import default_timer as timer
  25. from rag.llm import EmbeddingModel, CvModel
  26. from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
  27. from rag.utils import ELASTICSEARCH
  28. from rag.utils import MINIO
  29. from rag.utils import rmSpace, findMaxTm
  30. from rag.nlp import huchunk, huqie, search
  31. from io import BytesIO
  32. import pandas as pd
  33. from elasticsearch_dsl import Q
  34. from PIL import Image
  35. from rag.parser import (
  36. PdfParser,
  37. DocxParser,
  38. ExcelParser
  39. )
  40. from rag.nlp.huchunk import (
  41. PdfChunker,
  42. DocxChunker,
  43. ExcelChunker,
  44. PptChunker,
  45. TextChunker
  46. )
  47. from web_server.db import LLMType
  48. from web_server.db.services.document_service import DocumentService
  49. from web_server.db.services.llm_service import TenantLLMService
  50. from web_server.settings import database_logger
  51. from web_server.utils import get_format_time
  52. from web_server.utils.file_utils import get_project_base_directory
  53. BATCH_SIZE = 64
  54. PDF = PdfChunker(PdfParser())
  55. DOC = DocxChunker(DocxParser())
  56. EXC = ExcelChunker(ExcelParser())
  57. PPT = PptChunker()
  58. def chuck_doc(name, binary, cvmdl=None):
  59. suff = os.path.split(name)[-1].lower().split(".")[-1]
  60. if suff.find("pdf") >= 0:
  61. return PDF(binary)
  62. if suff.find("doc") >= 0:
  63. return DOC(binary)
  64. if re.match(r"(xlsx|xlsm|xltx|xltm)", suff):
  65. return EXC(binary)
  66. if suff.find("ppt") >= 0:
  67. return PPT(binary)
  68. if cvmdl and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
  69. name.lower()):
  70. txt = cvmdl.describe(binary)
  71. field = TextChunker.Fields()
  72. field.text_chunks = [(txt, binary)]
  73. field.table_chunks = []
  74. return TextChunker()(binary)
  75. def collect(comm, mod, tm):
  76. docs = DocumentService.get_newly_uploaded(tm, mod, comm)
  77. if len(docs) == 0:
  78. return pd.DataFrame()
  79. docs = pd.DataFrame(docs)
  80. mtm = docs["update_time"].max()
  81. cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
  82. return docs
  83. def set_progress(docid, prog, msg="Processing...", begin=False):
  84. d = {"progress": prog, "progress_msg": msg}
  85. if begin:
  86. d["process_begin_at"] = get_format_time()
  87. try:
  88. DocumentService.update_by_id(
  89. docid, {"progress": prog, "progress_msg": msg})
  90. except Exception as e:
  91. cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
  92. def build(row, cvmdl):
  93. if row["size"] > DOC_MAXIMUM_SIZE:
  94. set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
  95. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  96. return []
  97. res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
  98. if ELASTICSEARCH.getTotal(res) > 0:
  99. ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
  100. scripts="""
  101. if(!ctx._source.kb_id.contains('%s'))
  102. ctx._source.kb_id.add('%s');
  103. """ % (str(row["kb_id"]), str(row["kb_id"])),
  104. idxnm=search.index_name(row["tenant_id"])
  105. )
  106. set_progress(row["id"], 1, "Done")
  107. return []
  108. random.seed(time.time())
  109. set_progress(row["id"], random.randint(0, 20) /
  110. 100., "Finished preparing! Start to slice file!", True)
  111. try:
  112. cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
  113. obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
  114. except Exception as e:
  115. if re.search("(No such file|not found)", str(e)):
  116. set_progress(
  117. row["id"], -1, "Can not find file <%s>" %
  118. row["doc_name"])
  119. else:
  120. set_progress(
  121. row["id"], -1, f"Internal server error: %s" %
  122. str(e).replace(
  123. "'", ""))
  124. cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
  125. return []
  126. if not obj.text_chunks and not obj.table_chunks:
  127. set_progress(
  128. row["id"],
  129. 1,
  130. "Nothing added! Mostly, file type unsupported yet.")
  131. return []
  132. set_progress(row["id"], random.randint(20, 60) / 100.,
  133. "Finished slicing files. Start to embedding the content.")
  134. doc = {
  135. "doc_id": row["id"],
  136. "kb_id": [str(row["kb_id"])],
  137. "docnm_kwd": os.path.split(row["location"])[-1],
  138. "title_tks": huqie.qie(row["name"]),
  139. "updated_at": str(row["update_time"]).replace("T", " ")[:19]
  140. }
  141. doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
  142. output_buffer = BytesIO()
  143. docs = []
  144. md5 = hashlib.md5()
  145. for txt, img in obj.text_chunks:
  146. d = copy.deepcopy(doc)
  147. md5.update((txt + str(d["doc_id"])).encode("utf-8"))
  148. d["_id"] = md5.hexdigest()
  149. d["content_ltks"] = huqie.qie(txt)
  150. d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
  151. if not img:
  152. docs.append(d)
  153. continue
  154. if isinstance(img, bytes):
  155. output_buffer = BytesIO(img)
  156. else:
  157. img.save(output_buffer, format='JPEG')
  158. MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
  159. d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
  160. docs.append(d)
  161. for arr, img in obj.table_chunks:
  162. for i, txt in enumerate(arr):
  163. d = copy.deepcopy(doc)
  164. d["content_ltks"] = huqie.qie(txt)
  165. md5.update((txt + str(d["doc_id"])).encode("utf-8"))
  166. d["_id"] = md5.hexdigest()
  167. if not img:
  168. docs.append(d)
  169. continue
  170. img.save(output_buffer, format='JPEG')
  171. MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
  172. d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
  173. docs.append(d)
  174. set_progress(row["id"], random.randint(60, 70) /
  175. 100., "Continue embedding the content.")
  176. return docs
  177. def init_kb(row):
  178. idxnm = search.index_name(row["tenant_id"])
  179. if ELASTICSEARCH.indexExist(idxnm):
  180. return
  181. return ELASTICSEARCH.createIdx(idxnm, json.load(
  182. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  183. def embedding(docs, mdl):
  184. tts, cnts = [rmSpace(d["title_tks"]) for d in docs], [rmSpace(d["content_ltks"]) for d in docs]
  185. tk_count = 0
  186. tts, c = mdl.encode(tts)
  187. tk_count += c
  188. cnts, c = mdl.encode(cnts)
  189. tk_count += c
  190. vects = 0.1 * tts + 0.9 * cnts
  191. assert len(vects) == len(docs)
  192. for i, d in enumerate(docs):
  193. d["q_vec"] = vects[i].tolist()
  194. return tk_count
  195. def model_instance(tenant_id, llm_type):
  196. model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
  197. if not model_config:
  198. model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
  199. else: model_config = model_config[0].to_dict()
  200. if llm_type == LLMType.EMBEDDING:
  201. if model_config["llm_factory"] not in EmbeddingModel: return
  202. return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
  203. if llm_type == LLMType.IMAGE2TEXT:
  204. if model_config["llm_factory"] not in CvModel: return
  205. return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
  206. def main(comm, mod):
  207. global model
  208. from rag.llm import HuEmbedding
  209. model = HuEmbedding()
  210. tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
  211. tm = findMaxTm(tm_fnm)
  212. rows = collect(comm, mod, tm)
  213. if len(rows) == 0:
  214. return
  215. tmf = open(tm_fnm, "a+")
  216. for _, r in rows.iterrows():
  217. embd_mdl = model_instance(r["tenant_id"], LLMType.EMBEDDING)
  218. if not embd_mdl:
  219. set_progress(r["id"], -1, "Can't find embedding model!")
  220. cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
  221. continue
  222. cv_mdl = model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
  223. st_tm = timer()
  224. cks = build(r, cv_mdl)
  225. if not cks:
  226. tmf.write(str(r["update_time"]) + "\n")
  227. continue
  228. # TODO: exception handler
  229. ## set_progress(r["did"], -1, "ERROR: ")
  230. try:
  231. tk_count = embedding(cks, embd_mdl)
  232. except Exception as e:
  233. set_progress(r["id"], -1, "Embedding error:{}".format(str(e)))
  234. cron_logger.error(str(e))
  235. continue
  236. set_progress(r["id"], random.randint(70, 95) / 100.,
  237. "Finished embedding! Start to build index!")
  238. init_kb(r)
  239. es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
  240. if es_r:
  241. set_progress(r["id"], -1, "Index failure!")
  242. cron_logger.error(str(es_r))
  243. else:
  244. set_progress(r["id"], 1., "Done!")
  245. DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
  246. cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
  247. tmf.write(str(r["update_time"]) + "\n")
  248. tmf.close()
  249. if __name__ == "__main__":
  250. peewee_logger = logging.getLogger('peewee')
  251. peewee_logger.propagate = False
  252. peewee_logger.addHandler(database_logger.handlers[0])
  253. peewee_logger.setLevel(database_logger.level)
  254. from mpi4py import MPI
  255. comm = MPI.COMM_WORLD
  256. main(comm.Get_size(), comm.Get_rank())