You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

table.py 6.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import copy
  2. import re
  3. from io import BytesIO
  4. from xpinyin import Pinyin
  5. import numpy as np
  6. import pandas as pd
  7. from openpyxl import load_workbook
  8. from dateutil.parser import parse as datetime_parse
  9. from api.db.services.knowledgebase_service import KnowledgebaseService
  10. from rag.parser import is_english, tokenize
  11. from rag.nlp import huqie, stemmer
  12. class Excel(object):
  13. def __call__(self, fnm, binary=None, callback=None):
  14. if not binary:
  15. wb = load_workbook(fnm)
  16. else:
  17. wb = load_workbook(BytesIO(binary))
  18. total = 0
  19. for sheetname in wb.sheetnames:
  20. total += len(list(wb[sheetname].rows))
  21. res, fails, done = [], [], 0
  22. for sheetname in wb.sheetnames:
  23. ws = wb[sheetname]
  24. rows = list(ws.rows)
  25. headers = [cell.value for cell in rows[0]]
  26. missed = set([i for i, h in enumerate(headers) if h is None])
  27. headers = [cell.value for i, cell in enumerate(rows[0]) if i not in missed]
  28. data = []
  29. for i, r in enumerate(rows[1:]):
  30. row = [cell.value for ii, cell in enumerate(r) if ii not in missed]
  31. if len(row) != len(headers):
  32. fails.append(str(i))
  33. continue
  34. data.append(row)
  35. done += 1
  36. if done % 999 == 0:
  37. callback(done * 0.6 / total, ("Extract records: {}".format(len(res)) + (
  38. f"{len(fails)} failure({sheetname}), line: %s..." % (",".join(fails[:3])) if fails else "")))
  39. res.append(pd.DataFrame(np.array(data), columns=headers))
  40. callback(0.6, ("Extract records: {}. ".format(done) + (
  41. f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
  42. return res
  43. def trans_datatime(s):
  44. try:
  45. return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S")
  46. except Exception as e:
  47. pass
  48. def trans_bool(s):
  49. if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", "是"]
  50. if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", "否"]
  51. def column_data_type(arr):
  52. uni = len(set([a for a in arr if a is not None]))
  53. counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
  54. trans = {t: f for f, t in
  55. [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
  56. for a in arr:
  57. if a is None: continue
  58. if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")):
  59. counts["int"] += 1
  60. elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")):
  61. counts["float"] += 1
  62. elif re.match(r"(true|false|yes|no|是|否)$", str(a), flags=re.IGNORECASE):
  63. counts["bool"] += 1
  64. elif trans_datatime(str(a)):
  65. counts["datetime"] += 1
  66. else:
  67. counts["text"] += 1
  68. counts = sorted(counts.items(), key=lambda x: x[1] * -1)
  69. ty = counts[0][0]
  70. for i in range(len(arr)):
  71. if arr[i] is None: continue
  72. try:
  73. arr[i] = trans[ty](str(arr[i]))
  74. except Exception as e:
  75. arr[i] = None
  76. if ty == "text":
  77. if len(arr) > 128 and uni / len(arr) < 0.1:
  78. ty = "keyword"
  79. return arr, ty
  80. def chunk(filename, binary=None, callback=None, **kwargs):
  81. dfs = []
  82. if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
  83. callback(0.1, "Start to parse.")
  84. excel_parser = Excel()
  85. dfs = excel_parser(filename, binary, callback)
  86. elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
  87. callback(0.1, "Start to parse.")
  88. txt = ""
  89. if binary:
  90. txt = binary.decode("utf-8")
  91. else:
  92. with open(filename, "r") as f:
  93. while True:
  94. l = f.readline()
  95. if not l: break
  96. txt += l
  97. lines = txt.split("\n")
  98. fails = []
  99. headers = lines[0].split(kwargs.get("delimiter", "\t"))
  100. rows = []
  101. for i, line in enumerate(lines[1:]):
  102. row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
  103. if len(row) != len(headers):
  104. fails.append(str(i))
  105. continue
  106. rows.append(row)
  107. if len(rows) % 999 == 0:
  108. callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
  109. f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
  110. callback(0.6, ("Extract records: {}".format(len(rows)) + (
  111. f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
  112. dfs = [pd.DataFrame(np.array(rows), columns=headers)]
  113. else:
  114. raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
  115. res = []
  116. PY = Pinyin()
  117. fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
  118. for df in dfs:
  119. for n in ["id", "_id", "index", "idx"]:
  120. if n in df.columns: del df[n]
  121. clmns = df.columns.values
  122. txts = list(copy.deepcopy(clmns))
  123. py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
  124. clmn_tys = []
  125. for j in range(len(clmns)):
  126. cln, ty = column_data_type(df[clmns[j]])
  127. clmn_tys.append(ty)
  128. df[clmns[j]] = cln
  129. if ty == "text": txts.extend([str(c) for c in cln if c])
  130. clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))]
  131. eng = is_english(txts)
  132. for ii, row in df.iterrows():
  133. d = {}
  134. row_txt = []
  135. for j in range(len(clmns)):
  136. if row[clmns[j]] is None: continue
  137. fld = clmns_map[j][0]
  138. d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]])
  139. row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
  140. if not row_txt: continue
  141. tokenize(d, "; ".join(row_txt), eng)
  142. res.append(d)
  143. KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
  144. callback(0.6, "")
  145. return res
  146. if __name__ == "__main__":
  147. import sys
  148. def dummy(a, b):
  149. pass
  150. chunk(sys.argv[1], callback=dummy)