Bläddra i källkod

Fix errors detected by Ruff (#3918)

### What problem does this PR solve?

Fix errors detected by Ruff

### Type of change

- [x] Refactoring
tags/v0.15.0
Zhichang Yu 10 månader sedan
förälder
incheckning
0d68a6cd1b
Inget konto är kopplat till bidragsgivarens mejladress
97 ändrade filer med 2560 tillägg och 1978 borttagningar
  1. 23
    13
      agent/canvas.py
  2. 70
    0
      agent/component/__init__.py
  3. 10
    5
      agent/component/base.py
  4. 8
    5
      agent/component/categorize.py
  5. 0
    1
      agent/component/deepl.py
  6. 4
    2
      agent/component/exesql.py
  7. 24
    12
      agent/component/generate.py
  8. 2
    1
      agent/component/rewrite.py
  9. 10
    7
      agent/component/switch.py
  10. 2
    1
      agent/component/template.py
  11. 2
    1
      agent/test/client.py
  12. 8
    5
      api/apps/api_app.py
  13. 4
    4
      api/apps/canvas_app.py
  14. 2
    1
      api/apps/chunk_app.py
  15. 4
    2
      api/apps/conversation_app.py
  16. 4
    2
      api/apps/dialog_app.py
  17. 4
    3
      api/apps/document_app.py
  18. 4
    2
      api/apps/llm_app.py
  19. 1
    1
      api/apps/sdk/agent.py
  20. 2
    3
      api/apps/sdk/doc.py
  21. 8
    4
      api/apps/sdk/session.py
  22. 1
    1
      api/apps/user_app.py
  23. 1
    1
      api/db/db_models.py
  24. 1
    1
      api/db/init_data.py
  25. 5
    4
      api/db/services/__init__.py
  26. 2
    1
      api/db/services/api_service.py
  27. 1
    3
      api/db/services/canvas_service.py
  28. 1
    1
      api/db/services/common_service.py
  29. 24
    14
      api/db/services/dialog_service.py
  30. 2
    1
      api/db/services/document_service.py
  31. 2
    2
      api/db/services/file2document_service.py
  32. 4
    2
      api/db/services/file_service.py
  33. 2
    1
      api/db/services/llm_service.py
  34. 66
    34
      api/db/services/task_service.py
  35. 1
    1
      api/db/services/user_service.py
  36. 3
    3
      api/ragflow_server.py
  37. 0
    1
      api/utils/api_utils.py
  38. 1
    1
      api/validation.py
  39. 13
    1
      deepdoc/parser/__init__.py
  40. 12
    8
      deepdoc/parser/excel_parser.py
  41. 1
    1
      deepdoc/parser/html_parser.py
  42. 1
    1
      deepdoc/parser/json_parser.py
  43. 9
    8
      deepdoc/parser/pdf_parser.py
  44. 60
    19
      deepdoc/parser/resume/__init__.py
  45. 47
    21
      deepdoc/parser/resume/entities/corporations.py
  46. 20
    16
      deepdoc/parser/resume/entities/degrees.py
  47. 684
    679
      deepdoc/parser/resume/entities/industries.py
  48. 758
    748
      deepdoc/parser/resume/entities/regions.py
  49. 28
    17
      deepdoc/parser/resume/entities/schools.py
  50. 202
    106
      deepdoc/parser/resume/step_two.py
  51. 2
    1
      deepdoc/parser/txt_parser.py
  52. 13
    4
      deepdoc/vision/__init__.py
  53. 2
    2
      deepdoc/vision/layout_recognizer.py
  54. 3
    1
      deepdoc/vision/ocr.py
  55. 3
    3
      deepdoc/vision/operators.py
  56. 1
    1
      deepdoc/vision/postprocess.py
  57. 10
    4
      deepdoc/vision/recognizer.py
  58. 4
    2
      graphrag/community_reports_extractor.py
  59. 1
    0
      graphrag/entity_embedding.py
  60. 8
    4
      graphrag/graph_extractor.py
  61. 2
    1
      graphrag/index.py
  62. 8
    4
      graphrag/leiden.py
  63. 5
    1
      intergrations/chatgpt-on-wechat/plugins/__init__.py
  64. 0
    1
      intergrations/chatgpt-on-wechat/plugins/ragflow_chat.py
  65. 3
    3
      rag/app/book.py
  66. 1
    1
      rag/app/email.py
  67. 2
    1
      rag/app/knowledge_graph.py
  68. 11
    8
      rag/app/laws.py
  69. 4
    3
      rag/app/manual.py
  70. 3
    2
      rag/app/one.py
  71. 17
    13
      rag/app/qa.py
  72. 5
    4
      rag/app/table.py
  73. 122
    10
      rag/llm/__init__.py
  74. 56
    28
      rag/llm/chat_model.py
  75. 43
    24
      rag/llm/cv_model.py
  76. 2
    4
      rag/llm/sequence2txt_model.py
  77. 2
    2
      rag/llm/tts_model.py
  78. 15
    9
      rag/nlp/__init__.py
  79. 4
    2
      rag/nlp/query.py
  80. 2
    8
      rag/nlp/rag_tokenizer.py
  81. 3
    3
      rag/nlp/term_weight.py
  82. 4
    2
      rag/raptor.py
  83. 4
    2
      rag/svr/cache_file_svr.py
  84. 9
    8
      rag/svr/task_executor.py
  85. 14
    14
      rag/utils/__init__.py
  86. 1
    1
      rag/utils/azure_sas_conn.py
  87. 1
    1
      rag/utils/azure_spn_conn.py
  88. 2
    1
      rag/utils/es_conn.py
  89. 13
    6
      sdk/python/ragflow_sdk/__init__.py
  90. 1
    1
      sdk/python/ragflow_sdk/modules/session.py
  91. 0
    2
      sdk/python/test/conftest.py
  92. 0
    1
      sdk/python/test/test_frontend_api/common.py
  93. 1
    1
      sdk/python/test/test_frontend_api/get_email.py
  94. 1
    5
      sdk/python/test/test_frontend_api/test_chunk.py
  95. 2
    5
      sdk/python/test/test_frontend_api/test_dataset.py
  96. 1
    1
      sdk/python/test/test_sdk_api/get_email.py
  97. 1
    1
      sdk/python/test/test_sdk_api/t_agent.py

+ 23
- 13
agent/canvas.py Visa fil

@@ -133,7 +133,8 @@ class Canvas(ABC):
"components": {}
}
for k in self.dsl.keys():
if k in ["components"]:continue
if k in ["components"]:
continue
dsl[k] = deepcopy(self.dsl[k])

for k, cpn in self.components.items():
@@ -158,7 +159,8 @@ class Canvas(ABC):

def get_compnent_name(self, cid):
for n in self.dsl["graph"]["nodes"]:
if cid == n["id"]: return n["data"]["name"]
if cid == n["id"]:
return n["data"]["name"]
return ""

def run(self, **kwargs):
@@ -173,7 +175,8 @@ class Canvas(ABC):
if kwargs.get("stream"):
for an in ans():
yield an
else: yield ans
else:
yield ans
return

if not self.path:
@@ -188,7 +191,8 @@ class Canvas(ABC):
def prepare2run(cpns):
nonlocal ran, ans
for c in cpns:
if self.path[-1] and c == self.path[-1][-1]: continue
if self.path[-1] and c == self.path[-1][-1]:
continue
cpn = self.components[c]["obj"]
if cpn.component_name == "Answer":
self.answer.append(c)
@@ -197,7 +201,8 @@ class Canvas(ABC):
if c not in without_dependent_checking:
cpids = cpn.get_dependent_components()
if any([cc not in self.path[-1] for cc in cpids]):
if c not in waiting: waiting.append(c)
if c not in waiting:
waiting.append(c)
continue
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c))
ans = cpn.run(self.history, **kwargs)
@@ -211,10 +216,12 @@ class Canvas(ABC):
logging.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)
if not cpn["downstream"]: break
if not cpn["downstream"]:
break

loop = self._find_loop()
if loop: raise OverflowError(f"Too much loops: {loop}")
if loop:
raise OverflowError(f"Too much loops: {loop}")

if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
switch_out = cpn["obj"].output()[1].iloc[0, 0]
@@ -283,19 +290,22 @@ class Canvas(ABC):

def _find_loop(self, max_loops=6):
path = self.path[-1][::-1]
if len(path) < 2: return False
if len(path) < 2:
return False

for i in range(len(path)):
if path[i].lower().find("answer") >= 0:
path = path[:i]
break

if len(path) < 2: return False
if len(path) < 2:
return False

for l in range(2, len(path) // 2):
pat = ",".join(path[0:l])
for loc in range(2, len(path) // 2):
pat = ",".join(path[0:loc])
path_str = ",".join(path)
if len(pat) >= len(path_str): return False
if len(pat) >= len(path_str):
return False
loop = max_loops
while path_str.find(pat) == 0 and loop >= 0:
loop -= 1
@@ -303,7 +313,7 @@ class Canvas(ABC):
return False
path_str = path_str[len(pat)+1:]
if loop < 0:
pat = " => ".join([p.split(":")[0] for p in path[0:l]])
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
return pat + " => " + pat

return False

+ 70
- 0
agent/component/__init__.py Visa fil

@@ -39,3 +39,73 @@ def component_class(class_name):
m = importlib.import_module("agent.component")
c = getattr(m, class_name)
return c

__all__ = [
"Begin",
"BeginParam",
"Generate",
"GenerateParam",
"Retrieval",
"RetrievalParam",
"Answer",
"AnswerParam",
"Categorize",
"CategorizeParam",
"Switch",
"SwitchParam",
"Relevant",
"RelevantParam",
"Message",
"MessageParam",
"RewriteQuestion",
"RewriteQuestionParam",
"KeywordExtract",
"KeywordExtractParam",
"Concentrator",
"ConcentratorParam",
"Baidu",
"BaiduParam",
"DuckDuckGo",
"DuckDuckGoParam",
"Wikipedia",
"WikipediaParam",
"PubMed",
"PubMedParam",
"ArXiv",
"ArXivParam",
"Google",
"GoogleParam",
"Bing",
"BingParam",
"GoogleScholar",
"GoogleScholarParam",
"DeepL",
"DeepLParam",
"GitHub",
"GitHubParam",
"BaiduFanyi",
"BaiduFanyiParam",
"QWeather",
"QWeatherParam",
"ExeSQL",
"ExeSQLParam",
"YahooFinance",
"YahooFinanceParam",
"WenCai",
"WenCaiParam",
"Jin10",
"Jin10Param",
"TuShare",
"TuShareParam",
"AkShare",
"AkShareParam",
"Crawler",
"CrawlerParam",
"Invoke",
"InvokeParam",
"Template",
"TemplateParam",
"Email",
"EmailParam",
"component_class"
]

+ 10
- 5
agent/component/base.py Visa fil

@@ -428,7 +428,8 @@ class ComponentBase(ABC):
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o]
if not isinstance(o, list):
o = [o]
o = pd.DataFrame(o)

if allow_partial or not isinstance(o, partial):
@@ -440,7 +441,8 @@ class ComponentBase(ABC):
for oo in o():
if not isinstance(oo, pd.DataFrame):
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
else: outs = oo
else:
outs = oo
return self._param.output_var_name, outs

def reset(self):
@@ -482,13 +484,15 @@ class ComponentBase(ABC):
outs.append(pd.DataFrame([{"content": q["value"]}]))
if outs:
df = pd.concat(outs, ignore_index=True)
if "content" in df: df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
if "content" in df:
df = df.drop_duplicates(subset=['content']).reset_index(drop=True)
return df

upstream_outs = []

for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.get_component_name(u) in ["switch", "concentrator"]:
continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
@@ -532,7 +536,8 @@ class ComponentBase(ABC):
reversed_cpnts.extend(self._canvas.path[-1])

for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "answer"]: continue
if self.get_component_name(u) in ["switch", "answer"]:
continue
return self._canvas.get_component(u)["obj"].output()[1]

@staticmethod

+ 8
- 5
agent/component/categorize.py Visa fil

@@ -34,15 +34,18 @@ class CategorizeParam(GenerateParam):
super().check()
self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items():
if not k: raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
if not k:
raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")

def get_prompt(self):
cate_lines = []
for c, desc in self.category_description.items():
for l in desc.get("examples", "").split("\n"):
if not l: continue
cate_lines.append("Question: {}\tCategory: {}".format(l, c))
for line in desc.get("examples", "").split("\n"):
if not line:
continue
cate_lines.append("Question: {}\tCategory: {}".format(line, c))
descriptions = []
for c, desc in self.category_description.items():
if desc.get("description"):

+ 0
- 1
agent/component/deepl.py Visa fil

@@ -14,7 +14,6 @@
# limitations under the License.
#
from abc import ABC
import re
from agent.component.base import ComponentBase, ComponentParamBase
import deepl


+ 4
- 2
agent/component/exesql.py Visa fil

@@ -46,8 +46,10 @@ class ExeSQLParam(ComponentParamBase):
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")
if self.host == "ragflow-mysql":
raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow":
raise ValueError("The host is not accessible.")


class ExeSQL(ComponentBase, ABC):

+ 24
- 12
agent/component/generate.py Visa fil

@@ -51,11 +51,16 @@ class GenerateParam(ComponentParamBase):

def gen_conf(self):
conf = {}
if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
if self.temperature > 0: conf["temperature"] = self.temperature
if self.top_p > 0: conf["top_p"] = self.top_p
if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf


@@ -83,7 +88,8 @@ class Generate(ComponentBase):
recall_docs = []
for i in idx:
did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids: continue
if did in doc_ids:
continue
doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})

@@ -108,7 +114,8 @@ class Generate(ComponentBase):
retrieval_res = []
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
@@ -142,7 +149,8 @@ class Generate(ComponentBase):

if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([])
else:
retrieval_res = pd.DataFrame([])

for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
@@ -164,9 +172,11 @@ class Generate(ComponentBase):
return pd.DataFrame([res])

msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())

if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
@@ -185,9 +195,11 @@ class Generate(ComponentBase):
return

msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1: msg.append({"role": "user", "content": ""})
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2: msg.append({"role": "user", "content": ""})
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}

+ 2
- 1
agent/component/rewrite.py Visa fil

@@ -95,7 +95,8 @@ class RewriteQuestion(Generate, ABC):
hist = self._canvas.get_history(4)
conv = []
for m in hist:
if m["role"] not in ["user", "assistant"]: continue
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv)


