You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import datetime
  17. import json
  18. import logging
  19. import os
  20. import hashlib
  21. import copy
  22. import re
  23. import sys
  24. import traceback
  25. from functools import partial
  26. from timeit import default_timer as timer
  27. from elasticsearch_dsl import Q
  28. from api.db.services.task_service import TaskService
  29. from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
  30. from rag.utils import ELASTICSEARCH
  31. from rag.utils import MINIO
  32. from rag.utils import rmSpace, findMaxTm
  33. from rag.nlp import search
  34. from io import BytesIO
  35. import pandas as pd
  36. from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive
  37. from api.db import LLMType, ParserType
  38. from api.db.services.document_service import DocumentService
  39. from api.db.services.llm_service import LLMBundle
  40. from api.settings import database_logger
  41. from api.utils.file_utils import get_project_base_directory
  42. BATCH_SIZE = 64
  43. FACTORY = {
  44. ParserType.NAIVE.value: naive,
  45. ParserType.PAPER.value: paper,
  46. ParserType.BOOK.value: book,
  47. ParserType.PRESENTATION.value: presentation,
  48. ParserType.MANUAL.value: manual,
  49. ParserType.LAWS.value: laws,
  50. ParserType.QA.value: qa,
  51. ParserType.TABLE.value: table,
  52. ParserType.RESUME.value: resume,
  53. ParserType.PICTURE.value: picture,
  54. }
  55. def set_progress(task_id, from_page=0, to_page=-1,
  56. prog=None, msg="Processing..."):
  57. if prog is not None and prog < 0:
  58. msg = "[ERROR]"+msg
  59. cancel = TaskService.do_cancel(task_id)
  60. if cancel:
  61. msg += " [Canceled]"
  62. prog = -1
  63. if to_page > 0:
  64. msg = f"Page({from_page}~{to_page}): " + msg
  65. d = {"progress_msg": msg}
  66. if prog is not None:
  67. d["progress"] = prog
  68. try:
  69. TaskService.update_progress(task_id, d)
  70. except Exception as e:
  71. cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
  72. if cancel:
  73. sys.exit()
  74. def collect(comm, mod, tm):
  75. tasks = TaskService.get_tasks(tm, mod, comm)
  76. if len(tasks) == 0:
  77. return pd.DataFrame()
  78. tasks = pd.DataFrame(tasks)
  79. mtm = tasks["update_time"].max()
  80. cron_logger.info("TOTAL:{}, To:{}".format(len(tasks), mtm))
  81. return tasks
  82. def build(row):
  83. if row["size"] > DOC_MAXIMUM_SIZE:
  84. set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
  85. (int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
  86. return []
  87. callback = partial(
  88. set_progress,
  89. row["id"],
  90. row["from_page"],
  91. row["to_page"])
  92. chunker = FACTORY[row["parser_id"].lower()]
  93. try:
  94. cron_logger.info(
  95. "Chunkking {}/{}".format(row["location"], row["name"]))
  96. cks = chunker.chunk(row["name"], binary=MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"],
  97. to_page=row["to_page"], lang=row["language"], callback=callback,
  98. kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
  99. except Exception as e:
  100. if re.search("(No such file|not found)", str(e)):
  101. callback(-1, "Can not find file <%s>" % row["doc_name"])
  102. else:
  103. callback(-1, f"Internal server error: %s" %
  104. str(e).replace("'", ""))
  105. traceback.print_exc()
  106. cron_logger.warn(
  107. "Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
  108. return
  109. callback(msg="Finished slicing files(%d). Start to embedding the content."%len(cks))
  110. docs = []
  111. doc = {
  112. "doc_id": row["doc_id"],
  113. "kb_id": [str(row["kb_id"])]
  114. }
  115. for ck in cks:
  116. d = copy.deepcopy(doc)
  117. d.update(ck)
  118. md5 = hashlib.md5()
  119. md5.update((ck["content_with_weight"] +
  120. str(d["doc_id"])).encode("utf-8"))
  121. d["_id"] = md5.hexdigest()
  122. d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
  123. d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
  124. if not d.get("image"):
  125. docs.append(d)
  126. continue
  127. output_buffer = BytesIO()
  128. if isinstance(d["image"], bytes):
  129. output_buffer = BytesIO(d["image"])
  130. else:
  131. d["image"].save(output_buffer, format='JPEG')
  132. MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
  133. d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
  134. del d["image"]
  135. docs.append(d)
  136. return docs
  137. def init_kb(row):
  138. idxnm = search.index_name(row["tenant_id"])
  139. if ELASTICSEARCH.indexExist(idxnm):
  140. return
  141. return ELASTICSEARCH.createIdx(idxnm, json.load(
  142. open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))
  143. def embedding(docs, mdl, parser_config={}):
  144. tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
  145. d["content_with_weight"] for d in docs]
  146. tk_count = 0
  147. if len(tts) == len(cnts):
  148. tts, c = mdl.encode(tts)
  149. tk_count += c
  150. cnts, c = mdl.encode(cnts)
  151. tk_count += c
  152. title_w = float(parser_config.get("filename_embd_weight", 0.1))
  153. vects = (title_w * tts + (1 - title_w) *
  154. cnts) if len(tts) == len(cnts) else cnts
  155. assert len(vects) == len(docs)
  156. for i, d in enumerate(docs):
  157. v = vects[i].tolist()
  158. d["q_%d_vec" % len(v)] = v
  159. return tk_count
  160. def main(comm, mod):
  161. tm_fnm = os.path.join(
  162. get_project_base_directory(),
  163. "rag/res",
  164. f"{comm}-{mod}.tm")
  165. tm = findMaxTm(tm_fnm)
  166. rows = collect(comm, mod, tm)
  167. if len(rows) == 0:
  168. return
  169. tmf = open(tm_fnm, "a+")
  170. for _, r in rows.iterrows():
  171. callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
  172. try:
  173. embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
  174. except Exception as e:
  175. callback(prog=-1, msg=str(e))
  176. continue
  177. cks = build(r)
  178. if cks is None:
  179. continue
  180. if not cks:
  181. tmf.write(str(r["update_time"]) + "\n")
  182. callback(1., "No chunk! Done!")
  183. continue
  184. # TODO: exception handler
  185. ## set_progress(r["did"], -1, "ERROR: ")
  186. try:
  187. tk_count = embedding(cks, embd_mdl, r["parser_config"])
  188. except Exception as e:
  189. callback(-1, "Embedding error:{}".format(str(e)))
  190. cron_logger.error(str(e))
  191. callback(msg="Finished embedding! Start to build index!")
  192. init_kb(r)
  193. chunk_count = len(set([c["_id"] for c in cks]))
  194. es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
  195. if es_r:
  196. callback(-1, "Index failure!")
  197. cron_logger.error(str(es_r))
  198. else:
  199. if TaskService.do_cancel(r["id"]):
  200. ELASTICSEARCH.deleteByQuery(
  201. Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
  202. continue
  203. callback(1., "Done!")
  204. DocumentService.increment_chunk_num(
  205. r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
  206. cron_logger.info(
  207. "Chunk doc({}), token({}), chunks({})".format(
  208. r["id"], tk_count, len(cks)))
  209. tmf.write(str(r["update_time"]) + "\n")
  210. tmf.close()
  211. if __name__ == "__main__":
  212. peewee_logger = logging.getLogger('peewee')
  213. peewee_logger.propagate = False
  214. peewee_logger.addHandler(database_logger.handlers[0])
  215. peewee_logger.setLevel(database_logger.level)
  216. from mpi4py import MPI
  217. comm = MPI.COMM_WORLD
  218. main(comm.Get_size(), comm.Get_rank())