+ 10
- 7
agent/component/switch.py Visa fil

@@ -41,7 +41,8 @@ class SwitchParam(ComponentParamBase):
def check(self):
self.check_empty(self.conditions, "[Switch] conditions")
for cond in self.conditions:
if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!")
if not cond["to"]:
raise ValueError("[Switch] 'To' can not be empty!")


class Switch(ComponentBase, ABC):
@@ -51,7 +52,8 @@ class Switch(ComponentBase, ABC):
res = []
for cond in self._param.conditions:
for item in cond["items"]:
if not item["cpn_id"]: continue
if not item["cpn_id"]:
continue
if item["cpn_id"].find("begin") >= 0:
continue
cid = item["cpn_id"].split("@")[0]
@@ -63,7 +65,8 @@ class Switch(ComponentBase, ABC):
for cond in self._param.conditions:
res = []
for item in cond["items"]:
if not item["cpn_id"]:continue
if not item["cpn_id"]:
continue
cid = item["cpn_id"].split("@")[0]
if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@")
@@ -107,22 +110,22 @@ class Switch(ComponentBase, ABC):
elif operator == ">":
try:
return True if float(input) > float(value) else False
except Exception as e:
except Exception:
return True if input > value else False
elif operator == "<":
try:
return True if float(input) < float(value) else False
except Exception as e:
except Exception:
return True if input < value else False
elif operator == "≥":
try:
return True if float(input) >= float(value) else False
except Exception as e:
except Exception:
return True if input >= value else False
elif operator == "≤":
try:
return True if float(input) <= float(value) else False
except Exception as e:
except Exception:
return True if input <= value else False

raise ValueError('Not supported operator' + operator)

+ 2
- 1
agent/component/template.py Visa fil

@@ -47,7 +47,8 @@ class Template(ComponentBase):

self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"): continue
if not para.get("component_id"):
continue
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")

+ 2
- 1
agent/test/client.py Visa fil

@@ -43,6 +43,7 @@ if __name__ == '__main__':
else:
print(ans["content"])

if DEBUG: print(canvas.path)
if DEBUG:
print(canvas.path)
question = input("\n==================== User =====================\n> ")
canvas.add_user_input(question)

+ 8
- 5
api/apps/api_app.py Visa fil

@@ -142,7 +142,6 @@ def set_conversation():
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if objs[0].source == "agent":
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
@@ -188,7 +187,8 @@ def completion():
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req: req["quote"] = False
if "quote" not in req:
req["quote"] = False

msg = []
for m in req["messages"]:
@@ -197,7 +197,8 @@ def completion():
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]

def fillin_conv(ans):
@@ -674,11 +675,13 @@ def completion_faq():
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(message="Conversation not found!")
if "quote" not in req: req["quote"] = True
if "quote" not in req:
req["quote"] = True

msg = []
msg.append({"role": "user", "content": req["word"]})
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]

def fillin_conv(ans):

+ 4
- 4
api/apps/canvas_app.py Visa fil

@@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import json
import traceback
from functools import partial
from flask import request, Response
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
@@ -60,7 +58,8 @@ def rm():
def save():
req = request.json
req["user_id"] = current_user.id
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
if not isinstance(req["dsl"], str):
req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)

req["dsl"] = json.loads(req["dsl"])
if "id" not in req:
@@ -153,7 +152,8 @@ def run():
return resp

for answer in canvas.run(stream=False):
if answer.get("running_status"): continue
if answer.get("running_status"):
continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):

+ 2
- 1
api/apps/chunk_app.py Visa fil

@@ -237,7 +237,8 @@ def create():
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not e:
return get_data_error_result(message="Knowledgebase not found!")
if kb.pagerank: d["pagerank_fea"] = kb.pagerank
if kb.pagerank:
d["pagerank_fea"] = kb.pagerank

embd_id = DocumentService.get_embd_id(req["doc_id"])
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)

+ 4
- 2
api/apps/conversation_app.py Visa fil

@@ -281,10 +281,12 @@ def thumbup():
if req["message_id"] == msg.get("id", "") and msg.get("role", "") == "assistant":
if up_down:
msg["thumbup"] = True
if "feedback" in msg: del msg["feedback"]
if "feedback" in msg:
del msg["feedback"]
else:
msg["thumbup"] = False
if feedback: msg["feedback"] = feedback
if feedback:
msg["feedback"] = feedback
break

ConversationService.update_by_id(conv["id"], conv)

+ 4
- 2
api/apps/dialog_app.py Visa fil

@@ -37,10 +37,12 @@ def set_dialog():
top_n = req.get("top_n", 6)
top_k = req.get("top_k", 1024)
rerank_id = req.get("rerank_id", "")
if not rerank_id: req["rerank_id"] = ""
if not rerank_id:
req["rerank_id"] = ""
similarity_threshold = req.get("similarity_threshold", 0.1)
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
if vector_similarity_weight is None: vector_similarity_weight = 0.3
if vector_similarity_weight is None:
vector_similarity_weight = 0.3
llm_setting = req.get("llm_setting", {})
default_prompt = {
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。

+ 4
- 3
api/apps/document_app.py Visa fil

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
import json
import os.path
import pathlib
import re
@@ -90,7 +89,8 @@ def web_crawl():
raise LookupError("Can't find this knowledgebase!")

blob = html2pdf(url)
if not blob: return server_error_response(ValueError("Download failure."))
if not blob:
return server_error_response(ValueError("Download failure."))

root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
@@ -290,7 +290,8 @@ def change_status():
def rm():
req = request.json
doc_ids = req["doc_id"]
if isinstance(doc_ids, str): doc_ids = [doc_ids]
if isinstance(doc_ids, str):
doc_ids = [doc_ids]

for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id):

+ 4
- 2
api/apps/llm_app.py Visa fil

@@ -351,8 +351,10 @@ def list_app():

llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
for o in objs:
if not o.api_key: continue
if o.llm_name + "@" + o.llm_factory in llm_set: continue
if not o.api_key:
continue
if o.llm_name + "@" + o.llm_factory in llm_set:
continue
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})

res = {}

+ 1
- 1
api/apps/sdk/agent.py Visa fil

@@ -14,7 +14,7 @@
# limitations under the License.
#

from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.db.services.canvas_service import UserCanvasService
from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result
from flask import request

+ 2
- 3
api/apps/sdk/doc.py Visa fil

@@ -41,7 +41,6 @@ from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search
from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL
import os

MAXIMUM_OF_UPLOADING_FILES = 256

@@ -976,12 +975,12 @@ def add_chunk(tenant_id, dataset_id, document_id):
if not req.get("content"):
return get_error_data_result(message="`content` is required")
if "important_keywords" in req:
if type(req["important_keywords"]) != list:
if not isinstance(req["important_keywords"], list):
return get_error_data_result(
"`important_keywords` is required to be a list"
)
if "questions" in req:
if type(req["questions"]) != list:
if not isinstance(req["questions"], list):
return get_error_data_result(
"`questions` is required to be a list"
)

+ 8
- 4
api/apps/sdk/session.py Visa fil

@@ -143,8 +143,10 @@ def completion(tenant_id, chat_id):
}
conv.message.append(question)
for m in conv.message:
if m["role"] == "system": continue
if m["role"] == "assistant" and not msg: continue
if m["role"] == "system":
continue
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
message_id = msg[-1].get("id")
e, dia = DialogService.get_by_id(conv.dialog_id)
@@ -267,7 +269,8 @@ def agent_completion(tenant_id, agent_id):
if m["role"] == "assistant" and not msg:
continue
msg.append(m)
if not msg[-1].get("id"): msg[-1]["id"] = get_uuid()
if not msg[-1].get("id"):
msg[-1]["id"] = get_uuid()
message_id = msg[-1]["id"]

stream = req.get("stream", True)
@@ -361,7 +364,8 @@ def agent_completion(tenant_id, agent_id):
return resp

for answer in canvas.run(stream=False):
if answer.get("running_status"): continue
if answer.get("running_status"):
continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):

+ 1
- 1
api/apps/user_app.py Visa fil

@@ -330,7 +330,7 @@ def user_info_from_github(access_token):
headers=headers,
).json()
user_info["email"] = next(
(email for email in email_info if email["primary"] == True), None
(email for email in email_info if email["primary"]), None
)["email"]
return user_info


+ 1
- 1
api/db/db_models.py Visa fil

@@ -130,7 +130,7 @@ def is_continuous_field(cls: typing.Type) -> bool:
for p in cls.__bases__:
if p in CONTINUOUS_FIELD_TYPE:
return True
elif p != Field and p != object:
elif p is not Field and p is not object:
if is_continuous_field(p):
return True
else:

+ 1
- 1
api/db/init_data.py Visa fil

@@ -170,7 +170,7 @@ def add_graph_templates():
cnvs = json.load(open(os.path.join(dir, fnm), "r"))
try:
CanvasTemplateService.save(**cnvs)
except:
except Exception:
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
except Exception:
logging.exception("Add graph templates error: ")

+ 5
- 4
api/db/services/__init__.py Visa fil

@@ -15,13 +15,14 @@
#
import pathlib
import re
from .user_service import UserService
from .user_service import UserService as UserService


def duplicate_name(query_func, **kwargs):
fnm = kwargs["name"]
objs = query_func(**kwargs)
if not objs: return fnm
if not objs:
return fnm
ext = pathlib.Path(fnm).suffix #.jpg
nm = re.sub(r"%s$"%ext, "", fnm)
r = re.search(r"\(([0-9]+)\)$", nm)
@@ -31,8 +32,8 @@ def duplicate_name(query_func, **kwargs):
nm = re.sub(r"\([0-9]+\)$", "", nm)
c += 1
nm = f"{nm}({c})"
if ext: nm += f"{ext}"
if ext:
nm += f"{ext}"

kwargs["name"] = nm
return duplicate_name(query_func, **kwargs)


+ 2
- 1
api/db/services/api_service.py Visa fil

@@ -64,7 +64,8 @@ class API4ConversationService(CommonService):
@classmethod
@DB.connection_context()
def stats(cls, tenant_id, from_date, to_date, source=None):
if len(to_date) == 10: to_date += " 23:59:59"
if len(to_date) == 10:
to_date += " 23:59:59"
return cls.model.select(
cls.model.create_date.truncate("day").alias("dt"),
peewee.fn.COUNT(

+ 1
- 3
api/db/services/canvas_service.py Visa fil

@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
from api.db.db_models import DB, CanvasTemplate, UserCanvas
from api.db.services.common_service import CommonService



+ 1
- 1
api/db/services/common_service.py Visa fil

@@ -115,7 +115,7 @@ class CommonService:
try:
obj = cls.model.query(id=pid)[0]
return True, obj
except Exception as e:
except Exception:
return False, None

@classmethod

+ 24
- 14
api/db/services/dialog_service.py Visa fil

@@ -106,15 +106,15 @@ def message_fit_in(msg, max_length=4000):
return c, msg

ll = num_tokens_from_string(msg_[0]["content"])
l = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + l) > 0.8:
ll2 = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - l])
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
msg[0]["content"] = m
return max_length, msg

m = msg_[1]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - l])
m = encoder.decode(encoder.encode(m)[:max_length - ll2])
msg[1]["content"] = m
return max_length, msg

@@ -257,7 +257,8 @@ def chat(dialog, messages, stream=True, **kwargs):
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs

refs = deepcopy(kbinfos)
@@ -433,13 +434,15 @@ def relevant(tenant_id, llm_id, question, contents: list):
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
No other words needed except 'yes' or 'no'.
"""
if not contents:return False
if not contents:
return False
contents = "Documents: \n" + " - ".join(contents)
contents = f"Question: {question}\n" + contents
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
if ans.lower().find("yes") >= 0: return True
if ans.lower().find("yes") >= 0:
return True
return False


@@ -481,8 +484,10 @@ Requirements:
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0]
if kwd.find("**ERROR**") >=0: return ""
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >=0:
return ""
return kwd


@@ -508,8 +513,10 @@ Requirements:
]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): kwd = kwd[0]
if kwd.find("**ERROR**") >= 0: return ""
if isinstance(kwd, tuple):
kwd = kwd[0]
if kwd.find("**ERROR**") >= 0:
return ""
return kwd


@@ -520,7 +527,8 @@ def full_question(tenant_id, llm_id, messages):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
conv = []
for m in messages:
if m["role"] not in ["user", "assistant"]: continue
if m["role"] not in ["user", "assistant"]:
continue
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
conv = "\n".join(conv)
today = datetime.date.today().isoformat()
@@ -581,7 +589,8 @@ Output: What's the weather in Rochester on {tomorrow}?


def tts(tts_mdl, text):
if not tts_mdl or not text: return
if not tts_mdl or not text:
return
bin = b""
for chunk in tts_mdl.tts(text):
bin += chunk
@@ -641,7 +650,8 @@ def ask(question, kb_ids, tenant_id):
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
if not recall_docs:
recall_docs = kbinfos["doc_aggs"]
kbinfos["doc_aggs"] = recall_docs
refs = deepcopy(kbinfos)
for c in refs["chunks"]:

+ 2
- 1
api/db/services/document_service.py Visa fil

@@ -532,7 +532,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
try:
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
ensure_ascii=False, indent=2)
if len(mind_map) < 32: raise Exception("Few content: " + mind_map)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append({
"id": get_uuid(),
"doc_id": doc_id,

+ 2
- 2
api/db/services/file2document_service.py Visa fil

@@ -20,7 +20,7 @@ from api.db.db_models import DB
from api.db.db_models import File, File2Document
from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.utils import current_timestamp, datetime_format, get_uuid
from api.utils import current_timestamp, datetime_format


class File2DocumentService(CommonService):
@@ -63,7 +63,7 @@ class File2DocumentService(CommonService):
def update_by_file_id(cls, file_id, obj):
obj["update_time"] = current_timestamp()
obj["update_date"] = datetime_format(datetime.now())
num = cls.model.update(obj).where(cls.model.id == file_id).execute()
# num = cls.model.update(obj).where(cls.model.id == file_id).execute()
e, obj = cls.get_by_id(cls.model.id)
return obj


+ 4
- 2
api/db/services/file_service.py Visa fil

@@ -85,7 +85,8 @@ class FileService(CommonService):
.join(Document, on=(File2Document.document_id == Document.id))
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
.where(cls.model.id == file_id))
if not kbs: return []
if not kbs:
return []
kbs_info_list = []
for kb in list(kbs.dicts()):
kbs_info_list.append({"kb_id": kb['id'], "kb_name": kb['name']})
@@ -304,7 +305,8 @@ class FileService(CommonService):
@classmethod
@DB.connection_context()
def add_file_from_kb(cls, doc, kb_folder_id, tenant_id):
for _ in File2DocumentService.get_by_document_id(doc["id"]): return
for _ in File2DocumentService.get_by_document_id(doc["id"]):
return
file = {
"id": get_uuid(),
"parent_id": kb_folder_id,

+ 2
- 1
api/db/services/llm_service.py Visa fil

@@ -107,7 +107,8 @@ class TenantLLMService(CommonService):

model_config = cls.get_api_key(tenant_id, mdlnm)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if model_config: model_config = model_config.to_dict()
if model_config:
model_config = model_config.to_dict()
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)

+ 66
- 34
api/db/services/task_service.py Visa fil

@@ -57,28 +57,33 @@ class TaskService(CommonService):
Tenant.img2txt_id,
Tenant.asr_id,
Tenant.llm_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id)) \
cls.model.update_time,
]
docs = (
cls.model.select(*fields)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == task_id)
)
docs = list(docs.dicts())
if not docs: return None
if not docs:
return None

msg = "\nTask has been received."
prog = random.random() / 10.
prog = random.random() / 10.0
if docs[0]["retry_count"] >= 3:
msg = "\nERROR: Task is abandoned after 3 times attempts."
prog = -1

cls.model.update(progress_msg=cls.model.progress_msg + msg,
progress=prog,
retry_count=docs[0]["retry_count"]+1
).where(
cls.model.id == docs[0]["id"]).execute()
cls.model.update(
progress_msg=cls.model.progress_msg + msg,
progress=prog,
retry_count=docs[0]["retry_count"] + 1,
).where(cls.model.id == docs[0]["id"]).execute()

if docs[0]["retry_count"] >= 3: return None
if docs[0]["retry_count"] >= 3:
return None

return docs[0]

@@ -86,21 +91,44 @@ class TaskService(CommonService):
@DB.connection_context()
def get_ongoing_doc_name(cls):
with DB.lock("get_task", -1):
docs = cls.model.select(*[Document.id, Document.kb_id, Document.location, File.parent_id]) \
.join(Document, on=(cls.model.doc_id == Document.id)) \
.join(File2Document, on=(File2Document.document_id == Document.id), join_type=JOIN.LEFT_OUTER) \
.join(File, on=(File2Document.file_id == File.id), join_type=JOIN.LEFT_OUTER) \
docs = (
cls.model.select(
*[Document.id, Document.kb_id, Document.location, File.parent_id]
)
.join(Document, on=(cls.model.doc_id == Document.id))
.join(
File2Document,
on=(File2Document.document_id == Document.id),
join_type=JOIN.LEFT_OUTER,
)
.join(
File,
on=(File2Document.file_id == File.id),
join_type=JOIN.LEFT_OUTER,
)
.where(
Document.status == StatusEnum.VALID.value,
Document.run == TaskStatus.RUNNING.value,
~(Document.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.create_time >= current_timestamp() - 1000 * 600
cls.model.create_time >= current_timestamp() - 1000 * 600,
)
)
docs = list(docs.dicts())
if not docs: return []

return list(set([(d["parent_id"] if d["parent_id"] else d["kb_id"], d["location"]) for d in docs]))
if not docs:
return []

return list(
set(
[
(
d["parent_id"] if d["parent_id"] else d["kb_id"],
d["location"],
)
for d in docs
]
)
)

@classmethod
@DB.connection_context()
@@ -118,28 +146,30 @@ class TaskService(CommonService):
def update_progress(cls, id, info):
if os.environ.get("MACOS"):
if info["progress_msg"]:
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
cls.model.id == id).execute()
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id).execute()
cls.model.id == id
).execute()
return

with DB.lock("update_progress", -1):
if info["progress_msg"]:
cls.model.update(progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]).where(
cls.model.id == id).execute()
cls.model.update(
progress_msg=cls.model.progress_msg + "\n" + info["progress_msg"]
).where(cls.model.id == id).execute()
if "progress" in info:
cls.model.update(progress=info["progress"]).where(
cls.model.id == id).execute()
cls.model.id == id
).execute()


def queue_tasks(doc: dict, bucket: str, name: str):
def new_task():
return {
"id": get_uuid(),
"doc_id": doc["id"]
}
return {"id": get_uuid(), "doc_id": doc["id"]}

tsks = []

if doc["type"] == FileType.PDF.value:
@@ -150,8 +180,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
if doc["parser_id"] == "paper":
page_size = doc["parser_config"].get("task_page_size", 22)
if doc["parser_id"] in ["one", "knowledge_graph"] or not do_layout:
page_size = 10 ** 9
page_ranges = doc["parser_config"].get("pages") or [(1, 10 ** 5)]
page_size = 10**9
page_ranges = doc["parser_config"].get("pages") or [(1, 10**5)]
for s, e in page_ranges:
s -= 1
s = max(0, s)
@@ -177,4 +207,6 @@ def queue_tasks(doc: dict, bucket: str, name: str):
DocumentService.begin2parse(doc["id"])

for t in tsks:
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status."
assert REDIS_CONN.queue_product(
SVR_QUEUE_NAME, message=t
), "Can't access Redis. Please check the Redis' status."

+ 1
- 1
api/db/services/user_service.py Visa fil

@@ -22,7 +22,7 @@ from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant
from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.utils import get_uuid, current_timestamp, datetime_format
from api.db import StatusEnum



+ 3
- 3
api/ragflow_server.py Visa fil

@@ -21,10 +21,7 @@
import logging
import os
from api.utils.log_utils import initRootLogger
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger("ragflow_server", LOG_LEVELS)

import os
import signal
import sys
import time
@@ -44,6 +41,9 @@ from api.versions import get_ragflow_version
from api.utils import show_configs
from rag.settings import print_rag_settings

LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger("ragflow_server", LOG_LEVELS)


def update_progress():
while True:

+ 0
- 1
api/utils/api_utils.py Visa fil

@@ -36,7 +36,6 @@ from werkzeug.http import HTTP_STATUS_CODES
from api.db.db_models import APIToken
from api import settings

from api import settings
from api.utils import CustomJSONEncoder, get_uuid
from api.utils import json_dumps
from api.constants import REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC

+ 1
- 1
api/validation.py Visa fil

@@ -45,5 +45,5 @@ try:
pool = Pool(processes=1)
thread = pool.apply_async(download_nltk_data)
binary = thread.get(timeout=60)
except Exception as e:
except Exception:
print('\x1b[6;37;41m WARNING \x1b[0m' + "Downloading NLTK data failure.", flush=True)

+ 13
- 1
deepdoc/parser/__init__.py Visa fil

@@ -18,4 +18,16 @@ from .ppt_parser import RAGFlowPptParser as PptParser
from .html_parser import RAGFlowHtmlParser as HtmlParser
from .json_parser import RAGFlowJsonParser as JsonParser
from .markdown_parser import RAGFlowMarkdownParser as MarkdownParser
from .txt_parser import RAGFlowTxtParser as TxtParser
from .txt_parser import RAGFlowTxtParser as TxtParser

__all__ = [
"PdfParser",
"PlainParser",
"DocxParser",
"ExcelParser",
"PptParser",
"HtmlParser",
"JsonParser",
"MarkdownParser",
"TxtParser",
]

+ 12
- 8
deepdoc/parser/excel_parser.py Visa fil

@@ -29,7 +29,8 @@ class RAGFlowExcelParser:
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
if not rows: continue
if not rows:
continue

tb_rows_0 = "<tr>"
for t in list(rows[0]):
@@ -40,7 +41,9 @@ class RAGFlowExcelParser:
tb = ""
tb += f"<table><caption>{sheetname}</caption>"
tb += tb_rows_0
for r in list(rows[1 + chunk_i * chunk_rows:1 + (chunk_i + 1) * chunk_rows]):
for r in list(
rows[1 + chunk_i * chunk_rows : 1 + (chunk_i + 1) * chunk_rows]
):
tb += "<tr>"
for i, c in enumerate(r):
if c.value is None:
@@ -62,20 +65,21 @@ class RAGFlowExcelParser:
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
if not rows:continue
if not rows:
continue
ti = list(rows[0])
for r in list(rows[1:]):
l = []
fields = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
t += (":" if t else "") + str(c.value)
l.append(t)
l = "; ".join(l)
fields.append(t)
line = "; ".join(fields)
if sheetname.lower().find("sheet") < 0:
l += " ——" + sheetname
res.append(l)
line += " ——" + sheetname
res.append(line)
return res

@staticmethod

+ 1
- 1
deepdoc/parser/html_parser.py Visa fil

@@ -36,7 +36,7 @@ class RAGFlowHtmlParser:

@classmethod
def parser_txt(cls, txt):
if type(txt) != str:
if not isinstance(txt, str):
raise TypeError("txt type should be str!")
html_doc = readability.Document(txt)
title = html_doc.title()

+ 1
- 1
deepdoc/parser/json_parser.py Visa fil

@@ -22,7 +22,7 @@ class RAGFlowJsonParser:
txt = binary.decode(encoding, errors="ignore")
json_data = json.loads(txt)
chunks = self.split_json(json_data, True)
sections = [json.dumps(l, ensure_ascii=False) for l in chunks if l]
sections = [json.dumps(line, ensure_ascii=False) for line in chunks if line]
return sections

@staticmethod

+ 9
- 8
deepdoc/parser/pdf_parser.py Visa fil

@@ -752,7 +752,7 @@ class RAGFlowPdfParser:
"x1": np.max([b["x1"] for b in bxs]),
"bottom": np.max([b["bottom"] for b in bxs]) - ht
}
louts = [l for l in self.page_layout[pn] if l["type"] == ltype]
louts = [layout for layout in self.page_layout[pn] if layout["type"] == ltype]
ii = Recognizer.find_overlapped(b, louts, naive=True)
if ii is not None:
b = louts[ii]
@@ -763,7 +763,8 @@ class RAGFlowPdfParser:
"layoutno", "")))

left, top, right, bott = b["x0"], b["top"], b["x1"], b["bottom"]
if right < left: right = left + 1
if right < left:
right = left + 1
poss.append((pn + self.page_from, left, right, top, bott))
return self.page_images[pn] \
.crop((left * ZM, top * ZM,
@@ -845,7 +846,8 @@ class RAGFlowPdfParser:
top = bx["top"] - self.page_cum_height[pn[0] - 1]
bott = bx["bottom"] - self.page_cum_height[pn[0] - 1]
page_images_cnt = len(self.page_images)
if pn[-1] - 1 >= page_images_cnt: return ""
if pn[-1] - 1 >= page_images_cnt:
return ""
while bott * ZM > self.page_images[pn[-1] - 1].size[1]:
bott -= self.page_images[pn[-1] - 1].size[1] / ZM
pn.append(pn[-1] + 1)
@@ -889,7 +891,6 @@ class RAGFlowPdfParser:
nonlocal mh, pw, lines, widths
lines.append(line)
widths.append(width(line))
width_mean = np.mean(widths)
mmj = self.proj_match(
line["text"]) or line.get(
"layout_type",
@@ -994,7 +995,7 @@ class RAGFlowPdfParser:
else:
self.is_english = False

st = timer()
# st = timer()
for i, img in enumerate(self.page_images_x2):
chars = self.page_chars[i] if not self.is_english else []
self.mean_height.append(
@@ -1028,8 +1029,8 @@ class RAGFlowPdfParser:

self.page_cum_height = np.cumsum(self.page_cum_height)
assert len(self.page_cum_height) == len(self.page_images) + 1
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
page_to, callback)
if len(self.boxes) == 0 and zoomin < 9:
self.__images__(fnm, zoomin * 3, page_from, page_to, callback)

def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.__images__(fnm, zoomin)
@@ -1168,7 +1169,7 @@ class PlainParser(object):
if not self.outlines:
logging.warning("Miss outlines")

return [(l, "") for l in lines], []
return [(line, "") for line in lines], []

def crop(self, ck, need_position):
raise NotImplementedError

+ 60
- 19
deepdoc/parser/resume/__init__.py Visa fil

@@ -15,21 +15,42 @@ import datetime


def refactor(cv):
for n in ["raw_txt", "parser_name", "inference", "ori_text", "use_time", "time_stat"]:
if n in cv and cv[n] is not None: del cv[n]
for n in [
"raw_txt",
"parser_name",
"inference",
"ori_text",
"use_time",
"time_stat",
]:
if n in cv and cv[n] is not None:
del cv[n]
cv["is_deleted"] = 0
if "basic" not in cv: cv["basic"] = {}
if cv["basic"].get("photo2"): del cv["basic"]["photo2"]
if "basic" not in cv:
cv["basic"] = {}
if cv["basic"].get("photo2"):
del cv["basic"]["photo2"]

for n in ["education", "work", "certificate", "project", "language", "skill", "training"]:
if n not in cv or cv[n] is None: continue
if type(cv[n]) == type({}): cv[n] = [v for _, v in cv[n].items()]
if type(cv[n]) != type([]):
for n in [
"education",
"work",
"certificate",
"project",
"language",
"skill",
"training",
]:
if n not in cv or cv[n] is None:
continue
if isinstance(cv[n], dict):
cv[n] = [v for _, v in cv[n].items()]
if not isinstance(cv[n], list):
del cv[n]
continue
vv = []
for v in cv[n]:
if "external" in v and v["external"] is not None: del v["external"]
if "external" in v and v["external"] is not None:
del v["external"]
vv.append(v)
cv[n] = {str(i): vv[i] for i in range(len(vv))}

@@ -42,24 +63,44 @@ def refactor(cv):
cv["basic"][t] = cv["basic"][n]
del cv["basic"][n]

work = sorted([v for _, v in cv.get("work", {}).items()], key=lambda x: x.get("start_time", ""))
edu = sorted([v for _, v in cv.get("education", {}).items()], key=lambda x: x.get("start_time", ""))
work = sorted(
[v for _, v in cv.get("work", {}).items()],
key=lambda x: x.get("start_time", ""),
)
edu = sorted(
[v for _, v in cv.get("education", {}).items()],
key=lambda x: x.get("start_time", ""),
)

if work:
cv["basic"]["work_start_time"] = work[0].get("start_time", "")
cv["basic"]["management_experience"] = 'Y' if any(
[w.get("management_experience", '') == 'Y' for w in work]) else 'N'
cv["basic"]["management_experience"] = (
"Y"
if any([w.get("management_experience", "") == "Y" for w in work])
else "N"
)
cv["basic"]["annual_salary"] = work[-1].get("annual_salary_from", "0")

for n in ["annual_salary_from", "annual_salary_to", "industry_name", "position_name", "responsibilities",
"corporation_type", "scale", "corporation_name"]:
for n in [
"annual_salary_from",
"annual_salary_to",
"industry_name",
"position_name",
"responsibilities",
"corporation_type",
"scale",
"corporation_name",
]:
cv["basic"][n] = work[-1].get(n, "")

if edu:
for n in ["school_name", "discipline_name"]:
if n in edu[-1]: cv["basic"][n] = edu[-1][n]
if n in edu[-1]:
cv["basic"][n] = edu[-1][n]

cv["basic"]["updated_at"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if "contact" not in cv: cv["contact"] = {}
if not cv["contact"].get("name"): cv["contact"]["name"] = cv["basic"].get("name", "")
return cv
if "contact" not in cv:
cv["contact"] = {}
if not cv["contact"].get("name"):
cv["contact"]["name"] = cv["basic"].get("name", "")
return cv

+ 47
- 21
deepdoc/parser/resume/entities/corporations.py Visa fil

@@ -21,13 +21,18 @@ from . import regions


current_file_path = os.path.dirname(os.path.abspath(__file__))
GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0)
GOODS = pd.read_csv(
os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0
).fillna(0)
GOODS["cid"] = GOODS["cid"].astype(str)
GOODS = GOODS.set_index(["cid"])
CORP_TKS = json.load(open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r"))
CORP_TKS = json.load(
open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r")
)
GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r"))
CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r"))


def baike(cid, default_v=0):
global GOODS
try:
@@ -39,27 +44,41 @@ def baike(cid, default_v=0):

def corpNorm(nm, add_region=True):
global CORP_TKS
if not nm or type(nm)!=type(""):return ""
if not nm or isinstance(nm, str):
return ""
nm = rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(nm)).lower()
nm = re.sub(r"&amp;", "&", nm)
nm = re.sub(r"[\(\)()\+'\"\t \*\\【】-]+", " ", nm)
nm = re.sub(r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE)
nm = re.sub(r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$", "", nm, 10000, re.IGNORECASE)
if not nm or (len(nm)<5 and not regions.isName(nm[0:2])):return nm
nm = re.sub(
r"([—-]+.*| +co\..*|corp\..*| +inc\..*| +ltd.*)", "", nm, 10000, re.IGNORECASE
)
nm = re.sub(
r"(计算机|技术|(技术|科技|网络)*有限公司|公司|有限|研发中心|中国|总部)$",
"",
nm,
10000,
re.IGNORECASE,
)
if not nm or (len(nm) < 5 and not regions.isName(nm[0:2])):
return nm

tks = rag_tokenizer.tokenize(nm).split()
reg = [t for i,t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
reg = [t for i, t in enumerate(tks) if regions.isName(t) and (t != "中国" or i > 0)]
nm = ""
for t in tks:
if regions.isName(t) or t in CORP_TKS:continue
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):nm += " "
if regions.isName(t) or t in CORP_TKS:
continue
if re.match(r"[0-9a-zA-Z\\,.]+", t) and re.match(r".*[0-9a-zA-Z\,.]+$", nm):
nm += " "
nm += t

r = re.search(r"^([^a-z0-9 \(\)&]{2,})[a-z ]{4,}$", nm.strip())
if r:nm = r.group(1)
if r:
nm = r.group(1)
r = re.search(r"^([a-z ]{3,})[^a-z0-9 \(\)&]{2,}$", nm.strip())
if r:nm = r.group(1)
return nm.strip() + (("" if not reg else "(%s)"%reg[0]) if add_region else "")
if r:
nm = r.group(1)
return nm.strip() + (("" if not reg else "(%s)" % reg[0]) if add_region else "")


def rmNoise(n):
@@ -67,33 +86,40 @@ def rmNoise(n):
n = re.sub(r"[,. &()()]+", "", n)
return n


GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP])
for c,v in CORP_TAG.items():
for c, v in CORP_TAG.items():
cc = corpNorm(rmNoise(c), False)
if not cc:
logging.debug(c)
CORP_TAG = {corpNorm(rmNoise(c), False):v for c,v in CORP_TAG.items()}
CORP_TAG = {corpNorm(rmNoise(c), False): v for c, v in CORP_TAG.items()}


def is_good(nm):
global GOOD_CORP
if nm.find("外派")>=0:return False
if nm.find("外派") >= 0:
return False
nm = rmNoise(nm)
nm = corpNorm(nm, False)
for n in GOOD_CORP:
if re.match(r"[0-9a-zA-Z]+$", n):
if n == nm: return True
elif nm.find(n)>=0:return True
if n == nm:
return True
elif nm.find(n) >= 0:
return True
return False


def corp_tag(nm):
global CORP_TAG
nm = rmNoise(nm)
nm = corpNorm(nm, False)
for n in CORP_TAG.keys():
if re.match(r"[0-9a-zA-Z., ]+$", n):
if n == nm: return CORP_TAG[n]
elif nm.find(n)>=0:
if len(n)<3 and len(nm)/len(n)>=2:continue
if n == nm:
return CORP_TAG[n]
elif nm.find(n) >= 0:
if len(n) < 3 and len(nm) / len(n) >= 2:
continue
return CORP_TAG[n]
return []


+ 20
- 16
deepdoc/parser/resume/entities/degrees.py Visa fil

@@ -11,27 +11,31 @@
# limitations under the License.
#

TBL = {"94":"EMBA",
"6":"MBA",
"95":"MPA",
"92":"专升本",
"4":"专科",
"90":"中专",
"91":"中技",
"86":"初中",
"3":"博士",
"10":"博士后",
"1":"本科",
"2":"硕士",
"87":"职高",
"89":"高中"
TBL = {
"94": "EMBA",
"6": "MBA",
"95": "MPA",
"92": "专升本",
"4": "专科",
"90": "中专",
"91": "中技",
"86": "初中",
"3": "博士",
"10": "博士后",
"1": "本科",
"2": "硕士",
"87": "职高",
"89": "高中",
}

TBL_ = {v:k for k,v in TBL.items()}
TBL_ = {v: k for k, v in TBL.items()}


def get_name(id):
return TBL.get(str(id), "")


def get_id(nm):
if not nm:return ""
if not nm:
return ""
return TBL_.get(nm.upper().strip(), "")

+ 684
- 679
deepdoc/parser/resume/entities/industries.py
Filskillnaden har hållits tillbaka eftersom den är för stor
Visa fil


+ 758
- 748
deepdoc/parser/resume/entities/regions.py
Filskillnaden har hållits tillbaka eftersom den är för stor
Visa fil


+ 28
- 17
deepdoc/parser/resume/entities/schools.py Visa fil

@@ -16,8 +16,11 @@ import json
import re
import copy
import pandas as pd

current_file_path = os.path.dirname(os.path.abspath(__file__))
TBL = pd.read_csv(os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0).fillna("")
TBL = pd.read_csv(
os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0
).fillna("")
TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip())
GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r"))
GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH])
@@ -26,14 +29,15 @@ GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH])
def loadRank(fnm):
global TBL
TBL["rank"] = 1000000
with open(fnm, "r", encoding='utf-8') as f:
with open(fnm, "r", encoding="utf-8") as f:
while True:
l = f.readline()
if not l:break
l = l.strip("\n").split(",")
line = f.readline()
if not line:
break
line = line.strip("\n").split(",")
try:
nm,rk = l[0].strip(),int(l[1])
#assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
nm, rk = line[0].strip(), int(line[1])
# assert len(TBL[((TBL.name_cn == nm) | (TBL.name_en == nm))]),f"<{nm}>"
TBL.loc[((TBL.name_cn == nm) | (TBL.name_en == nm)), "rank"] = rk
except Exception:
pass
@@ -44,27 +48,35 @@ loadRank(os.path.join(current_file_path, "res/school.rank.csv"))

def split(txt):
tks = []
for t in re.sub(r"[ \t]+", " ",txt).split():
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and \
re.match(r"[a-zA-Z]", t) and tks:
for t in re.sub(r"[ \t]+", " ", txt).split():
if (
tks
and re.match(r".*[a-zA-Z]$", tks[-1])
and re.match(r"[a-zA-Z]", t)
and tks
):
tks[-1] = tks[-1] + " " + t
else:tks.append(t)
else:
tks.append(t)
return tks


def select(nm):
global TBL
if not nm:return
if isinstance(nm, list):nm = str(nm[0])
if not nm:
return
if isinstance(nm, list):
nm = str(nm[0])
nm = split(nm)[0]
nm = str(nm).lower().strip()
nm = re.sub(r"[((][^()()]+[))]", "", nm.lower())
nm = re.sub(r"(^the |[,.&()();;·]+|^(英国|美国|瑞士))", "", nm)
nm = re.sub(r"大学.*学院", "大学", nm)
tbl = copy.deepcopy(TBL)
tbl["hit_alias"] = tbl["alias"].map(lambda x:nm in set(x.split("+")))
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | (tbl.hit_alias == True))]
if res.empty:return
tbl["hit_alias"] = tbl["alias"].map(lambda x: nm in set(x.split("+")))
res = tbl[((tbl.name_cn == nm) | (tbl.name_en == nm) | tbl.hit_alias)]
if res.empty:
return

return json.loads(res.to_json(orient="records"))[0]

@@ -74,4 +86,3 @@ def is_good(nm):
nm = re.sub(r"[((][^()()]+[))]", "", nm.lower())
nm = re.sub(r"[''`‘’“”,. &()();;]+", "", nm)
return nm in GOOD_SCH


+ 202
- 106
deepdoc/parser/resume/step_two.py Visa fil

@@ -25,7 +25,8 @@ from xpinyin import Pinyin
from contextlib import contextmanager


class TimeoutException(Exception): pass
class TimeoutException(Exception):
pass


@contextmanager
@@ -50,8 +51,10 @@ def rmHtmlTag(line):


def highest_degree(dg):
if not dg: return ""
if type(dg) == type(""): dg = [dg]
if not dg:
return ""
if isinstance(dg, str):
dg = [dg]
m = {"初中": 0, "高中": 1, "中专": 2, "大专": 3, "专升本": 4, "本科": 5, "硕士": 6, "博士": 7, "博士后": 8}
return sorted([(d, m.get(d, -1)) for d in dg], key=lambda x: x[1] * -1)[0][0]

@@ -68,10 +71,12 @@ def forEdu(cv):
for ii, n in enumerate(sorted(cv["education_obj"], key=lambda x: x.get("start_time", "3"))):
e = {}
if n.get("end_time"):
if n["end_time"] > edu_end_dt: edu_end_dt = n["end_time"]
if n["end_time"] > edu_end_dt:
edu_end_dt = n["end_time"]
try:
dt = n["end_time"]
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt)
if re.match(r"[0-9]{9,}", dt):
dt = turnTm2Dt(dt)
y, m, d = getYMD(dt)
ed_dt.append(str(y))
e["end_dt_kwd"] = str(y)
@@ -80,7 +85,8 @@ def forEdu(cv):
if n.get("start_time"):
try:
dt = n["start_time"]
if re.match(r"[0-9]{9,}", dt): dt = turnTm2Dt(dt)
if re.match(r"[0-9]{9,}", dt):
dt = turnTm2Dt(dt)
y, m, d = getYMD(dt)
st_dt.append(str(y))
e["start_dt_kwd"] = str(y)
@@ -89,13 +95,20 @@ def forEdu(cv):

r = schools.select(n.get("school_name", ""))
if r:
if str(r.get("type", "")) == "1": fea.append("211")
if str(r.get("type", "")) == "2": fea.append("211")
if str(r.get("is_abroad", "")) == "1": fea.append("留学")
if str(r.get("is_double_first", "")) == "1": fea.append("双一流")
if str(r.get("is_985", "")) == "1": fea.append("985")
if str(r.get("is_world_known", "")) == "1": fea.append("海外知名")
if r.get("rank") and cv["school_rank_int"] > r["rank"]: cv["school_rank_int"] = r["rank"]
if str(r.get("type", "")) == "1":
fea.append("211")
if str(r.get("type", "")) == "2":
fea.append("211")
if str(r.get("is_abroad", "")) == "1":
fea.append("留学")
if str(r.get("is_double_first", "")) == "1":
fea.append("双一流")
if str(r.get("is_985", "")) == "1":
fea.append("985")
if str(r.get("is_world_known", "")) == "1":
fea.append("海外知名")
if r.get("rank") and cv["school_rank_int"] > r["rank"]:
cv["school_rank_int"] = r["rank"]

if n.get("school_name") and isinstance(n["school_name"], str):
sch.append(re.sub(r"(211|985|重点大学|[,&;;-])", "", n["school_name"]))
@@ -106,22 +119,25 @@ def forEdu(cv):
maj.append(n["discipline_name"])
e["major_kwd"] = n["discipline_name"]

if not n.get("degree") and "985" in fea and not first_fea: n["degree"] = "1"
if not n.get("degree") and "985" in fea and not first_fea:
n["degree"] = "1"

if n.get("degree"):
d = degrees.get_name(n["degree"])
if d: e["degree_kwd"] = d
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)",
n.get(
"school_name",
""))): d = "专升本"
if d: deg.append(d)
if d:
e["degree_kwd"] = d
if d == "本科" and ("专科" in deg or "专升本" in deg or "中专" in deg or "大专" in deg or re.search(r"(成人|自考|自学考试)", n.get("school_name",""))):
d = "专升本"
if d:
deg.append(d)

# for first degree
if not fdeg and d in ["中专", "专升本", "专科", "本科", "大专"]:
fdeg = [d]
if n.get("school_name"): fsch = [n["school_name"]]
if n.get("discipline_name"): fmaj = [n["discipline_name"]]
if n.get("school_name"):
fsch = [n["school_name"]]
if n.get("discipline_name"):
fmaj = [n["discipline_name"]]
first_fea = copy.deepcopy(fea)

edu_nst.append(e)
@@ -140,16 +156,26 @@ def forEdu(cv):
else:
cv["sch_rank_kwd"].append("一般学校")

if edu_nst: cv["edu_nst"] = edu_nst
if fea: cv["edu_fea_kwd"] = list(set(fea))
if first_fea: cv["edu_first_fea_kwd"] = list(set(first_fea))
if maj: cv["major_kwd"] = maj
if fsch: cv["first_school_name_kwd"] = fsch
if fdeg: cv["first_degree_kwd"] = fdeg
if fmaj: cv["first_major_kwd"] = fmaj
if st_dt: cv["edu_start_kwd"] = st_dt
if ed_dt: cv["edu_end_kwd"] = ed_dt
if ed_dt: cv["edu_end_int"] = max([int(t) for t in ed_dt])
if edu_nst:
cv["edu_nst"] = edu_nst
if fea:
cv["edu_fea_kwd"] = list(set(fea))
if first_fea:
cv["edu_first_fea_kwd"] = list(set(first_fea))
if maj:
cv["major_kwd"] = maj
if fsch:
cv["first_school_name_kwd"] = fsch
if fdeg:
cv["first_degree_kwd"] = fdeg
if fmaj:
cv["first_major_kwd"] = fmaj
if st_dt:
cv["edu_start_kwd"] = st_dt
if ed_dt:
cv["edu_end_kwd"] = ed_dt
if ed_dt:
cv["edu_end_int"] = max([int(t) for t in ed_dt])
if deg:
if "本科" in deg and "专科" in deg:
deg.append("专升本")
@@ -158,8 +184,10 @@ def forEdu(cv):
cv["highest_degree_kwd"] = highest_degree(deg)
if edu_end_dt:
try:
if re.match(r"[0-9]{9,}", edu_end_dt): edu_end_dt = turnTm2Dt(edu_end_dt)
if edu_end_dt.strip("\n") == "至今": edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
if re.match(r"[0-9]{9,}", edu_end_dt):
edu_end_dt = turnTm2Dt(edu_end_dt)
if edu_end_dt.strip("\n") == "至今":
edu_end_dt = cv.get("updated_at_dt", str(datetime.date.today()))
y, m, d = getYMD(edu_end_dt)
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
except Exception as e:
@@ -171,7 +199,8 @@ def forEdu(cv):
or not cv.get("degree_kwd"):
for c in sch:
if schools.is_good(c):
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好学校")
cv["tag_kwd"].append("好学历")
break
@@ -180,28 +209,39 @@ def forEdu(cv):
any([d.lower() in ["硕士", "博士", "mba", "博士"] for d in cv.get("degree_kwd", [])])) \
or all([d.lower() in ["硕士", "博士", "mba", "博士后"] for d in cv.get("degree_kwd", [])]) \
or any([d in ["mba", "emba", "博士后"] for d in cv.get("degree_kwd", [])]):
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "好学历" not in cv["tag_kwd"]: cv["tag_kwd"].append("好学历")

if cv.get("major_kwd"): cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
if cv.get("school_name_kwd"): cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
if cv.get("first_school_name_kwd"): cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
if cv.get("first_major_kwd"): cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
if "好学历" not in cv["tag_kwd"]:
cv["tag_kwd"].append("好学历")

if cv.get("major_kwd"):
cv["major_tks"] = rag_tokenizer.tokenize(" ".join(maj))
if cv.get("school_name_kwd"):
cv["school_name_tks"] = rag_tokenizer.tokenize(" ".join(sch))
if cv.get("first_school_name_kwd"):
cv["first_school_name_tks"] = rag_tokenizer.tokenize(" ".join(fsch))
if cv.get("first_major_kwd"):
cv["first_major_tks"] = rag_tokenizer.tokenize(" ".join(fmaj))

return cv


def forProj(cv):
if not cv.get("project_obj"): return cv
if not cv.get("project_obj"):
return cv

pro_nms, desc = [], []
for i, n in enumerate(
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if type(x) == type({}) else "",
sorted(cv.get("project_obj", []), key=lambda x: str(x.get("updated_at", "")) if isinstance(x, dict) else "",
reverse=True)):
if n.get("name"): pro_nms.append(n["name"])
if n.get("describe"): desc.append(str(n["describe"]))
if n.get("responsibilities"): desc.append(str(n["responsibilities"]))
if n.get("achivement"): desc.append(str(n["achivement"]))
if n.get("name"):
pro_nms.append(n["name"])
if n.get("describe"):
desc.append(str(n["describe"]))
if n.get("responsibilities"):
desc.append(str(n["responsibilities"]))
if n.get("achivement"):
desc.append(str(n["achivement"]))

if pro_nms:
# cv["pro_nms_tks"] = rag_tokenizer.tokenize(" ".join(pro_nms))
@@ -233,15 +273,16 @@ def forWork(cv):
work_st_tm = ""
corp_tags = []
for i, n in enumerate(
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if type(x) == type({}) else "",
sorted(cv.get("work_obj", []), key=lambda x: str(x.get("start_time", "")) if isinstance(x, dict) else "",
reverse=True)):
if type(n) == type(""):
if isinstance(n, str):
try:
n = json_loads(n)
except Exception:
continue

if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"]
if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm):
work_st_tm = n["start_time"]
for c in flds:
if not n.get(c) or str(n[c]) == '0':
fea[c].append("")
@@ -262,14 +303,18 @@ def forWork(cv):
fea[c].append(rmHtmlTag(str(n[c]).lower()))

y, m, d = getYMD(n.get("start_time"))
if not y or not m: continue
if not y or not m:
continue
st = "%s-%02d-%02d" % (y, int(m), int(d))
latest_job_tm = st

y, m, d = getYMD(n.get("end_time"))
if (not y or not m) and i > 0: continue
if not y or not m or int(y) > 2022: y, m, d = getYMD(str(n.get("updated_at", "")))
if not y or not m: continue
if (not y or not m) and i > 0:
continue
if not y or not m or int(y) > 2022:
y, m, d = getYMD(str(n.get("updated_at", "")))
if not y or not m:
continue
ed = "%s-%02d-%02d" % (y, int(m), int(d))

try:
@@ -279,22 +324,28 @@ def forWork(cv):

if n.get("scale"):
r = re.search(r"^([0-9]+)", str(n["scale"]))
if r: scales.append(int(r.group(1)))
if r:
scales.append(int(r.group(1)))

if goodcorp:
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好公司")
if goodcorp_:
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].append("好公司(曾)")

if corp_tags:
if "tag_kwd" not in cv: cv["tag_kwd"] = []
if "tag_kwd" not in cv:
cv["tag_kwd"] = []
cv["tag_kwd"].extend(corp_tags)
cv["corp_tag_kwd"] = [c for c in corp_tags if re.match(r"(综合|行业)", c)]

if latest_job_tm: cv["latest_job_dt"] = latest_job_tm
if fea["corporation_id"]: cv["corporation_id"] = fea["corporation_id"]
if latest_job_tm:
cv["latest_job_dt"] = latest_job_tm
if fea["corporation_id"]:
cv["corporation_id"] = fea["corporation_id"]

if fea["position_name"]:
cv["position_name_tks"] = rag_tokenizer.tokenize(fea["position_name"][0])
@@ -317,18 +368,23 @@ def forWork(cv):
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(fea["responsibilities"][0])
cv["resp_ltks"] = rag_tokenizer.tokenize(" ".join(fea["responsibilities"][1:]))

if fea["subordinates_count"]: fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
if fea["subordinates_count"]:
fea["subordinates_count"] = [int(i) for i in fea["subordinates_count"] if
re.match(r"[^0-9]+$", str(i))]
if fea["subordinates_count"]: cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])
if fea["subordinates_count"]:
cv["max_sub_cnt_int"] = np.max(fea["subordinates_count"])

if type(cv.get("corporation_id")) == type(1): cv["corporation_id"] = [str(cv["corporation_id"])]
if not cv.get("corporation_id"): cv["corporation_id"] = []
if isinstance(cv.get("corporation_id"), int):
cv["corporation_id"] = [str(cv["corporation_id"])]
if not cv.get("corporation_id"):
cv["corporation_id"] = []
for i in cv.get("corporation_id", []):
cv["baike_flt"] = max(corporations.baike(i), cv["baike_flt"] if "baike_flt" in cv else 0)

if work_st_tm:
try:
if re.match(r"[0-9]{9,}", work_st_tm): work_st_tm = turnTm2Dt(work_st_tm)
if re.match(r"[0-9]{9,}", work_st_tm):
work_st_tm = turnTm2Dt(work_st_tm)
y, m, d = getYMD(work_st_tm)
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
except Exception as e:
@@ -339,28 +395,37 @@ def forWork(cv):
cv["dua_flt"] = np.mean(duas)
cv["cur_dua_int"] = duas[0]
cv["job_num_int"] = len(duas)
if scales: cv["scale_flt"] = np.max(scales)
if scales:
cv["scale_flt"] = np.max(scales)
return cv


def turnTm2Dt(b):
if not b: return
if not b:
return
b = str(b).strip()
if re.match(r"[0-9]{10,}", b): b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
if re.match(r"[0-9]{10,}", b):
b = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(b[:10])))
return b


def getYMD(b):
y, m, d = "", "", "01"
if not b: return (y, m, d)
if not b:
return (y, m, d)
b = turnTm2Dt(b)
if re.match(r"[0-9]{4}", b): y = int(b[:4])
if re.match(r"[0-9]{4}", b):
y = int(b[:4])
r = re.search(r"[0-9]{4}.?([0-9]{1,2})", b)
if r: m = r.group(1)
if r:
m = r.group(1)
r = re.search(r"[0-9]{4}.?[0-9]{,2}.?([0-9]{1,2})", b)
if r: d = r.group(1)
if not d or int(d) == 0 or int(d) > 31: d = "1"
if not m or int(m) > 12 or int(m) < 1: m = "1"
if r:
d = r.group(1)
if not d or int(d) == 0 or int(d) > 31:
d = "1"
if not m or int(m) > 12 or int(m) < 1:
m = "1"
return (y, m, d)


@@ -369,7 +434,8 @@ def birth(cv):
cv["integerity_flt"] *= 0.9
return cv
y, m, d = getYMD(cv["birth"])
if not m or not y: return cv
if not m or not y:
return cv
b = "%s-%02d-%02d" % (y, int(m), int(d))
cv["birth_dt"] = b
cv["birthday_kwd"] = "%02d%02d" % (int(m), int(d))
@@ -380,7 +446,8 @@ def birth(cv):

def parse(cv):
for k in cv.keys():
if cv[k] == '\\N': cv[k] = ''
if cv[k] == '\\N':
cv[k] = ''
# cv = cv.asDict()
tks_fld = ["address", "corporation_name", "discipline_name", "email", "expect_city_names",
"expect_industry_name", "expect_position_name", "industry_name", "industry_names", "name",
@@ -402,9 +469,12 @@ def parse(cv):

rmkeys = []
for k in cv.keys():
if cv[k] is None: rmkeys.append(k)
if (type(cv[k]) == type([]) or type(cv[k]) == type("")) and len(cv[k]) == 0: rmkeys.append(k)
for k in rmkeys: del cv[k]
if cv[k] is None:
rmkeys.append(k)
if (isinstance(cv[k], list) or isinstance(cv[k], str)) and len(cv[k]) == 0:
rmkeys.append(k)
for k in rmkeys:
del cv[k]

integerity = 0.
flds_num = 0.
@@ -414,7 +484,8 @@ def parse(cv):
flds_num += len(flds)
for f in flds:
v = str(cv.get(f, ""))
if len(v) > 0 and v != '0' and v != '[]': integerity += 1
if len(v) > 0 and v != '0' and v != '[]':
integerity += 1

hasValues(tks_fld)
hasValues(small_tks_fld)
@@ -433,7 +504,8 @@ def parse(cv):
(r"[ ()\(\)人/·0-9-]+", ""),
(r".*(元|规模|于|=|北京|上海|至今|中国|工资|州|shanghai|强|餐饮|融资|职).*", "")]:
cv["corporation_type"] = re.sub(p, r, cv["corporation_type"], 1000, re.IGNORECASE)
if len(cv["corporation_type"]) < 2: del cv["corporation_type"]
if len(cv["corporation_type"]) < 2:
del cv["corporation_type"]

if cv.get("political_status"):
for p, r in [
@@ -441,9 +513,11 @@ def parse(cv):
(r".*(无党派|公民).*", "群众"),
(r".*团员.*", "团员")]:
cv["political_status"] = re.sub(p, r, cv["political_status"])
if not re.search(r"[党团群]", cv["political_status"]): del cv["political_status"]
if not re.search(r"[党团群]", cv["political_status"]):
del cv["political_status"]

if cv.get("phone"): cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))
if cv.get("phone"):
cv["phone"] = re.sub(r"^0*86([0-9]{11})", r"\1", re.sub(r"[^0-9]+", "", cv["phone"]))

keys = list(cv.keys())
for k in keys:
@@ -454,9 +528,11 @@ def parse(cv):
cv[k] = [a for _, a in cv[k].items()]
nms = []
for n in cv[k]:
if type(n) != type({}) or "name" not in n or not n.get("name"): continue
if not isinstance(n, dict) or "name" not in n or not n.get("name"):
continue
n["name"] = re.sub(r"((442)|\t )", "", n["name"]).strip().lower()
if not n["name"]: continue
if not n["name"]:
continue
nms.append(n["name"])
if nms:
t = k[:-4]
@@ -469,15 +545,18 @@ def parse(cv):
# tokenize fields
if k in tks_fld:
cv[f"{k}_tks"] = rag_tokenizer.tokenize(cv[k])
if k in small_tks_fld: cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])
if k in small_tks_fld:
cv[f"{k}_sm_tks"] = rag_tokenizer.tokenize(cv[f"{k}_tks"])

# keyword fields
if k in kwd_fld: cv[f"{k}_kwd"] = [n.lower()
if k in kwd_fld:
cv[f"{k}_kwd"] = [n.lower()
for n in re.split(r"[\t,,;;. ]",
re.sub(r"([^a-zA-Z])[ ]+([^a-zA-Z ])", r"\1,\2", cv[k])
) if n]

if k in num_fld and cv.get(k): cv[f"{k}_int"] = cv[k]
if k in num_fld and cv.get(k):
cv[f"{k}_int"] = cv[k]

cv["email_kwd"] = cv.get("email_tks", "").replace(" ", "")
# for name field
@@ -501,10 +580,12 @@ def parse(cv):
cv["name_py_pref0_tks"] = ""
cv["name_py_pref_tks"] = ""
for py in PY.get_pinyins(nm[:20], ''):
for i in range(2, len(py) + 1): cv["name_py_pref_tks"] += " " + py[:i]
for i in range(2, len(py) + 1):
cv["name_py_pref_tks"] += " " + py[:i]
for py in PY.get_pinyins(nm[:20], ' '):
py = py.split()
for i in range(1, len(py) + 1): cv["name_py_pref0_tks"] += " " + "".join(py[:i])
for i in range(1, len(py) + 1):
cv["name_py_pref0_tks"] += " " + "".join(py[:i])

cv["name_kwd"] = name
cv["name_pinyin_kwd"] = PY.get_pinyins(nm[:20], ' ')[:3]
@@ -526,22 +607,30 @@ def parse(cv):
cv["updated_at_dt"] = cv["updated_at"].strftime('%Y-%m-%d %H:%M:%S')
else:
y, m, d = getYMD(str(cv.get("updated_at", "")))
if not y: y = "2012"
if not m: m = "01"
if not d: d = "01"
if not y:
y = "2012"
if not m:
m = "01"
if not d:
d = "01"
cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
# long text tokenize

if cv.get("responsibilities"): cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
if cv.get("responsibilities"):
cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))

# for yes or no field
fea = []
for f, y, n in is_fld:
if f not in cv: continue
if cv[f] == '是': fea.append(y)
if cv[f] == '否': fea.append(n)
if f not in cv:
continue
if cv[f] == '是':
fea.append(y)
if cv[f] == '否':
fea.append(n)

if fea: cv["tag_kwd"] = fea
if fea:
cv["tag_kwd"] = fea

cv = forEdu(cv)
cv = forProj(cv)
@@ -550,9 +639,11 @@ def parse(cv):

cv["corp_proj_sch_deg_kwd"] = [c for c in cv.get("corp_tag_kwd", [])]
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
for j in cv.get("sch_rank_kwd", []): cv["corp_proj_sch_deg_kwd"][i] += "+" + j
for j in cv.get("sch_rank_kwd", []):
cv["corp_proj_sch_deg_kwd"][i] += "+" + j
for i in range(len(cv["corp_proj_sch_deg_kwd"])):
if cv.get("highest_degree_kwd"): cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]
if cv.get("highest_degree_kwd"):
cv["corp_proj_sch_deg_kwd"][i] += "+" + cv["highest_degree_kwd"]

try:
if not cv.get("work_exp_flt") and cv.get("work_start_time"):
@@ -565,17 +656,21 @@ def parse(cv):
cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y)
except Exception as e:
logging.exception("parse {} ==> {}".format(e, cv.get("work_start_time")))
if "work_exp_flt" not in cv and cv.get("work_experience", 0): cv["work_exp_flt"] = int(cv["work_experience"]) / 12.
if "work_exp_flt" not in cv and cv.get("work_experience", 0):
cv["work_exp_flt"] = int(cv["work_experience"]) / 12.

keys = list(cv.keys())
for k in keys:
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k): del cv[k]
if not re.search(r"_(fea|tks|nst|dt|int|flt|ltks|kwd|id)$", k):
del cv[k]
for k in cv.keys():
if not re.search("_(kwd|id)$", k) or type(cv[k]) != type([]): continue
if not re.search("_(kwd|id)$", k) or not isinstance(cv[k], list):
continue
cv[k] = list(set([re.sub("(市)$", "", str(n)) for n in cv[k] if n not in ['中国', '0']]))
keys = [k for k in cv.keys() if re.search(r"_feas*$", k)]
for k in keys:
if cv[k] <= 0: del cv[k]
if cv[k] <= 0:
del cv[k]

cv["tob_resume_id"] = str(cv["tob_resume_id"])
cv["id"] = cv["tob_resume_id"]
@@ -592,5 +687,6 @@ def dealWithInt64(d):
if isinstance(d, list):
d = [dealWithInt64(t) for t in d]

if isinstance(d, np.integer): d = int(d)
if isinstance(d, np.integer):
d = int(d)
return d

+ 2
- 1
deepdoc/parser/txt_parser.py Visa fil

@@ -51,6 +51,7 @@ class RAGFlowTxtParser:
dels = [d for d in dels if d]
dels = "|".join(dels)
secs = re.split(r"(%s)" % dels, txt)
for sec in secs: add_chunk(sec)
for sec in secs:
add_chunk(sec)

return [[c, ""] for c in cks]

+ 13
- 4
deepdoc/vision/__init__.py Visa fil

@@ -18,7 +18,6 @@ from .recognizer import Recognizer
from .layout_recognizer import LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer


def init_in_out(args):
from PIL import Image
import os
@@ -47,7 +46,7 @@ def init_in_out(args):
try:
images.append(Image.open(fnm))
outputs.append(os.path.split(fnm)[-1])
except Exception as e:
except Exception:
traceback.print_exc()

if os.path.isdir(args.inputs):
@@ -56,6 +55,16 @@ def init_in_out(args):
else:
images_and_outputs(args.inputs)

for i in range(len(outputs)): outputs[i] = os.path.join(args.output_dir, outputs[i])
for i in range(len(outputs)):
outputs[i] = os.path.join(args.output_dir, outputs[i])

return images, outputs


return images, outputs
__all__ = [
"OCR",
"Recognizer",
"LayoutRecognizer",
"TableStructureRecognizer",
"init_in_out",
]

+ 2
- 2
deepdoc/vision/layout_recognizer.py Visa fil

@@ -42,7 +42,7 @@ class LayoutRecognizer(Recognizer):
get_project_base_directory(),
"rag/res/deepdoc")
super().__init__(self.labels, domain, model_dir)
except Exception as e:
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)
@@ -77,7 +77,7 @@ class LayoutRecognizer(Recognizer):
"page_number": pn,
} for b in lts if float(b["score"]) >= 0.8 or b["type"] not in self.garbage_layouts]
lts = self.sort_Y_firstly(lts, np.mean(
[l["bottom"] - l["top"] for l in lts]) / 2)
[lt["bottom"] - lt["top"] for lt in lts]) / 2)
lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts)


+ 3
- 1
deepdoc/vision/ocr.py Visa fil

@@ -19,7 +19,9 @@ from huggingface_hub import snapshot_download

from api.utils.file_utils import get_project_base_directory
from .operators import *
import math
import numpy as np
import cv2
import onnxruntime as ort

from .postprocess import build_post_process
@@ -484,7 +486,7 @@ class OCR(object):
"rag/res/deepdoc")
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
except Exception as e:
except Exception:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False)

+ 3
- 3
deepdoc/vision/operators.py Visa fil

@@ -232,7 +232,7 @@ class LinearResize(object):
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
_im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
@@ -255,7 +255,7 @@ class LinearResize(object):
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
_im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
@@ -581,7 +581,7 @@ class SRResize(object):
return data

images_HR = data["image_hr"]
label_strs = data["label"]
_label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR

+ 1
- 1
deepdoc/vision/postprocess.py Visa fil

@@ -121,7 +121,7 @@ class DBPostProcess(object):
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
_img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]


+ 10
- 4
deepdoc/vision/recognizer.py Visa fil

@@ -13,15 +13,18 @@

import logging
import os
import math
import numpy as np
import cv2
from copy import deepcopy


import onnxruntime as ort
from huggingface_hub import snapshot_download

from api.utils.file_utils import get_project_base_directory
from .operators import *


class Recognizer(object):
def __init__(self, label_list, task_name, model_dir=None):
"""
@@ -277,7 +280,8 @@ class Recognizer(object):
return
min_dis, min_i = 1000000, None
for i,b in enumerate(boxes):
if box.get("layoutno", "0") != b.get("layoutno", "0"): continue
if box.get("layoutno", "0") != b.get("layoutno", "0"):
continue
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
if dis < min_dis:
min_i = i
@@ -402,7 +406,8 @@ class Recognizer(object):
scores = np.max(boxes[:, 4:], axis=1)
boxes = boxes[scores > thr, :]
scores = scores[scores > thr]
if len(boxes) == 0: return []
if len(boxes) == 0:
return []

# Get the class with the highest confidence
class_ids = np.argmax(boxes[:, 4:], axis=1)
@@ -432,7 +437,8 @@ class Recognizer(object):
for i in range(len(image_list)):
if not isinstance(image_list[i], np.ndarray):
imgs.append(np.array(image_list[i]))
else: imgs.append(image_list[i])
else:
imgs.append(image_list[i])

batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
for i in range(batch_loop_cnt):

+ 4
- 2
graphrag/community_reports_extractor.py Visa fil

@@ -88,7 +88,8 @@ class CommunityReportsExtractor:
("findings", list),
("rating", float),
("rating_explanation", str),
]): continue
]):
continue
response["weight"] = weight
response["entities"] = ents
except Exception as e:
@@ -100,7 +101,8 @@ class CommunityReportsExtractor:
res_str.append(self._get_text_output(response))
res_dict.append(response)
over += 1
if callback: callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
if callback:
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")

return CommunityReportsResult(
structured_output=res_dict,

+ 1
- 0
graphrag/entity_embedding.py Visa fil

@@ -8,6 +8,7 @@ Reference:
from typing import Any
import numpy as np
import networkx as nx
from dataclasses import dataclass
from graphrag.leiden import stable_largest_connected_component



+ 8
- 4
graphrag/graph_extractor.py Visa fil

@@ -129,9 +129,11 @@ class GraphExtractor:
source_doc_map[doc_index] = text
all_records[doc_index] = result
total_token_count += token_count
if callback: callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
if callback:
callback(msg=f"{doc_index+1}/{total}, elapsed: {timer() - st}s, used tokens: {total_token_count}")
except Exception as e:
if callback: callback(msg="Knowledge graph extraction error:{}".format(str(e)))
if callback:
callback(msg="Knowledge graph extraction error:{}".format(str(e)))
logging.exception("error extracting graph")
self._on_error(
e,
@@ -164,7 +166,8 @@ class GraphExtractor:
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
gen_conf = {"temperature": 0.3}
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
if response.find("**ERROR**") >= 0: raise Exception(response)
if response.find("**ERROR**") >= 0:
raise Exception(response)
token_count = num_tokens_from_string(text + response)

results = response or ""
@@ -175,7 +178,8 @@ class GraphExtractor:
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
history.append({"role": "user", "content": text})
response = self._llm.chat("", history, gen_conf)
if response.find("**ERROR**") >=0: raise Exception(response)
if response.find("**ERROR**") >=0:
raise Exception(response)
results += response or ""

# if this is the final glean, don't bother updating the continuation flag

+ 2
- 1
graphrag/index.py Visa fil

@@ -134,7 +134,8 @@ def build_knowledge_graph_chunks(tenant_id: str, chunks: list[str], callback, en
callback(0.75, "Extracting mind graph.")
mindmap = MindMapExtractor(llm_bdl)
mg = mindmap(_chunks).output
if not len(mg.keys()): return chunks
if not len(mg.keys()):
return chunks

logging.debug(json.dumps(mg, ensure_ascii=False, indent=2))
chunks.append(

+ 8
- 4
graphrag/leiden.py Visa fil

@@ -78,7 +78,8 @@ def _compute_leiden_communities(
) -> dict[int, dict[str, int]]:
"""Return Leiden root communities."""
results: dict[int, dict[str, int]] = {}
if is_empty(graph): return results
if is_empty(graph):
return results
if use_lcc:
graph = stable_largest_connected_component(graph)

@@ -100,7 +101,8 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
logging.debug(
"Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc
)
if not graph.nodes(): return {}
if not graph.nodes():
return {}

node_id_to_community_map = _compute_leiden_communities(
graph=graph,
@@ -125,9 +127,11 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
result[community_id]["nodes"].append(node_id)
result[community_id]["weight"] += graph.nodes[node_id].get("rank", 0) * graph.nodes[node_id].get("weight", 1)
weights = [comm["weight"] for _, comm in result.items()]
if not weights:continue
if not weights:
continue
max_weight = max(weights)
for _, comm in result.items(): comm["weight"] /= max_weight
for _, comm in result.items():
comm["weight"] /= max_weight

return results_by_level


+ 5
- 1
intergrations/chatgpt-on-wechat/plugins/__init__.py Visa fil

@@ -1 +1,5 @@
from .ragflow_chat import *
from .ragflow_chat import RAGFlowChat

__all__ = [
"RAGFlowChat"
]

+ 0
- 1
intergrations/chatgpt-on-wechat/plugins/ragflow_chat.py Visa fil

@@ -2,7 +2,6 @@ import logging
import requests
from bridge.context import ContextType # Import Context, ContextType
from bridge.reply import Reply, ReplyType # Import Reply, ReplyType
from bridge import *
from plugins import Plugin, register # Import Plugin and register
from plugins.event import Event, EventContext, EventAction # Import event-related classes


+ 3
- 3
rag/app/book.py Visa fil

@@ -94,7 +94,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
@@ -102,7 +102,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")
@@ -112,7 +112,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [(l, "") for l in sections if l]
sections = [(line, "") for line in sections if line]
remove_contents_table(sections, eng=is_english(
random_choices([t for t, _ in sections], k=200)))
callback(0.8, "Finish parsing.")

+ 1
- 1
rag/app/email.py Visa fil

@@ -75,7 +75,7 @@ def chunk(
_add_content(msg, msg.get_content_type())

sections = TxtParser.parser_txt("\n".join(text_txt)) + [
(l, "") for l in HtmlParser.parser_txt("\n".join(html_txt)) if l
(line, "") for line in HtmlParser.parser_txt("\n".join(html_txt)) if line
]

st = timer()

+ 2
- 1
rag/app/knowledge_graph.py Visa fil

@@ -18,7 +18,8 @@ def chunk(filename, binary, tenant_id, from_page=0, to_page=100000,
chunks = build_knowledge_graph_chunks(tenant_id, sections, callback,
parser_config.get("entity_types", ["organization", "person", "location", "event", "time"])
)
for c in chunks: c["docnm_kwd"] = filename
for c in chunks:
c["docnm_kwd"] = filename

doc = {
"docnm_kwd": filename,

+ 11
- 8
rag/app/laws.py Visa fil

@@ -48,7 +48,7 @@ class Docx(DocxParser):
continue
if 'w:br' in run._element.xml and 'type="page"' in run._element.xml:
pn += 1
return [l for l in lines if l]
return [line for line in lines if line]

def __call__(self, filename, binary=None, from_page=0, to_page=100000):
self.doc = Document(
@@ -60,7 +60,8 @@ class Docx(DocxParser):
if pn > to_page:
break
question_level, p_text = docx_question_level(p, bull)
if not p_text.strip("\n"):continue
if not p_text.strip("\n"):
continue
lines.append((question_level, p_text))

for run in p.runs:
@@ -78,19 +79,21 @@ class Docx(DocxParser):
if lines[e][0] <= lines[s][0]:
break
e += 1
if e - s == 1 and visit[s]: continue
if e - s == 1 and visit[s]:
continue
sec = []
next_level = lines[s][0] + 1
while not sec and next_level < 22:
for i in range(s+1, e):
if lines[i][0] != next_level: continue
if lines[i][0] != next_level:
continue
sec.append(lines[i][1])
visit[i] = True
next_level += 1
sec.insert(0, lines[s][1])

sections.append("\n".join(sec))
return [l for l in sections if l]
return [s for s in sections if s]

def __str__(self) -> str:
return f'''
@@ -168,13 +171,13 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
sections = txt.split("\n")
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")

elif re.search(r"\.(htm|html)$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
sections = HtmlParser()(filename, binary)
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")

elif re.search(r"\.doc$", filename, re.IGNORECASE):
@@ -182,7 +185,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")

else:

+ 4
- 3
rag/app/manual.py Visa fil

@@ -190,7 +190,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
if sections and len(sections[0]) < 3:
sections = [(t, l, [[0] * 5]) for t, l in sections]
sections = [(t, lvl, [[0] * 5]) for t, lvl in sections]
# set pivot using the most frequent type of title,
# then merge between 2 pivot
if len(sections) > 0 and len(pdf_parser.outlines) / len(sections) > 0.1:
@@ -211,7 +211,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
else:
bull = bullets_category([txt for txt, _, _ in sections])
most_level, levels = title_frequency(
bull, [(txt, l) for txt, l, poss in sections])
bull, [(txt, lvl) for txt, lvl, _ in sections])

assert len(sections) == len(levels)
sec_ids = []
@@ -225,7 +225,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
sections = [(txt, sec_ids[i], poss)
for i, (txt, _, poss) in enumerate(sections)]
for (img, rows), poss in tbls:
if not rows: continue
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0], -1,
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))


+ 3
- 2
rag/app/one.py Visa fil

@@ -54,7 +54,8 @@ class Pdf(PdfParser):
sections = [(b["text"], self.get_position(b, zoomin))
for i, b in enumerate(self.boxes)]
for (img, rows), poss in tbls:
if not rows:continue
if not rows:
continue
sections.append((rows if isinstance(rows, str) else rows[0],
[(p[0] + 1 - from_page, p[1], p[2], p[3], p[4]) for p in poss]))
return [(txt, "") for txt, _ in sorted(sections, key=lambda x: (
@@ -109,7 +110,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
binary = BytesIO(binary)
doc_parsed = parser.from_buffer(binary)
sections = doc_parsed['content'].split('\n')
sections = [l for l in sections if l]
sections = [s for s in sections if s]
callback(0.8, "Finish parsing.")

else:

+ 17
- 13
rag/app/qa.py Visa fil

@@ -171,7 +171,7 @@ class Pdf(PdfParser):
tbl_bottom = tbls[tbl_index][1][0][4]
tbl_tag = "@@{}\t{:.1f}\t{:.1f}\t{:.1f}\t{:.1f}##" \
.format(tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom)
tbl_text = ''.join(tbls[tbl_index][0][1])
_tbl_text = ''.join(tbls[tbl_index][0][1])
return tbl_pn, tbl_left, tbl_right, tbl_top, tbl_bottom, tbl_tag,


@@ -325,9 +325,11 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
txt = get_text(filename, binary)
lines = txt.split("\n")
comma, tab = 0, 0
for l in lines:
if len(l.split(",")) == 2: comma += 1
if len(l.split("\t")) == 2: tab += 1
for line in lines:
if len(line.split(",")) == 2:
comma += 1
if len(line.split("\t")) == 2:
tab += 1
delimiter = "\t" if tab >= comma else ","

fails = []
@@ -336,18 +338,21 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
while i < len(lines):
arr = lines[i].split(delimiter)
if len(arr) != 2:
if question: answer += "\n" + lines[i]
if question:
answer += "\n" + lines[i]
else:
fails.append(str(i+1))
elif len(arr) == 2:
if question and answer: res.append(beAdoc(deepcopy(doc), question, answer, eng))
if question and answer:
res.append(beAdoc(deepcopy(doc), question, answer, eng))
question, answer = arr
i += 1
if len(res) % 999 == 0:
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))

if question: res.append(beAdoc(deepcopy(doc), question, answer, eng))
if question:
res.append(beAdoc(deepcopy(doc), question, answer, eng))

callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@@ -367,19 +372,18 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
callback(0.1, "Start to parse.")
txt = get_text(filename, binary)
lines = txt.split("\n")
last_question, last_answer = "", ""
_last_question, last_answer = "", ""
question_stack, level_stack = [], []
code_block = False
level_index = [-1] * 7
for index, l in enumerate(lines):
if l.strip().startswith('```'):
for index, line in enumerate(lines):
if line.strip().startswith('```'):
code_block = not code_block
question_level, question = 0, ''
if not code_block:
question_level, question = mdQuestionLevel(l)
question_level, question = mdQuestionLevel(line)

if not question_level or question_level > 6: # not a question
last_answer = f'{last_answer}\n{l}'
last_answer = f'{last_answer}\n{line}'
else: # is a question
if last_answer.strip():
sum_question = '\n'.join(question_stack)

+ 5
- 4
rag/app/table.py Visa fil

@@ -41,14 +41,16 @@ class Excel(ExcelParser):
for sheetname in wb.sheetnames:
ws = wb[sheetname]
rows = list(ws.rows)
if not rows:continue
if not rows:
continue
headers = [cell.value for cell in rows[0]]
missed = set([i for i, h in enumerate(headers) if h is None])
headers = [
cell.value for i,
cell in enumerate(
rows[0]) if i not in missed]
if not headers:continue
if not headers:
continue
data = []
for i, r in enumerate(rows[1:]):
rn += 1
@@ -88,7 +90,6 @@ def trans_bool(s):

def column_data_type(arr):
arr = list(arr)
uni = len(set([a for a in arr if a is not None]))
counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
@@ -157,7 +158,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000,
continue
if i >= to_page:
break
row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
row = [field for field in line.split(kwargs.get("delimiter", "\t"))]
if len(row) != len(headers):
fails.append(str(i))
continue

+ 122
- 10
rag/llm/__init__.py Visa fil

@@ -13,12 +13,124 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .embedding_model import *
from .chat_model import *
from .cv_model import *
from .rerank_model import *
from .sequence2txt_model import *
from .tts_model import *
from .embedding_model import (
OllamaEmbed,
LocalAIEmbed,
OpenAIEmbed,
AzureEmbed,
XinferenceEmbed,
QWenEmbed,
ZhipuEmbed,
FastEmbed,
YoudaoEmbed,
BaiChuanEmbed,
JinaEmbed,
DefaultEmbedding,
MistralEmbed,
BedrockEmbed,
GeminiEmbed,
NvidiaEmbed,
LmStudioEmbed,
OpenAI_APIEmbed,
CoHereEmbed,
TogetherAIEmbed,
PerfXCloudEmbed,
UpstageEmbed,
SILICONFLOWEmbed,
ReplicateEmbed,
BaiduYiyanEmbed,
VoyageEmbed,
HuggingFaceEmbed,
VolcEngineEmbed,
)
from .chat_model import (
GptTurbo,
AzureChat,
ZhipuChat,
QWenChat,
OllamaChat,
LocalAIChat,
XinferenceChat,
MoonshotChat,
DeepSeekChat,
VolcEngineChat,
BaiChuanChat,
MiniMaxChat,
MistralChat,
GeminiChat,
BedrockChat,
GroqChat,
OpenRouterChat,
StepFunChat,
NvidiaChat,
LmStudioChat,
OpenAI_APIChat,
CoHereChat,
LeptonAIChat,
TogetherAIChat,
PerfXCloudChat,
UpstageChat,
NovitaAIChat,
SILICONFLOWChat,
YiChat,
ReplicateChat,
HunyuanChat,
SparkChat,
BaiduYiyanChat,
AnthropicChat,
GoogleChat,
HuggingFaceChat,
)

from .cv_model import (
GptV4,
AzureGptV4,
OllamaCV,
XinferenceCV,
QWenCV,
Zhipu4V,
LocalCV,
GeminiCV,
OpenRouterCV,
LocalAICV,
NvidiaCV,
LmStudioCV,
StepFunCV,
OpenAI_APICV,
TogetherAICV,
YiCV,
HunyuanCV,
)
from .rerank_model import (
LocalAIRerank,
DefaultRerank,
JinaRerank,
YoudaoRerank,
XInferenceRerank,
NvidiaRerank,
LmStudioRerank,
OpenAI_APIRerank,
CoHereRerank,
TogetherAIRerank,
SILICONFLOWRerank,
BaiduYiyanRerank,
VoyageRerank,
QWenRerank,
)
from .sequence2txt_model import (
GPTSeq2txt,
QWenSeq2txt,
AzureSeq2txt,
XinferenceSeq2txt,
TencentCloudSeq2txt,
)
from .tts_model import (
FishAudioTTS,
QwenTTS,
OpenAITTS,
SparkTTS,
XinferenceTTS,
)

EmbeddingModel = {
"Ollama": OllamaEmbed,
@@ -48,7 +160,7 @@ EmbeddingModel = {
"BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed,
"HuggingFace": HuggingFaceEmbed,
"VolcEngine":VolcEngineEmbed,
"VolcEngine": VolcEngineEmbed,
}

CvModel = {
@@ -68,7 +180,7 @@ CvModel = {
"OpenAI-API-Compatible": OpenAI_APICV,
"TogetherAI": TogetherAICV,
"01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV
"Tencent Hunyuan": HunyuanCV,
}

ChatModel = {
@@ -111,7 +223,7 @@ ChatModel = {
}

RerankModel = {
"LocalAI":LocalAIRerank,
"LocalAI": LocalAIRerank,
"BAAI": DefaultRerank,
"Jina": JinaRerank,
"Youdao": YoudaoRerank,
@@ -132,7 +244,7 @@ Seq2txtModel = {
"Tongyi-Qianwen": QWenSeq2txt,
"Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt
"Tencent Cloud": TencentCloudSeq2txt,
}

TTSModel = {

+ 56
- 28
rag/llm/chat_model.py Visa fil

@@ -69,7 +69,8 @@ class Base(ABC):
stream=True,
**gen_conf)
for resp in response:
if not resp.choices: continue
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
@@ -81,7 +82,8 @@ class Base(ABC):
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
else:
total_tokens = resp.usage.total_tokens

if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
@@ -98,13 +100,15 @@ class Base(ABC):

class GptTurbo(Base):
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
super().__init__(key, model_name, base_url)


class MoonshotChat(Base):
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
if not base_url: base_url = "https://api.moonshot.cn/v1"
if not base_url:
base_url = "https://api.moonshot.cn/v1"
super().__init__(key, model_name, base_url)


@@ -128,7 +132,8 @@ class HuggingFaceChat(Base):

class DeepSeekChat(Base):
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
if not base_url: base_url = "https://api.deepseek.com/v1"
if not base_url:
base_url = "https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)


@@ -202,7 +207,8 @@ class BaiChuanChat(Base):
stream=True,
**self._format_params(gen_conf))
for resp in response:
if not resp.choices: continue
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
@@ -313,8 +319,10 @@ class ZhipuChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
try:
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
@@ -333,8 +341,10 @@ class ZhipuChat(Base):
def chat_streamly(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]
ans = ""
tk_count = 0
try:
@@ -345,7 +355,8 @@ class ZhipuChat(Base):
**gen_conf
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
@@ -354,7 +365,8 @@ class ZhipuChat(Base):
else:
ans += LENGTH_NOTIFICATION_EN
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@@ -372,11 +384,16 @@ class OllamaChat(Base):
history.insert(0, {"role": "system", "content": system})
try:
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
@@ -392,11 +409,16 @@ class OllamaChat(Base):
if system:
history.insert(0, {"role": "system", "content": system})
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_p"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
@@ -636,7 +658,8 @@ class MistralChat(Base):
messages=history,
**gen_conf)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content: continue
if not resp.choices or not resp.choices[0].delta.content:
continue
ans += resp.choices[0].delta.content
total_tokens += 1
if resp.choices[0].finish_reason == "length":
@@ -1196,7 +1219,8 @@ class SparkChat(Base):
assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
if model_name in model2version:
model_version = model2version[model_name]
else: model_version = model_name
else:
model_version = model_name
super().__init__(key, model_version, base_url)


@@ -1281,8 +1305,10 @@ class AnthropicChat(Base):
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]

ans = ""
try:
@@ -1312,8 +1338,10 @@ class AnthropicChat(Base):
self.system = system
if "max_tokens" not in gen_conf:
gen_conf["max_tokens"] = 4096
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
if "presence_penalty" in gen_conf:
del gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
del gen_conf["frequency_penalty"]

ans = ""
total_tokens = 0

+ 43
- 24
rag/llm/cv_model.py Visa fil

@@ -25,6 +25,7 @@ import base64
from io import BytesIO
import json
import requests
from transformers import GenerationConfig

from rag.nlp import is_english
from api.utils import get_uuid
@@ -77,14 +78,16 @@ class Base(ABC):
stream=True
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@@ -99,7 +102,7 @@ class Base(ABC):
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
except Exception:
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")

@@ -139,7 +142,8 @@ class Base(ABC):

class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url: base_url="https://api.openai.com/v1"
if not base_url:
base_url="https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
@@ -149,7 +153,8 @@ class GptV4(Base):
prompt = self.prompt(b64)
for i in range(len(prompt)):
for c in prompt[i]["content"]:
if "text" in c: c["type"] = "text"
if "text" in c:
c["type"] = "text"

res = self.client.chat.completions.create(
model=self.model_name,
@@ -171,7 +176,8 @@ class AzureGptV4(Base):
prompt = self.prompt(b64)
for i in range(len(prompt)):
for c in prompt[i]["content"]:
if "text" in c: c["type"] = "text"
if "text" in c:
c["type"] = "text"

res = self.client.chat.completions.create(
model=self.model_name,
@@ -344,14 +350,16 @@ class Zhipu4V(Base):
stream=True
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens
yield ans
except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
@@ -389,11 +397,16 @@ class OllamaCV(Base):
if his["role"] == "user":
his["images"] = [image]
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
response = self.client.chat(
model=self.model_name,
messages=history,
@@ -414,11 +427,16 @@ class OllamaCV(Base):
if his["role"] == "user":
his["images"] = [image]
options = {}
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
if "temperature" in gen_conf:
options["temperature"] = gen_conf["temperature"]
if "max_tokens" in gen_conf:
options["num_predict"] = gen_conf["max_tokens"]
if "top_p" in gen_conf:
options["top_k"] = gen_conf["top_p"]
if "presence_penalty" in gen_conf:
options["presence_penalty"] = gen_conf["presence_penalty"]
if "frequency_penalty" in gen_conf:
options["frequency_penalty"] = gen_conf["frequency_penalty"]
ans = ""
try:
response = self.client.chat(
@@ -469,7 +487,7 @@ class XinferenceCV(Base):

class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client, GenerativeModel, GenerationConfig
from google.generativeai import client, GenerativeModel
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
@@ -503,7 +521,7 @@ class GeminiCV(Base):
if his["role"] == "user":
his["parts"] = [his["content"]]
his.pop("content")
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
history[-1]["parts"].append("data:image/jpeg;base64," + image)

response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
@@ -519,7 +537,6 @@ class GeminiCV(Base):
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]

ans = ""
tk_count = 0
try:
for his in history:
if his["role"] == "assistant":
@@ -529,14 +546,15 @@ class GeminiCV(Base):
if his["role"] == "user":
his["parts"] = [his["content"]]
his.pop("content")
history[-1]["parts"].append(f"data:image/jpeg;base64," + image)
history[-1]["parts"].append("data:image/jpeg;base64," + image)

response = self.model.generate_content(history, generation_config=GenerationConfig(
max_output_tokens=gen_conf.get("max_tokens", 1000), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7)), stream=True)

for resp in response:
if not resp.text: continue
if not resp.text:
continue
ans += resp.text
yield ans
except Exception as e:
@@ -632,7 +650,8 @@ class NvidiaCV(Base):

class StepFunCV(GptV4):
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url: base_url="https://api.stepfun.com/v1"
if not base_url:
base_url="https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang

+ 2
- 4
rag/llm/sequence2txt_model.py Visa fil

@@ -15,12 +15,9 @@
#
import requests
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from openai import OpenAI
import os
import json
from rag.utils import num_tokens_from_string
import base64
@@ -49,7 +46,8 @@ class Base(ABC):

class GPTSeq2txt(Base):
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name


+ 2
- 2
rag/llm/tts_model.py Visa fil

@@ -16,7 +16,6 @@

import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
@@ -175,7 +174,8 @@ class QwenTTS(Base):

class OpenAITTS(Base):
def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
if not base_url: base_url = "https://api.openai.com/v1"
if not base_url:
base_url = "https://api.openai.com/v1"
self.api_key = key
self.model_name = model_name
self.base_url = base_url

+ 15
- 9
rag/nlp/__init__.py Visa fil

@@ -222,7 +222,8 @@ def bullets_category(sections):

def is_english(texts):
eng = 0
if not texts: return False
if not texts:
return False
for t in texts:
if re.match(r"[ `a-zA-Z.,':;/\"?<>!\(\)-]", t.strip()):
eng += 1
@@ -250,7 +251,8 @@ def tokenize_chunks(chunks, doc, eng, pdf_parser=None):
res = []
# wrap up as es documents
for ck in chunks:
if len(ck.strip()) == 0:continue
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
if pdf_parser:
@@ -269,7 +271,8 @@ def tokenize_chunks_docx(chunks, doc, eng, images):
res = []
# wrap up as es documents
for ck, image in zip(chunks, images):
if len(ck.strip()) == 0:continue
if len(ck.strip()) == 0:
continue
logging.debug("-- {}".format(ck))
d = copy.deepcopy(doc)
d["image"] = image
@@ -288,8 +291,10 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
d = copy.deepcopy(doc)
tokenize(d, rows, eng)
d["content_with_weight"] = rows
if img: d["image"] = img
if poss: add_positions(d, poss)
if img:
d["image"] = img
if poss:
add_positions(d, poss)
res.append(d)
continue
de = "; " if eng else "; "
@@ -387,9 +392,9 @@ def title_frequency(bull, sections):
if re.search(r"(title|head)", layout) and not not_title(txt.split("@")[0]):
levels[i] = bullets_size
most_level = bullets_size+1
for l, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
if l <= bullets_size:
most_level = l
for level, c in sorted(Counter(levels).items(), key=lambda x:x[1]*-1):
if level <= bullets_size:
most_level = level
break
return most_level, levels

@@ -504,7 +509,8 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if not pos: pos = ""
if not pos:
pos = ""
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num

+ 4
- 2
rag/nlp/query.py Visa fil

@@ -121,7 +121,8 @@ class FulltextQueryer:
keywords.append(tt)
twts = self.tw.weights([tt])
syns = self.syn.lookup(tt)
if syns and len(keywords) < 32: keywords.extend(syns)
if syns and len(keywords) < 32:
keywords.extend(syns)
logging.debug(json.dumps(twts, ensure_ascii=False))
tms = []
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
@@ -147,7 +148,8 @@ class FulltextQueryer:

tk_syns = self.syn.lookup(tk)
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
if len(keywords) < 32: keywords.extend([s for s in tk_syns if s])
if len(keywords) < 32:
keywords.extend([s for s in tk_syns if s])
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]


+ 2
- 8
rag/nlp/rag_tokenizer.py Visa fil

@@ -104,7 +104,6 @@ class RagTokenizer:
return HanziConv.toSimplified(line)

def dfs_(self, chars, s, preTks, tkslist):
MAX_L = 10
res = s
# if s > MAX_L or s>= len(chars):
if s >= len(chars):
@@ -184,12 +183,6 @@ class RagTokenizer:
return sorted(res, key=lambda x: x[1], reverse=True)

def merge_(self, tks):
patts = [
(r"[ ]+", " "),
(r"([0-9\+\.,%\*=-]) ([0-9\+\.,%\*=-])", r"\1\2"),
]
# for p,s in patts: tks = re.sub(p, s, tks)

# if split chars is part of token
res = []
tks = re.sub(r"[ ]+", " ", tks).split()
@@ -284,7 +277,8 @@ class RagTokenizer:
same = 0
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0: res.append(" ".join(tks[j: j + same]))
if same > 0:
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1

+ 3
- 3
rag/nlp/term_weight.py Visa fil

@@ -62,10 +62,10 @@ class Dealer:
res = {}
f = open(fnm, "r")
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
arr = l.replace("\n", "").split("\t")
arr = line.replace("\n", "").split("\t")
if len(arr) < 2:
res[arr[0]] = 0
else:

+ 4
- 2
rag/raptor.py Visa fil

@@ -47,7 +47,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __call__(self, chunks, random_state, callback=None):
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
if len(chunks) <= 1: return
if len(chunks) <= 1:
return
chunks = [(s, a) for s, a in chunks if len(a) > 0]

def summarize(ck_idx, lock):
@@ -66,7 +67,8 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt])
with lock:
if not len(embds[0]): return
if not len(embds[0]):
return
chunks.append((cnt, embds[0]))
except Exception as e:
logging.exception("summarize got exception")

+ 4
- 2
rag/svr/cache_file_svr.py Visa fil

@@ -33,14 +33,16 @@ def collect():

def main():
locations = collect()
if not locations:return
if not locations:
return
logging.info(f"TASKS: {len(locations)}")
for kb_id, loc in locations:
try:
if REDIS_CONN.is_alive():
try:
key = "{}/{}".format(kb_id, loc)
if REDIS_CONN.exist(key):continue
if REDIS_CONN.exist(key):
continue
file_bin = STORAGE_IMPL.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60)
logging.info("CACHE: {}".format(loc))

+ 9
- 8
rag/svr/task_executor.py Visa fil

@@ -23,18 +23,12 @@ import os

from api.utils.log_utils import initRootLogger

CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger(CONSUMER_NAME, LOG_LEVELS)

from datetime import datetime
import json
import os
import hashlib
import copy
import re
import sys
import time
import threading
from functools import partial
@@ -63,6 +57,11 @@ from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL

CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
LOG_LEVELS = os.environ.get("LOG_LEVELS", "")
initRootLogger(CONSUMER_NAME, LOG_LEVELS)

BATCH_SIZE = 64

FACTORY = {
@@ -201,7 +200,8 @@ def build_chunks(task, progress_callback):
"doc_id": task["doc_id"],
"kb_id": str(task["kb_id"])
}
if task["pagerank"]: doc["pagerank_fea"] = int(task["pagerank"])
if task["pagerank"]:
doc["pagerank_fea"] = int(task["pagerank"])
el = 0
for ck in cks:
d = copy.deepcopy(doc)
@@ -342,7 +342,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
"docnm_kwd": row["name"],
"title_tks": rag_tokenizer.tokenize(row["name"])
}
if row["pagerank"]: doc["pagerank_fea"] = int(row["pagerank"])
if row["pagerank"]:
doc["pagerank_fea"] = int(row["pagerank"])
res = []
tk_count = 0
for content, vctr in chunks[original_length:]:

+ 14
- 14
rag/utils/__init__.py Visa fil

@@ -41,15 +41,15 @@ def findMaxDt(fnm):
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
l = l.strip("\n")
if l == 'nan':
line = line.strip("\n")
if line == 'nan':
continue
if l > m:
m = l
except Exception as e:
if line > m:
m = line
except Exception:
pass
return m

@@ -59,15 +59,15 @@ def findMaxTm(fnm):
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:
line = f.readline()
if not line:
break
l = l.strip("\n")
if l == 'nan':
line = line.strip("\n")
if line == 'nan':
continue
if int(l) > m:
m = int(l)
except Exception as e:
if int(line) > m:
m = int(line)
except Exception:
pass
return m


+ 1
- 1
rag/utils/azure_sas_conn.py Visa fil

@@ -32,7 +32,7 @@ class RAGFlowAzureSasBlob(object):
self.conn = None

def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))

def put(self, bucket, fnm, binary):

+ 1
- 1
rag/utils/azure_spn_conn.py Visa fil

@@ -36,7 +36,7 @@ class RAGFlowAzureSpnBlob(object):
self.conn = None

def health(self):
bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
_bucket, fnm, binary = "txtxtxtxt1", "txtxtxtxt1", b"_t@@@1"
f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary))

+ 2
- 1
rag/utils/es_conn.py Visa fil

@@ -132,7 +132,8 @@ class ESConnection(DocStoreConnection):
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
continue
if not v: continue
if not v:
continue
if isinstance(v, list):
bqry.filter.append(Q("terms", **{k: v}))
elif isinstance(v, str) or isinstance(v, int):

+ 13
- 6
sdk/python/ragflow_sdk/__init__.py Visa fil

@@ -1,14 +1,21 @@
from beartype.claw import beartype_this_package
beartype_this_package() # <-- raise exceptions in your code

import importlib.metadata

__version__ = importlib.metadata.version("ragflow_sdk")

from .ragflow import RAGFlow
from .modules.dataset import DataSet
from .modules.chat import Chat
from .modules.session import Session
from .modules.document import Document
from .modules.chunk import Chunk
from .modules.agent import Agent
from .modules.agent import Agent

__version__ = importlib.metadata.version("ragflow_sdk")

__all__ = [
"RAGFlow",
"DataSet",
"Chat",
"Session",
"Document",
"Chunk",
"Agent"
]

+ 1
- 1
sdk/python/ragflow_sdk/modules/session.py Visa fil

@@ -29,7 +29,7 @@ class Session(Base):
raise Exception(json_data["message"])
if line.startswith("data:"):
json_data = json.loads(line[5:])
if json_data["data"] != True:
if not json_data["data"]:
answer = json_data["data"]["answer"]
reference = json_data["data"]["reference"]
temp_dict = {

+ 0
- 2
sdk/python/test/conftest.py Visa fil

@@ -1,5 +1,3 @@
import string
import random
import os
import pytest
import requests

+ 0
- 1
sdk/python/test/test_frontend_api/common.py Visa fil

@@ -39,7 +39,6 @@ def update_dataset(auth, json_req):
def upload_file(auth, dataset_id, path):
authorization = {"Authorization": auth}
url = f"{HOST_ADDRESS}/v1/document/upload"
base_name = os.path.basename(path)
json_req = {
"kb_id": dataset_id,
}

+ 1
- 1
sdk/python/test/test_frontend_api/get_email.py Visa fil

@@ -1,3 +1,3 @@
def test_get_email(get_email):
print(f"\nEmail account:",flush=True)
print("\nEmail account:",flush=True)
print(f"{get_email}\n",flush=True)

+ 1
- 5
sdk/python/test/test_frontend_api/test_chunk.py Visa fil

@@ -13,14 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, upload_file, DATASET_NAME_LIMIT
from common import create_dataset, list_dataset, rm_dataset, upload_file
from common import list_document, get_docs_info, parse_docs
from time import sleep
from timeit import default_timer as timer
import re
import pytest
import random
import string


def test_parse_txt_document(get_auth):

+ 2
- 5
sdk/python/test/test_frontend_api/test_dataset.py Visa fil

@@ -1,6 +1,5 @@
from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
from common import create_dataset, list_dataset, rm_dataset, update_dataset, DATASET_NAME_LIMIT
import re
import pytest
import random
import string

@@ -33,8 +32,6 @@ def test_dataset(get_auth):

def test_dataset_1k_dataset(get_auth):
# create dataset
authorization = {"Authorization": get_auth}
url = f"{HOST_ADDRESS}/v1/kb/create"
for i in range(1000):
res = create_dataset(get_auth, f"test_create_dataset_{i}")
assert res.get("code") == 0, f"{res.get('message')}"
@@ -76,7 +73,7 @@ def test_duplicated_name_dataset(get_auth):
dataset_id = item.get("id")
dataset_list.append(dataset_id)
match = re.match(pattern, dataset_name)
assert match != None
assert match is not None

for dataset_id in dataset_list:
res = rm_dataset(get_auth, dataset_id)

+ 1
- 1
sdk/python/test/test_sdk_api/get_email.py Visa fil

@@ -1,3 +1,3 @@
def test_get_email(get_email):
print(f"\nEmail account:",flush=True)
print("\nEmail account:",flush=True)
print(f"{get_email}\n",flush=True)

+ 1
- 1
sdk/python/test/test_sdk_api/t_agent.py Visa fil

@@ -1,4 +1,4 @@
from ragflow_sdk import RAGFlow,Agent
from ragflow_sdk import RAGFlow
from common import HOST_ADDRESS
import pytest


Laddar…
Avbryt
Spara