Browse Source

Rework logging (#3358)

Unified all log files into one.

### What problem does this PR solve?

Unified all log files into one.

### Type of change

- [x] Refactoring
tags/v0.14.0
Zhichang Yu 11 months ago
parent
commit
a2a5631da4
No account linked to committer's email address
75 changed files with 481 additions and 853 deletions
  1. 5
    7
      agent/canvas.py
  2. 2
    3
      agent/component/arxiv.py
  3. 2
    4
      agent/component/baidu.py
  4. 7
    7
      agent/component/base.py
  5. 2
    3
      agent/component/bing.py
  6. 3
    3
      agent/component/categorize.py
  7. 2
    2
      agent/component/duckduckgo.py
  8. 2
    2
      agent/component/github.py
  9. 3
    3
      agent/component/google.py
  10. 4
    4
      agent/component/googlescholar.py
  11. 2
    2
      agent/component/keyword.py
  12. 2
    2
      agent/component/pubmed.py
  13. 2
    1
      agent/component/relevant.py
  14. 2
    1
      agent/component/retrieval.py
  15. 2
    1
      agent/component/rewrite.py
  16. 2
    4
      agent/component/wikipedia.py
  17. 3
    2
      agent/component/yahoofinance.py
  18. 0
    16
      agent/settings.py
  19. 5
    10
      api/apps/__init__.py
  20. 2
    1
      api/apps/canvas_app.py
  21. 2
    1
      api/apps/llm_app.py
  22. 1
    1
      api/apps/sdk/dataset.py
  23. 7
    7
      api/apps/user_app.py
  24. 8
    11
      api/db/db_models.py
  25. 0
    6
      api/db/db_utils.py
  26. 14
    16
      api/db/init_data.py
  27. 0
    21
      api/db/operatioins.py
  28. 10
    13
      api/db/services/dialog_service.py
  29. 4
    4
      api/db/services/document_service.py
  30. 5
    4
      api/db/services/file_service.py
  31. 17
    17
      api/db/services/llm_service.py
  32. 14
    18
      api/ragflow_server.py
  33. 0
    17
      api/settings.py
  34. 4
    3
      api/utils/api_utils.py
  35. 25
    287
      api/utils/log_utils.py
  36. 25
    24
      deepdoc/parser/pdf_parser.py
  37. 9
    3
      deepdoc/parser/resume/entities/corporations.py
  38. 20
    15
      deepdoc/parser/resume/step_two.py
  39. 2
    2
      deepdoc/vision/operators.py
  40. 2
    1
      deepdoc/vision/recognizer.py
  41. 2
    1
      deepdoc/vision/seeit.py
  42. 5
    2
      deepdoc/vision/t_recognizer.py
  43. 3
    4
      graphrag/claim_extractor.py
  44. 4
    7
      graphrag/community_reports_extractor.py
  45. 3
    2
      graphrag/index.py
  46. 3
    3
      graphrag/mind_map_extractor.py
  47. 3
    3
      intergrations/chatgpt-on-wechat/plugins/ragflow_chat.py
  48. 2
    1
      rag/app/book.py
  49. 2
    2
      rag/app/email.py
  50. 3
    3
      rag/app/laws.py
  51. 3
    2
      rag/app/manual.py
  52. 9
    9
      rag/app/naive.py
  53. 2
    1
      rag/app/one.py
  54. 8
    7
      rag/app/paper.py
  55. 3
    3
      rag/app/qa.py
  56. 5
    6
      rag/app/resume.py
  57. 5
    4
      rag/llm/embedding_model.py
  58. 4
    3
      rag/llm/rerank_model.py
  59. 4
    3
      rag/nlp/__init__.py
  60. 23
    24
      rag/nlp/rag_tokenizer.py
  61. 8
    8
      rag/nlp/search.py
  62. 6
    6
      rag/nlp/synonym.py
  63. 5
    4
      rag/nlp/term_weight.py
  64. 5
    6
      rag/raptor.py
  65. 0
    27
      rag/settings.py
  66. 4
    5
      rag/svr/cache_file_svr.py
  67. 2
    1
      rag/svr/discord_svr.py
  68. 30
    36
      rag/svr/task_executor.py
  69. 13
    15
      rag/utils/azure_sas_conn.py
  70. 13
    15
      rag/utils/azure_spn_conn.py
  71. 27
    38
      rag/utils/es_conn.py
  72. 12
    13
      rag/utils/infinity_conn.py
  73. 15
    16
      rag/utils/minio_conn.py
  74. 4
    5
      rag/utils/redis_conn.py
  75. 18
    19
      rag/utils/s3_conn.py

+ 5
- 7
agent/canvas.py View File

# limitations under the License. # limitations under the License.
# #
import json import json
import traceback
from abc import ABC from abc import ABC
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from agent.component import component_class from agent.component import component_class
from agent.component.base import ComponentBase from agent.component.base import ComponentBase
from agent.settings import flow_logger, DEBUG

from api.utils.log_utils import logger


class Canvas(ABC): class Canvas(ABC):
""" """
if cpn.component_name == "Answer": if cpn.component_name == "Answer":
self.answer.append(c) self.answer.append(c)
else: else:
if DEBUG: print("RUN: ", c)
logger.debug(f"Canvas.prepare2run: {c}")
cpids = cpn.get_dependent_components() cpids = cpn.get_dependent_components()
if any([c not in self.path[-1] for c in cpids]): if any([c not in self.path[-1] for c in cpids]):
continue continue


prepare2run(self.components[self.path[-2][-1]]["downstream"]) prepare2run(self.components[self.path[-2][-1]]["downstream"])
while 0 <= ran < len(self.path[-1]): while 0 <= ran < len(self.path[-1]):
if DEBUG: print(ran, self.path)
logger.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran] cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id) cpn = self.get_component(cpn_id)
if not cpn["downstream"]: break if not cpn["downstream"]: break
self.get_component(p)["obj"].set_exception(e) self.get_component(p)["obj"].set_exception(e)
prepare2run([p]) prepare2run([p])
break break
traceback.print_exc()
logger.exception("Canvas.run got exception")
break break
continue continue


self.get_component(p)["obj"].set_exception(e) self.get_component(p)["obj"].set_exception(e)
prepare2run([p]) prepare2run([p])
break break
traceback.print_exc()
logger.exception("Canvas.run got exception")
break break


if self.answer: if self.answer:

+ 2
- 3
agent/component/arxiv.py View File

from abc import ABC from abc import ABC
import arxiv import arxiv
import pandas as pd import pandas as pd
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger


class ArXivParam(ComponentParamBase): class ArXivParam(ComponentParamBase):
""" """
return ArXiv.be_output("") return ArXiv.be_output("")


df = pd.DataFrame(arxiv_res) df = pd.DataFrame(arxiv_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug(f"df: {str(df)}")
return df return df

+ 2
- 4
agent/component/baidu.py View File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import random
from abc import ABC from abc import ABC
from functools import partial
import pandas as pd import pandas as pd
import requests import requests
import re import re
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger




class BaiduParam(ComponentParamBase): class BaiduParam(ComponentParamBase):
return Baidu.be_output("") return Baidu.be_output("")


df = pd.DataFrame(baidu_res) df = pd.DataFrame(baidu_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug(f"df: {str(df)}")
return df return df



+ 7
- 7
agent/component/base.py View File

import builtins import builtins
import json import json
import os import os
from copy import deepcopy
from functools import partial from functools import partial
from typing import List, Dict, Tuple, Union
from typing import Tuple, Union


import pandas as pd import pandas as pd


from agent import settings from agent import settings
from agent.settings import flow_logger, DEBUG
from api.utils.log_utils import logger



_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params"


def _warn_deprecated_param(self, param_name, descr): def _warn_deprecated_param(self, param_name, descr):
if self._deprecated_params_set.get(param_name): if self._deprecated_params_set.get(param_name):
flow_logger.warning(
logger.warning(
f"{descr} {param_name} is deprecated and ignored in this version." f"{descr} {param_name} is deprecated and ignored in this version."
) )


def _warn_to_deprecate_param(self, param_name, descr, new_param): def _warn_to_deprecate_param(self, param_name, descr, new_param):
if self._deprecated_params_set.get(param_name): if self._deprecated_params_set.get(param_name):
flow_logger.warning(
logger.warning(
f"{descr} {param_name} will be deprecated in future release; " f"{descr} {param_name} will be deprecated in future release; "
f"please use {new_param} instead." f"please use {new_param} instead."
) )
return cpnts return cpnts


def run(self, history, **kwargs): def run(self, history, **kwargs):
flow_logger.info("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
logger.info("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False),
json.dumps(kwargs, ensure_ascii=False))) json.dumps(kwargs, ensure_ascii=False)))
try: try:
res = self._run(history, **kwargs) res = self._run(history, **kwargs)
reversed_cpnts.extend(self._canvas.path[-2]) reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1]) reversed_cpnts.extend(self._canvas.path[-1])


if DEBUG: print(self.component_name, reversed_cpnts[::-1])
logger.debug(f"{self.component_name} {reversed_cpnts[::-1]}")
for u in reversed_cpnts[::-1]: for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":

+ 2
- 3
agent/component/bing.py View File

from abc import ABC from abc import ABC
import requests import requests
import pandas as pd import pandas as pd
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger


class BingParam(ComponentParamBase): class BingParam(ComponentParamBase):
""" """
return Bing.be_output("") return Bing.be_output("")


df = pd.DataFrame(bing_res) df = pd.DataFrame(bing_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug(f"df: {str(df)}")
return df return df

+ 3
- 3
agent/component/categorize.py View File

from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate from agent.component import GenerateParam, Generate
from agent.settings import DEBUG
from api.utils.log_utils import logger




class CategorizeParam(GenerateParam): class CategorizeParam(GenerateParam):
super().check() super().check()
self.check_empty(self.category_description, "[Categorize] Category examples") self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items(): for k, v in self.category_description.items():
if not k: raise ValueError(f"[Categorize] Category name can not be empty!")
if not k: raise ValueError("[Categorize] Category name can not be empty!")
if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!") if not v.get("to"): raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")


def get_prompt(self): def get_prompt(self):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": input}], ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": input}],
self._param.gen_conf()) self._param.gen_conf())
if DEBUG: print(ans, ":::::::::::::::::::::::::::::::::", input)
logger.debug(f"input: {input}, answer: {str(ans)}")
for c in self._param.category_description.keys(): for c in self._param.category_description.keys():
if ans.lower().find(c.lower()) >= 0: if ans.lower().find(c.lower()) >= 0:
return Categorize.be_output(self._param.category_description[c]["to"]) return Categorize.be_output(self._param.category_description[c]["to"])

+ 2
- 2
agent/component/duckduckgo.py View File

from abc import ABC from abc import ABC
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
import pandas as pd import pandas as pd
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger




class DuckDuckGoParam(ComponentParamBase): class DuckDuckGoParam(ComponentParamBase):
return DuckDuckGo.be_output("") return DuckDuckGo.be_output("")


df = pd.DataFrame(duck_res) df = pd.DataFrame(duck_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug("df: {df}")
return df return df

+ 2
- 2
agent/component/github.py View File

from abc import ABC from abc import ABC
import pandas as pd import pandas as pd
import requests import requests
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger




class GitHubParam(ComponentParamBase): class GitHubParam(ComponentParamBase):
return GitHub.be_output("") return GitHub.be_output("")


df = pd.DataFrame(github_res) df = pd.DataFrame(github_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug(f"df: {df}")
return df return df

+ 3
- 3
agent/component/google.py View File

from abc import ABC from abc import ABC
from serpapi import GoogleSearch from serpapi import GoogleSearch
import pandas as pd import pandas as pd
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger




class GoogleParam(ComponentParamBase): class GoogleParam(ComponentParamBase):
"hl": self._param.language, "num": self._param.top_n}) "hl": self._param.language, "num": self._param.top_n})
google_res = [{"content": '<a href="' + i["link"] + '">' + i["title"] + '</a> ' + i["snippet"]} for i in google_res = [{"content": '<a href="' + i["link"] + '">' + i["title"] + '</a> ' + i["snippet"]} for i in
client.get_dict()["organic_results"]] client.get_dict()["organic_results"]]
except Exception as e:
except Exception:
return Google.be_output("**ERROR**: Existing Unavailable Parameters!") return Google.be_output("**ERROR**: Existing Unavailable Parameters!")


if not google_res: if not google_res:
return Google.be_output("") return Google.be_output("")


df = pd.DataFrame(google_res) df = pd.DataFrame(google_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug(f"df: {df}")
return df return df

+ 4
- 4
agent/component/googlescholar.py View File

# #
from abc import ABC from abc import ABC
import pandas as pd import pandas as pd
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from scholarly import scholarly from scholarly import scholarly
from api.utils.log_utils import logger




class GoogleScholarParam(ComponentParamBase): class GoogleScholarParam(ComponentParamBase):
'pub_url'] + '"></a> ' + "\n author: " + ",".join(pub['bib']['author']) + '\n Abstract: ' + pub[ 'pub_url'] + '"></a> ' + "\n author: " + ",".join(pub['bib']['author']) + '\n Abstract: ' + pub[
'bib'].get('abstract', 'no abstract')}) 'bib'].get('abstract', 'no abstract')})


except StopIteration or Exception as e:
print("**ERROR** " + str(e))
except StopIteration or Exception:
logger.exception("GoogleScholar")
break break


if not scholar_res: if not scholar_res:
return GoogleScholar.be_output("") return GoogleScholar.be_output("")


df = pd.DataFrame(scholar_res) df = pd.DataFrame(scholar_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug(f"df: {df}")
return df return df

+ 2
- 2
agent/component/keyword.py View File

from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate from agent.component import GenerateParam, Generate
from agent.settings import DEBUG
from api.utils.log_utils import logger




class KeywordExtractParam(GenerateParam): class KeywordExtractParam(GenerateParam):
self._param.gen_conf()) self._param.gen_conf())


ans = re.sub(r".*keyword:", "", ans).strip() ans = re.sub(r".*keyword:", "", ans).strip()
if DEBUG: print(ans, ":::::::::::::::::::::::::::::::::")
logger.info(f"ans: {ans}")
return KeywordExtract.be_output(ans) return KeywordExtract.be_output(ans)

+ 2
- 2
agent/component/pubmed.py View File

import re import re
import pandas as pd import pandas as pd
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger




class PubMedParam(ComponentParamBase): class PubMedParam(ComponentParamBase):
return PubMed.be_output("") return PubMed.be_output("")


df = pd.DataFrame(pubmed_res) df = pd.DataFrame(pubmed_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug(f"df: {df}")
return df return df

+ 2
- 1
agent/component/relevant.py View File

from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate from agent.component import GenerateParam, Generate
from rag.utils import num_tokens_from_string, encoder from rag.utils import num_tokens_from_string, encoder
from api.utils.log_utils import logger




class RelevantParam(GenerateParam): class RelevantParam(GenerateParam):
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": ans}], ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": ans}],
self._param.gen_conf()) self._param.gen_conf())


print(ans, ":::::::::::::::::::::::::::::::::")
logger.info(ans)
if ans.lower().find("yes") >= 0: if ans.lower().find("yes") >= 0:
return Relevant.be_output(self._param.yes) return Relevant.be_output(self._param.yes)
if ans.lower().find("no") >= 0: if ans.lower().find("no") >= 0:

+ 2
- 1
agent/component/retrieval.py View File

from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api.settings import retrievaler
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger




class RetrievalParam(ComponentParamBase): class RetrievalParam(ComponentParamBase):
df = pd.DataFrame(kbinfos["chunks"]) df = pd.DataFrame(kbinfos["chunks"])
df["content"] = df["content_with_weight"] df["content"] = df["content_with_weight"]
del df["content_with_weight"] del df["content_with_weight"]
print(">>>>>>>>>>>>>>>>>>>>>>>>>>\n", query, df)
logger.debug("{} {}".format(query, df))
return df return df





+ 2
- 1
agent/component/rewrite.py View File

from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate from agent.component import GenerateParam, Generate
from api.utils.log_utils import logger




class RewriteQuestionParam(GenerateParam): class RewriteQuestionParam(GenerateParam):
self._canvas.history.pop() self._canvas.history.pop()
self._canvas.history.append(("user", ans)) self._canvas.history.append(("user", ans))


print(ans, ":::::::::::::::::::::::::::::::::")
logger.info(ans)
return RewriteQuestion.be_output(ans) return RewriteQuestion.be_output(ans)





+ 2
- 4
agent/component/wikipedia.py View File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import random
from abc import ABC from abc import ABC
from functools import partial
import wikipedia import wikipedia
import pandas as pd import pandas as pd
from agent.settings import DEBUG
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.log_utils import logger




class WikipediaParam(ComponentParamBase): class WikipediaParam(ComponentParamBase):
return Wikipedia.be_output("") return Wikipedia.be_output("")


df = pd.DataFrame(wiki_res) df = pd.DataFrame(wiki_res)
if DEBUG: print(df, ":::::::::::::::::::::::::::::::::")
logger.debug(f"df: {df}")
return df return df

+ 3
- 2
agent/component/yahoofinance.py View File

import pandas as pd import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
import yfinance as yf import yfinance as yf
from api.utils.log_utils import logger
class YahooFinanceParam(ComponentParamBase): class YahooFinanceParam(ComponentParamBase):
{"content": "quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n"}) {"content": "quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n"})
if self._param.news: if self._param.news:
yohoo_res.append({"content": "news:\n" + pd.DataFrame(msft.news).to_markdown() + "\n"}) yohoo_res.append({"content": "news:\n" + pd.DataFrame(msft.news).to_markdown() + "\n"})
except Exception as e:
print("**ERROR** " + str(e))
except Exception:
logger.exception("YahooFinance got exception")
if not yohoo_res: if not yohoo_res:
return YahooFinance.be_output("") return YahooFinance.be_output("")

+ 0
- 16
agent/settings.py View File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
# Logger
import os


from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger

DEBUG = 0
LoggerFactory.set_directory(
os.path.join(
get_project_base_directory(),
"logs",
"flow"))
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 30

flow_logger = getLogger("flow")
database_logger = getLogger("database")
FLOAT_ZERO = 1e-8 FLOAT_ZERO = 1e-8
PARAM_MAXDEPTH = 5 PARAM_MAXDEPTH = 5

+ 5
- 10
api/apps/__init__.py View File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import logging
import os import os
import sys import sys
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location


from flask_session import Session from flask_session import Session
from flask_login import LoginManager from flask_login import LoginManager
from api.settings import SECRET_KEY, stat_logger
from api.settings import API_VERSION, access_logger
from api.settings import SECRET_KEY
from api.settings import API_VERSION
from api.utils.api_utils import server_error_response from api.utils.api_utils import server_error_response
from api.utils.log_utils import logger
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer


__all__ = ["app"] __all__ = ["app"]



logger = logging.getLogger("flask.app")
for h in access_logger.handlers:
logger.addHandler(h)

Request.json = property(lambda self: self.get_json(force=True, silent=True)) Request.json = property(lambda self: self.get_json(force=True, silent=True))


app = Flask(__name__) app = Flask(__name__)
return user[0] return user[0]
else: else:
return None return None
except Exception as e:
stat_logger.exception(e)
except Exception:
logger.exception("load_user got exception")
return None return None
else: else:
return None return None

+ 2
- 1
api/apps/canvas_app.py View File

from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase from peewee import MySQLDatabase, PostgresqlDatabase
from api.utils.log_utils import logger




@manager.route('/templates', methods=['GET']) @manager.route('/templates', methods=['GET'])
pass pass
canvas.add_user_input(req["message"]) canvas.add_user_input(req["message"])
answer = canvas.run(stream=stream) answer = canvas.run(stream=stream)
print(canvas)
logger.info(canvas)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)



+ 2
- 1
api/apps/llm_app.py View File

from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
import requests import requests
from api.utils.log_utils import logger




@manager.route('/factories', methods=['GET']) @manager.route('/factories', methods=['GET'])
if len(arr) == 0 or tc == 0: if len(arr) == 0 or tc == 0:
raise Exception("Fail") raise Exception("Fail")
rerank_passed = True rerank_passed = True
print(f'passed model rerank{llm.llm_name}',flush=True)
logger.info(f'passed model rerank {llm.llm_name}')
except Exception as e: except Exception as e:
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str( msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
e) e)

+ 1
- 1
api/apps/sdk/dataset.py View File

new_key = key_mapping.get(key, key) new_key = key_mapping.get(key, key)
renamed_data[new_key] = value renamed_data[new_key] = value
renamed_list.append(renamed_data) renamed_list.append(renamed_data)
return get_result(data=renamed_list)
return get_result(data=renamed_list)

+ 7
- 7
api/apps/user_app.py View File

) )
from api.db.services.user_service import UserService, TenantService, UserTenantService from api.db.services.user_service import UserService, TenantService, UserTenantService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.settings import stat_logger
from api.utils.api_utils import get_json_result, construct_response from api.utils.api_utils import get_json_result, construct_response
from api.utils.log_utils import logger




@manager.route("/login", methods=["POST", "GET"]) @manager.route("/login", methods=["POST", "GET"])
try: try:
avatar = download_img(user_info["avatar_url"]) avatar = download_img(user_info["avatar_url"])
except Exception as e: except Exception as e:
stat_logger.exception(e)
logger.exception(e)
avatar = "" avatar = ""
users = user_register( users = user_register(
user_id, user_id,
return redirect("/?auth=%s" % user.get_id()) return redirect("/?auth=%s" % user.get_id())
except Exception as e: except Exception as e:
rollback_user_registration(user_id) rollback_user_registration(user_id)
stat_logger.exception(e)
logger.exception(e)
return redirect("/?error=%s" % str(e)) return redirect("/?error=%s" % str(e))


# User has already registered, try to log in # User has already registered, try to log in
try: try:
avatar = download_img(user_info["avatar_url"]) avatar = download_img(user_info["avatar_url"])
except Exception as e: except Exception as e:
stat_logger.exception(e)
logger.exception(e)
avatar = "" avatar = ""
users = user_register( users = user_register(
user_id, user_id,
return redirect("/?auth=%s" % user.get_id()) return redirect("/?auth=%s" % user.get_id())
except Exception as e: except Exception as e:
rollback_user_registration(user_id) rollback_user_registration(user_id)
stat_logger.exception(e)
logger.exception(e)
return redirect("/?error=%s" % str(e)) return redirect("/?error=%s" % str(e))


# User has already registered, try to log in # User has already registered, try to log in
UserService.update_by_id(current_user.id, update_dict) UserService.update_by_id(current_user.id, update_dict)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
stat_logger.exception(e)
logger.exception(e)
return get_json_result( return get_json_result(
data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
) )
) )
except Exception as e: except Exception as e:
rollback_user_registration(user_id) rollback_user_registration(user_id)
stat_logger.exception(e)
logger.exception(e)
return get_json_result( return get_json_result(
data=False, data=False,
message=f"User registration failure, error: {str(e)}", message=f"User registration failure, error: {str(e)}",

+ 8
- 11
api/db/db_models.py View File

) )
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType from api.db import SerializedType, ParserType
from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE
from api.utils.log_utils import getLogger
from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE
from api import utils from api import utils

LOGGER = getLogger()

from api.utils.log_utils import logger


def singleton(cls, *args, **kw): def singleton(cls, *args, **kw):
instances = {} instances = {}
database_config = DATABASE.copy() database_config = DATABASE.copy()
db_name = database_config.pop("name") db_name = database_config.pop("name")
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config) self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
stat_logger.info('init database on cluster mode successfully')
logger.info('init database on cluster mode successfully')


class PostgresDatabaseLock: class PostgresDatabaseLock:
def __init__(self, lock_name, timeout=10, db=None): def __init__(self, lock_name, timeout=10, db=None):
if DB: if DB:
DB.close_stale(age=30) DB.close_stale(age=30)
except Exception as e: except Exception as e:
LOGGER.exception(e)
logger.exception(e)




class DataBaseModel(BaseModel): class DataBaseModel(BaseModel):
for name, obj in members: for name, obj in members:
if obj != DataBaseModel and issubclass(obj, DataBaseModel): if obj != DataBaseModel and issubclass(obj, DataBaseModel):
table_objs.append(obj) table_objs.append(obj)
LOGGER.info(f"start create table {obj.__name__}")
logger.info(f"start create table {obj.__name__}")
try: try:
obj.create_table() obj.create_table()
LOGGER.info(f"create table success: {obj.__name__}")
logger.info(f"create table success: {obj.__name__}")
except Exception as e: except Exception as e:
LOGGER.exception(e)
logger.exception(e)
create_failed_list.append(obj.__name__) create_failed_list.append(obj.__name__)
if create_failed_list: if create_failed_list:
LOGGER.info(f"create tables failed: {create_failed_list}")
logger.info(f"create tables failed: {create_failed_list}")
raise Exception(f"create tables failed: {create_failed_list}") raise Exception(f"create tables failed: {create_failed_list}")
migrate_db() migrate_db()



+ 0
- 6
api/db/db_utils.py View File

from api.utils import current_timestamp, timestamp_to_date from api.utils import current_timestamp, timestamp_to_date


from api.db.db_models import DB, DataBaseModel from api.db.db_models import DB, DataBaseModel
from api.db.runtime_config import RuntimeConfig
from api.utils.log_utils import getLogger
from enum import Enum


LOGGER = getLogger()




@DB.connection_context() @DB.connection_context()

+ 14
- 16
api/db/init_data.py View File

from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import logger




def encode_to_base64(input_string): def encode_to_base64(input_string):
"api_key": API_KEY, "api_base": LLM_BASE_URL}) "api_key": API_KEY, "api_base": LLM_BASE_URL})


if not UserService.save(**user_info): if not UserService.save(**user_info):
print("\033[93m【ERROR】\033[0mcan't init admin.")
logger.info("can't init admin.")
return return
TenantService.insert(**tenant) TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant) UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm) TenantLLMService.insert_many(tenant_llm)
print(
"【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
logger.info(
"Super user initialized. email: admin@ragflow.io, password: admin. Changing the password after logining is strongly recomanded.")


chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[ msg = chat_mdl.chat(system="", history=[
{"role": "user", "content": "Hello!"}], gen_conf={}) {"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0: if msg.find("ERROR: ") == 0:
print(
"\33[91m【ERROR】\33[0m: ",
logger.error(
"'{}' dosen't work. {}".format( "'{}' dosen't work. {}".format(
tenant["llm_id"], tenant["llm_id"],
msg)) msg))
embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"]) embd_mdl = LLMBundle(tenant["id"], LLMType.EMBEDDING, tenant["embd_id"])
v, c = embd_mdl.encode(["Hello!"]) v, c = embd_mdl.encode(["Hello!"])
if c == 0: if c == 0:
print(
"\33[91m【ERROR】\33[0m:",
" '{}' dosen't work!".format(
logger.error(
"'{}' dosen't work!".format(
tenant["embd_id"])) tenant["embd_id"]))




def init_llm_factory(): def init_llm_factory():
try: try:
LLMService.filter_delete([(LLM.fid == "MiniMax" or LLM.fid == "Minimax")]) LLMService.filter_delete([(LLM.fid == "MiniMax" or LLM.fid == "Minimax")])
except Exception as e:
except Exception:
pass pass


factory_llm_infos = json.load( factory_llm_infos = json.load(
llm_infos = factory_llm_info.pop("llm") llm_infos = factory_llm_info.pop("llm")
try: try:
LLMFactoriesService.save(**factory_llm_info) LLMFactoriesService.save(**factory_llm_info)
except Exception as e:
except Exception:
pass pass
LLMService.filter_delete([LLM.fid == factory_llm_info["name"]]) LLMService.filter_delete([LLM.fid == factory_llm_info["name"]])
for llm_info in llm_infos: for llm_info in llm_infos:
llm_info["fid"] = factory_llm_info["name"] llm_info["fid"] = factory_llm_info["name"]
try: try:
LLMService.save(**llm_info) LLMService.save(**llm_info)
except Exception as e:
except Exception:
pass pass


LLMFactoriesService.filter_delete([LLMFactories.name == "Local"]) LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
row = deepcopy(row) row = deepcopy(row)
row["llm_name"] = "text-embedding-3-large" row["llm_name"] = "text-embedding-3-large"
TenantLLMService.save(**row) TenantLLMService.save(**row)
except Exception as e:
except Exception:
pass pass
break break
for kb_id in KnowledgebaseService.get_all_ids(): for kb_id in KnowledgebaseService.get_all_ids():
CanvasTemplateService.save(**cnvs) CanvasTemplateService.save(**cnvs)
except: except:
CanvasTemplateService.update_by_id(cnvs["id"], cnvs) CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
except Exception as e:
print("Add graph templates error: ", e)
print("------------", flush=True)
except Exception:
logger.exception("Add graph templates error: ")




def init_web_data(): def init_web_data():
# init_superuser() # init_superuser()


add_graph_templates() add_graph_templates()
print("init web data success:{}".format(time.time() - start_time))
logger.info("init web data success:{}".format(time.time() - start_time))




if __name__ == '__main__': if __name__ == '__main__':

+ 0
- 21
api/db/operatioins.py View File

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import operator
import time
import typing
from api.utils.log_utils import sql_logger
import peewee

+ 10
- 13
api/db/services/dialog_service.py View File

from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import chat_logger, retrievaler, kg_retrievaler
from api.settings import retrievaler, kg_retrievaler
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import logger




class DialogService(CommonService): class DialogService(CommonService):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
# try to use sql if field mapping is good to go # try to use sql if field mapping is good to go
if field_map: if field_map:
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
logger.info("Use SQL to retrieval:{}".format(questions[-1]))
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
if ans: if ans:
yield ans yield ans
doc_ids=attachments, doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info(
logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges))) "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_tm = timer() retrieval_tm = timer()


yield decorate_answer(answer) yield decorate_answer(answer)
else: else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf) answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
chat_logger.info("User: {}|Assistant: {}".format(
logger.info("User: {}|Assistant: {}".format(
msg[-1]["content"], answer)) msg[-1]["content"], answer))
res = decorate_answer(answer) res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer) res["audio_binary"] = tts(tts_mdl, answer)
nonlocal sys_prompt, user_promt, question, tried_times nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], { sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
"temperature": 0.06}) "temperature": 0.06})
print(user_promt, sql)
chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
logger.info(f"{question} ==> {user_promt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r"[\r\n]+", " ", sql.lower())
sql = re.sub(r".*select ", "select ", sql.lower()) sql = re.sub(r".*select ", "select ", sql.lower())
sql = re.sub(r" +", " ", sql) sql = re.sub(r" +", " ", sql)
flds.append(k) flds.append(k)
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]


print(f"“{question}” get SQL(refined): {sql}")

chat_logger.info(f"“{question}” get SQL(refined): {sql}")
logger.info(f"{question} get SQL(refined): {sql}")
tried_times += 1 tried_times += 1
return retrievaler.sql_retrieval(sql, format="json"), sql return retrievaler.sql_retrieval(sql, format="json"), sql


question, sql, tbl["error"] question, sql, tbl["error"]
) )
tbl, sql = get_table() tbl, sql = get_table()
chat_logger.info("TRY it again: {}".format(sql))
logger.info("TRY it again: {}".format(sql))


chat_logger.info("GET table: {}".format(tbl))
print(tbl)
logger.info("GET table: {}".format(tbl))
if tbl.get("error") or len(tbl["rows"]) == 0: if tbl.get("error") or len(tbl["rows"]) == 0:
return None return None


rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)


if not docid_idx or not docnm_idx: if not docid_idx or not docnm_idx:
chat_logger.warning("SQL missing field: " + sql)
logger.warning("SQL missing field: " + sql)
return { return {
"answer": "\n".join([clmns, line, rows]), "answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [], "doc_aggs": []}, "reference": {"chunks": [], "doc_aggs": []},

+ 4
- 4
api/db/services/document_service.py View File

import json import json
import random import random
import re import re
import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from peewee import fn from peewee import fn


from api.db.db_utils import bulk_insert_into_db from api.db.db_utils import bulk_insert_into_db
from api.settings import stat_logger, docStoreConn
from api.settings import docStoreConn
from api.utils import current_timestamp, get_format_time, get_uuid from api.utils import current_timestamp, get_format_time, get_uuid
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME from rag.settings import SVR_QUEUE_NAME
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db import StatusEnum from api.db import StatusEnum
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN
from api.utils.log_utils import logger




class DocumentService(CommonService): class DocumentService(CommonService):
cls.update_by_id(d["id"], info) cls.update_by_id(d["id"], info)
except Exception as e: except Exception as e:
if str(e).find("'0'") < 0: if str(e).find("'0'") < 0:
stat_logger.error("fetch task exception:" + str(e))
logger.exception("fetch task exception")


@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
"knowledge_graph_kwd": "mind_map" "knowledge_graph_kwd": "mind_map"
}) })
except Exception as e: except Exception as e:
stat_logger.error("Mind map generation error:", traceback.format_exc())
logger.exception("Mind map generation error")


vects = embedding(doc_id, [c["content_with_weight"] for c in cks]) vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vects) assert len(cks) == len(vects)

+ 5
- 4
api/db/services/file_service.py View File

from api.utils import get_uuid from api.utils import get_uuid
from api.utils.file_utils import filename_type, thumbnail_img from api.utils.file_utils import filename_type, thumbnail_img
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from api.utils.log_utils import logger




class FileService(CommonService): class FileService(CommonService):
cls.delete_folder_by_pf_id(user_id, file.id) cls.delete_folder_by_pf_id(user_id, file.id)
return cls.model.delete().where((cls.model.tenant_id == user_id) return cls.model.delete().where((cls.model.tenant_id == user_id)
& (cls.model.id == folder_id)).execute(), & (cls.model.id == folder_id)).execute(),
except Exception as e:
print(e)
except Exception:
logger.exception("delete_folder_by_pf_id")
raise RuntimeError("Database error (File retrieval)!") raise RuntimeError("Database error (File retrieval)!")


@classmethod @classmethod
def move_file(cls, file_ids, folder_id): def move_file(cls, file_ids, folder_id):
try: try:
cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id }) cls.filter_update((cls.model.id << file_ids, ), { 'parent_id': folder_id })
except Exception as e:
print(e)
except Exception:
logger.exception("move_file")
raise RuntimeError("Database error (File move)!") raise RuntimeError("Database error (File move)!")


@classmethod @classmethod

+ 17
- 17
api/db/services/llm_service.py View File

# limitations under the License. # limitations under the License.
# #
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api.db import LLMType from api.db import LLMType
from api.db.db_models import DB from api.db.db_models import DB
from api.db.db_models import LLMFactories, LLM, TenantLLM from api.db.db_models import LLMFactories, LLM, TenantLLM
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.utils.log_utils import logger




class LLMFactoriesService(CommonService): class LLMFactoriesService(CommonService):
emd, used_tokens = self.mdl.encode(texts, batch_size) emd, used_tokens = self.mdl.encode(texts, batch_size)
if not TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens): self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
logger.error(
"LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
return emd, used_tokens return emd, used_tokens


def encode_queries(self, query: str): def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query) emd, used_tokens = self.mdl.encode_queries(query)
if not TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens): self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
logger.error(
"LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens))
return emd, used_tokens return emd, used_tokens


def similarity(self, query: str, texts: list): def similarity(self, query: str, texts: list):
sim, used_tokens = self.mdl.similarity(query, texts) sim, used_tokens = self.mdl.similarity(query, texts)
if not TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens): self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
logger.error(
"LLMBundle.similarity can't update token usage for {}/RERANK used_tokens: {}".format(self.tenant_id, used_tokens))
return sim, used_tokens return sim, used_tokens


def describe(self, image, max_tokens=300): def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens) txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens): self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
logger.error(
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
return txt return txt


def transcription(self, audio): def transcription(self, audio):
txt, used_tokens = self.mdl.transcription(audio) txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens): self.tenant_id, self.llm_type, used_tokens):
database_logger.error(
"Can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
logger.error(
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
return txt return txt


def tts(self, text): def tts(self, text):
if isinstance(chunk,int): if isinstance(chunk,int):
if not TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, chunk, self.llm_name): self.tenant_id, self.llm_type, chunk, self.llm_name):
database_logger.error(
"Can't update token usage for {}/TTS".format(self.tenant_id))
logger.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return return
yield chunk yield chunk


txt, used_tokens = self.mdl.chat(system, history, gen_conf) txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if isinstance(txt, int) and not TenantLLMService.increase_usage( if isinstance(txt, int) and not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens, self.llm_name): self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error(
"Can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
logger.error(
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
return txt return txt


def chat_streamly(self, system, history, gen_conf): def chat_streamly(self, system, history, gen_conf):
if isinstance(txt, int): if isinstance(txt, int):
if not TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, txt, self.llm_name): self.tenant_id, self.llm_type, txt, self.llm_name):
database_logger.error(
"Can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
logger.error(
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
return return
yield txt yield txt

+ 14
- 18
api/ragflow_server.py View File

from api.db.runtime_config import RuntimeConfig from api.db.runtime_config import RuntimeConfig
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import ( from api.settings import (
HOST,
HTTP_PORT,
access_logger,
database_logger,
stat_logger,
HOST, HTTP_PORT
) )
from api import utils from api import utils
from api.utils.log_utils import logger


from api.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
from api.db.init_data import init_web_data from api.db.init_data import init_web_data
time.sleep(3) time.sleep(3)
try: try:
DocumentService.update_progress() DocumentService.update_progress()
except Exception as e:
stat_logger.error("update_progress exception:" + str(e))
except Exception:
logger.exception("update_progress exception")




if __name__ == "__main__":
print(
r"""
if __name__ == '__main__':
logger.info(r"""
____ ___ ______ ______ __ ____ ___ ______ ______ __
/ __ \ / | / ____// ____// /____ _ __ / __ \ / | / ____// ____// /____ _ __
/ /_/ // /| | / / __ / /_ / // __ \| | /| / / / /_/ // /| | / / __ / /_ / // __ \| | /| / /
/ _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ / / _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /
/_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/ /_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/


""",
flush=True,
""")
logger.info(
f'project base: {utils.file_utils.get_project_base_directory()}'
) )
stat_logger.info(f"project base: {utils.file_utils.get_project_base_directory()}")


# init db # init db
init_web_db() init_web_db()


RuntimeConfig.DEBUG = args.debug RuntimeConfig.DEBUG = args.debug
if RuntimeConfig.DEBUG: if RuntimeConfig.DEBUG:
stat_logger.info("run on debug mode")
logger.info("run on debug mode")


RuntimeConfig.init_env() RuntimeConfig.init_env()
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
peewee_logger = logging.getLogger("peewee") peewee_logger = logging.getLogger("peewee")
peewee_logger.propagate = False peewee_logger.propagate = False
# rag_arch.common.log.ROpenHandler # rag_arch.common.log.ROpenHandler
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
peewee_logger.addHandler(logger.handlers[0])
peewee_logger.setLevel(logger.handlers[0].level)


thr = ThreadPoolExecutor(max_workers=1) thr = ThreadPoolExecutor(max_workers=1)
thr.submit(update_progress) thr.submit(update_progress)


# start http server # start http server
try: try:
stat_logger.info("RAG Flow http server start...")
logger.info("RAG Flow http server start...")
werkzeug_logger = logging.getLogger("werkzeug") werkzeug_logger = logging.getLogger("werkzeug")
for h in access_logger.handlers:
for h in logger.handlers:
werkzeug_logger.addHandler(h) werkzeug_logger.addHandler(h)
run_simple( run_simple(
hostname=HOST, hostname=HOST,

+ 0
- 17
api/settings.py View File

from datetime import date from datetime import date
from enum import IntEnum, Enum from enum import IntEnum, Enum
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger
import rag.utils.es_conn import rag.utils.es_conn
import rag.utils.infinity_conn import rag.utils.infinity_conn


# Logger
LoggerFactory.set_directory(
os.path.join(
get_project_base_directory(),
"logs",
"api"))
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 30

stat_logger = getLogger("stat")
access_logger = getLogger("access")
database_logger = getLogger("database")
chat_logger = getLogger("chat")

import rag.utils import rag.utils
from rag.nlp import search from rag.nlp import search
from graphrag import search as kg_search from graphrag import search as kg_search
RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf") RAG_FLOW_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
LIGHTEN = int(os.environ.get('LIGHTEN', "0")) LIGHTEN = int(os.environ.get('LIGHTEN', "0"))


SUBPROCESS_STD_LOG_NAME = "std.log"

ERROR_REPORT = True ERROR_REPORT = True
ERROR_REPORT_WITH_PATH = False ERROR_REPORT_WITH_PATH = False



+ 4
- 3
api/utils/api_utils.py View File

from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.settings import ( from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
stat_logger, CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
) )
from api.settings import RetCode from api.settings import RetCode
from api.utils import CustomJSONEncoder, get_uuid from api.utils import CustomJSONEncoder, get_uuid
from api.utils import json_dumps from api.utils import json_dumps
from api.utils.log_utils import logger


requests.models.complexjson.dumps = functools.partial( requests.models.complexjson.dumps = functools.partial(
json.dumps, cls=CustomJSONEncoder) json.dumps, cls=CustomJSONEncoder)




def server_error_response(e): def server_error_response(e):
stat_logger.exception(e)
logger.exception(e)
try: try:
if e.code == 401: if e.code == 401:
return get_json_result(code=401, message=repr(e)) return get_json_result(code=401, message=repr(e))




def construct_error_response(e): def construct_error_response(e):
stat_logger.exception(e)
logger.exception(e)
try: try:
if e.code == 401: if e.code == 401:
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))

+ 25
- 287
api/utils/log_utils.py View File

# limitations under the License. # limitations under the License.
# #
import os import os
import typing
import traceback
import logging import logging
import inspect
from logging.handlers import TimedRotatingFileHandler
from threading import RLock
from logging.handlers import RotatingFileHandler


from api.utils import file_utils
from api.utils.file_utils import get_project_base_directory


LOG_LEVEL = logging.INFO
LOG_FILE = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"ragflow_{os.getpid()}.log"))
LOG_FORMAT = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"
logger = None


class LoggerFactory(object):
TYPE = "FILE"
LOG_FORMAT = "[%(levelname)s] [%(asctime)s] [%(module)s.%(funcName)s] [line:%(lineno)d]: %(message)s"
logging.basicConfig(format=LOG_FORMAT)
LEVEL = logging.DEBUG
logger_dict = {}
global_handler_dict = {}

LOG_DIR = None
PARENT_LOG_DIR = None
log_share = True

append_to_parent_log = None

lock = RLock()
# CRITICAL = 50
# FATAL = CRITICAL
# ERROR = 40
# WARNING = 30
# WARN = WARNING
# INFO = 20
# DEBUG = 10
# NOTSET = 0
levels = (10, 20, 30, 40)
schedule_logger_dict = {}

@staticmethod
def set_directory(directory=None, parent_log_dir=None,
append_to_parent_log=None, force=False):
if parent_log_dir:
LoggerFactory.PARENT_LOG_DIR = parent_log_dir
if append_to_parent_log:
LoggerFactory.append_to_parent_log = append_to_parent_log
with LoggerFactory.lock:
if not directory:
directory = file_utils.get_project_base_directory("logs")
if not LoggerFactory.LOG_DIR or force:
LoggerFactory.LOG_DIR = directory
if LoggerFactory.log_share:
oldmask = os.umask(000)
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
os.umask(oldmask)
else:
os.makedirs(LoggerFactory.LOG_DIR, exist_ok=True)
for loggerName, ghandler in LoggerFactory.global_handler_dict.items():
for className, (logger,
handler) in LoggerFactory.logger_dict.items():
logger.removeHandler(ghandler)
ghandler.close()
LoggerFactory.global_handler_dict = {}
for className, (logger,
handler) in LoggerFactory.logger_dict.items():
logger.removeHandler(handler)
_handler = None
if handler:
handler.close()
if className != "default":
_handler = LoggerFactory.get_handler(className)
logger.addHandler(_handler)
LoggerFactory.assemble_global_handler(logger)
LoggerFactory.logger_dict[className] = logger, _handler

@staticmethod
def new_logger(name):
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(LoggerFactory.LEVEL)
def getLogger():
global logger
if logger is not None:
return logger return logger


@staticmethod
def get_logger(class_name=None):
with LoggerFactory.lock:
if class_name in LoggerFactory.logger_dict.keys():
logger, handler = LoggerFactory.logger_dict[class_name]
if not logger:
logger, handler = LoggerFactory.init_logger(class_name)
else:
logger, handler = LoggerFactory.init_logger(class_name)
return logger

@staticmethod
def get_global_handler(logger_name, level=None, log_dir=None):
if not LoggerFactory.LOG_DIR:
return logging.StreamHandler()
if log_dir:
logger_name_key = logger_name + "_" + log_dir
else:
logger_name_key = logger_name + "_" + LoggerFactory.LOG_DIR
# if loggerName not in LoggerFactory.globalHandlerDict:
if logger_name_key not in LoggerFactory.global_handler_dict:
with LoggerFactory.lock:
if logger_name_key not in LoggerFactory.global_handler_dict:
handler = LoggerFactory.get_handler(
logger_name, level, log_dir)
LoggerFactory.global_handler_dict[logger_name_key] = handler
return LoggerFactory.global_handler_dict[logger_name_key]

@staticmethod
def get_handler(class_name, level=None, log_dir=None,
log_type=None, job_id=None):
if not log_type:
if not LoggerFactory.LOG_DIR or not class_name:
return logging.StreamHandler()
# return Diy_StreamHandler()

if not log_dir:
log_file = os.path.join(
LoggerFactory.LOG_DIR,
"{}.log".format(class_name))
else:
log_file = os.path.join(log_dir, "{}.log".format(class_name))
else:
log_file = os.path.join(log_dir, "rag_flow_{}.log".format(
log_type) if level == LoggerFactory.LEVEL else 'rag_flow_{}_error.log'.format(log_type))

os.makedirs(os.path.dirname(log_file), exist_ok=True)
if LoggerFactory.log_share:
handler = ROpenHandler(log_file,
when='D',
interval=1,
backupCount=14,
delay=True)
else:
handler = TimedRotatingFileHandler(log_file,
when='D',
interval=1,
backupCount=14,
delay=True)
if level:
handler.level = level

return handler

@staticmethod
def init_logger(class_name):
with LoggerFactory.lock:
logger = LoggerFactory.new_logger(class_name)
handler = None
if class_name:
handler = LoggerFactory.get_handler(class_name)
logger.addHandler(handler)
LoggerFactory.logger_dict[class_name] = logger, handler

else:
LoggerFactory.logger_dict["default"] = logger, handler

LoggerFactory.assemble_global_handler(logger)
return logger, handler

@staticmethod
def assemble_global_handler(logger):
if LoggerFactory.LOG_DIR:
for level in LoggerFactory.levels:
if level >= LoggerFactory.LEVEL:
level_logger_name = logging._levelToName[level]
logger.addHandler(
LoggerFactory.get_global_handler(
level_logger_name, level))
if LoggerFactory.append_to_parent_log and LoggerFactory.PARENT_LOG_DIR:
for level in LoggerFactory.levels:
if level >= LoggerFactory.LEVEL:
level_logger_name = logging._levelToName[level]
logger.addHandler(
LoggerFactory.get_global_handler(level_logger_name, level, LoggerFactory.PARENT_LOG_DIR))


def setDirectory(directory=None):
LoggerFactory.set_directory(directory)


def setLevel(level):
LoggerFactory.LEVEL = level


def getLogger(className=None, useLevelFile=False):
if className is None:
frame = inspect.stack()[1]
module = inspect.getmodule(frame[0])
className = 'stat'
return LoggerFactory.get_logger(className)

print(f"log file path: {LOG_FILE}")
os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)
logger = logging.getLogger("ragflow")
logger.setLevel(LOG_LEVEL)


def exception_to_trace_string(ex):
return "".join(traceback.TracebackException.from_exception(ex).format())
handler1 = RotatingFileHandler(LOG_FILE, maxBytes=10*1024*1024, backupCount=5)
handler1.setLevel(LOG_LEVEL)
formatter1 = logging.Formatter(LOG_FORMAT)
handler1.setFormatter(formatter1)
logger.addHandler(handler1)


handler2 = logging.StreamHandler()
handler2.setLevel(LOG_LEVEL)
formatter2 = logging.Formatter(LOG_FORMAT)
handler2.setFormatter(formatter2)
logger.addHandler(handler2)


class ROpenHandler(TimedRotatingFileHandler):
def _open(self):
prevumask = os.umask(000)
rtv = TimedRotatingFileHandler._open(self)
os.umask(prevumask)
return rtv


def sql_logger(job_id='', log_type='sql'):
key = job_id + log_type
if key in LoggerFactory.schedule_logger_dict.keys():
return LoggerFactory.schedule_logger_dict[key]
return get_job_logger(job_id=job_id, log_type=log_type)


def ready_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} ready{suffix}"


def start_log(msg, job=None, task=None, role=None, party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}start to {msg}{suffix}"


def successful_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} successfully{suffix}"


def warning_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}{msg} is not effective{suffix}"


def failed_log(msg, job=None, task=None, role=None,
party_id=None, detail=None):
prefix, suffix = base_msg(job, task, role, party_id, detail)
return f"{prefix}failed to {msg}{suffix}"


def base_msg(job=None, task=None, role: str = None,
party_id: typing.Union[str, int] = None, detail=None):
if detail:
detail_msg = f" detail: \n{detail}"
else:
detail_msg = ""
if task is not None:
return f"task {task.f_task_id} {task.f_task_version} ", f" on {task.f_role} {task.f_party_id}{detail_msg}"
elif job is not None:
return "", f" on {job.f_role} {job.f_party_id}{detail_msg}"
elif role and party_id:
return "", f" on {role} {party_id}{detail_msg}"
else:
return "", f"{detail_msg}"


def exception_to_trace_string(ex):
return "".join(traceback.TracebackException.from_exception(ex).format())


def get_logger_base_dir():
job_log_dir = file_utils.get_rag_flow_directory('logs')
return job_log_dir


def get_job_logger(job_id, log_type):
rag_flow_log_dir = file_utils.get_rag_flow_directory('logs', 'rag_flow')
job_log_dir = file_utils.get_rag_flow_directory('logs', job_id)
if not job_id:
log_dirs = [rag_flow_log_dir]
else:
if log_type == 'audit':
log_dirs = [job_log_dir, rag_flow_log_dir]
else:
log_dirs = [job_log_dir]
if LoggerFactory.log_share:
oldmask = os.umask(000)
os.makedirs(job_log_dir, exist_ok=True)
os.makedirs(rag_flow_log_dir, exist_ok=True)
os.umask(oldmask)
else:
os.makedirs(job_log_dir, exist_ok=True)
os.makedirs(rag_flow_log_dir, exist_ok=True)
logger = LoggerFactory.new_logger(f"{job_id}_{log_type}")
for job_log_dir in log_dirs:
handler = LoggerFactory.get_handler(class_name=None, level=LoggerFactory.LEVEL,
log_dir=job_log_dir, log_type=log_type, job_id=job_id)
error_handler = LoggerFactory.get_handler(
class_name=None,
level=logging.ERROR,
log_dir=job_log_dir,
log_type=log_type,
job_id=job_id)
logger.addHandler(handler)
logger.addHandler(error_handler)
with LoggerFactory.lock:
LoggerFactory.schedule_logger_dict[job_id + log_type] = logger
return logger return logger

logger = getLogger()

+ 25
- 24
deepdoc/parser/pdf_parser.py View File

import re import re
import pdfplumber import pdfplumber
import logging import logging
from PIL import Image, ImageDraw
from PIL import Image
import numpy as np import numpy as np
from timeit import default_timer as timer from timeit import default_timer as timer
from pypdf import PdfReader as pdf2_read from pypdf import PdfReader as pdf2_read


from api.settings import LIGHTEN from api.settings import LIGHTEN
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import logger
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from copy import deepcopy from copy import deepcopy
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
self.updown_cnt_mdl.set_param({"device": "cuda"}) self.updown_cnt_mdl.set_param({"device": "cuda"})
except Exception as e:
logging.error(str(e))
except Exception:
logger.exception("RAGFlowPdfParser __init__")
try: try:
model_dir = os.path.join( model_dir = os.path.join(
get_project_base_directory(), get_project_base_directory(),
"rag/res/deepdoc") "rag/res/deepdoc")
self.updown_cnt_mdl.load_model(os.path.join( self.updown_cnt_mdl.load_model(os.path.join(
model_dir, "updown_concat_xgb.model")) model_dir, "updown_concat_xgb.model"))
except Exception as e:
except Exception:
model_dir = snapshot_download( model_dir = snapshot_download(
repo_id="InfiniFlow/text_concat_xgb_v1.0", repo_id="InfiniFlow/text_concat_xgb_v1.0",
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
return True return True


def _table_transformer_job(self, ZM): def _table_transformer_job(self, ZM):
logging.info("Table processing...")
logger.info("Table processing...")
imgs, pos = [], [] imgs, pos = [], []
tbcnt = [0] tbcnt = [0]
MARGIN = 10 MARGIN = 10
detach_feats = [b["x1"] < b_["x0"], detach_feats = [b["x1"] < b_["x0"],
b["x0"] > b_["x1"]] b["x0"] > b_["x1"]]
if (any(feats) and not any(concatting_feats)) or any(detach_feats): if (any(feats) and not any(concatting_feats)) or any(detach_feats):
print(
logger.info("{} {} {} {}".format(
b["text"], b["text"],
b_["text"], b_["text"],
any(feats), any(feats),
any(concatting_feats), any(concatting_feats),
any(detach_feats))
))
i += 1 i += 1
continue continue
# merge up and down # merge up and down
# continue # continue
if tv < fv and tk: if tv < fv and tk:
tables[tk].insert(0, c) tables[tk].insert(0, c)
logging.debug(
logger.debug(
"TABLE:" + "TABLE:" +
self.boxes[i]["text"] + self.boxes[i]["text"] +
"; Cap: " + "; Cap: " +
tk) tk)
elif fk: elif fk:
figures[fk].insert(0, c) figures[fk].insert(0, c)
logging.debug(
logger.debug(
"FIGURE:" + "FIGURE:" +
self.boxes[i]["text"] + self.boxes[i]["text"] +
"; Cap: " + "; Cap: " +
if ii is not None: if ii is not None:
b = louts[ii] b = louts[ii]
else: else:
logging.warn(
logger.warn(
f"Missing layout match: {pn + 1},%s" % f"Missing layout match: {pn + 1},%s" %
(bxs[0].get( (bxs[0].get(
"layoutno", ""))) "layoutno", "")))
if usefull(boxes[0]): if usefull(boxes[0]):
dfs(boxes[0], 0) dfs(boxes[0], 0)
else: else:
logging.debug("WASTE: " + boxes[0]["text"])
except Exception as e:
logger.debug("WASTE: " + boxes[0]["text"])
except Exception:
pass pass
boxes.pop(0) boxes.pop(0)
mw = np.mean(widths) mw = np.mean(widths)
res.append( res.append(
"\n".join([c["text"] + self._line_tag(c, ZM) for c in lines])) "\n".join([c["text"] + self._line_tag(c, ZM) for c in lines]))
else: else:
logging.debug("REMOVED: " +
logger.debug("REMOVED: " +
"<<".join([c["text"] for c in lines])) "<<".join([c["text"] for c in lines]))


return "\n\n".join(res) return "\n\n".join(res)
pdf = pdfplumber.open( pdf = pdfplumber.open(
fnm) if not binary else pdfplumber.open(BytesIO(binary)) fnm) if not binary else pdfplumber.open(BytesIO(binary))
return len(pdf.pages) return len(pdf.pages)
except Exception as e:
logging.error(str(e))
except Exception:
logger.exception("total_page_number")


def __images__(self, fnm, zoomin=3, page_from=0, def __images__(self, fnm, zoomin=3, page_from=0,
page_to=299, callback=None): page_to=299, callback=None):
self.page_chars = [[{**c, 'top': c['top'], 'bottom': c['bottom']} for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.page_chars = [[{**c, 'top': c['top'], 'bottom': c['bottom']} for c in page.dedupe_chars().chars if self._has_color(c)] for page in
self.pdf.pages[page_from:page_to]] self.pdf.pages[page_from:page_to]]
self.total_page = len(self.pdf.pages) self.total_page = len(self.pdf.pages)
except Exception as e:
logging.error(str(e))
except Exception:
logger.exception("RAGFlowPdfParser __images__")


self.outlines = [] self.outlines = []
try: try:


dfs(outlines, 0) dfs(outlines, 0)
except Exception as e: except Exception as e:
logging.warning(f"Outlines exception: {e}")
logger.warning(f"Outlines exception: {e}")
if not self.outlines: if not self.outlines:
logging.warning(f"Miss outlines")
logger.warning("Miss outlines")


logging.info("Images converted.")
logger.info("Images converted.")
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
range(len(self.page_chars))] range(len(self.page_chars))]
self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}",
"".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))])) "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))


logging.info("Is it English:", self.is_english)
logger.info("Is it English:", self.is_english)


self.page_cum_height = np.cumsum(self.page_cum_height) self.page_cum_height = np.cumsum(self.page_cum_height)
assert len(self.page_cum_height) == len(self.page_images) + 1 assert len(self.page_cum_height) == len(self.page_images) + 1
dfs(a, depth + 1) dfs(a, depth + 1)


dfs(outlines, 0) dfs(outlines, 0)
except Exception as e:
logging.warning(f"Outlines exception: {e}")
except Exception:
logger.exception("Outlines exception")
if not self.outlines: if not self.outlines:
logging.warning(f"Miss outlines")
logger.warning("Miss outlines")


return [(l, "") for l in lines], [] return [(l, "") for l in lines], []



+ 9
- 3
deepdoc/parser/resume/entities/corporations.py View File

# limitations under the License. # limitations under the License.
# #


import re,json,os
import re
import json
import os
import pandas as pd import pandas as pd
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from . import regions from . import regions
from api.utils.log_utils import logger


current_file_path = os.path.dirname(os.path.abspath(__file__)) current_file_path = os.path.dirname(os.path.abspath(__file__))
GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0) GOODS = pd.read_csv(os.path.join(current_file_path, "res/corp_baike_len.csv"), sep="\t", header=0).fillna(0)
GOODS["cid"] = GOODS["cid"].astype(str) GOODS["cid"] = GOODS["cid"].astype(str)
global GOODS global GOODS
try: try:
return GOODS.loc[str(cid), "len"] return GOODS.loc[str(cid), "len"]
except Exception as e:
except Exception:
pass pass
return default_v return default_v


GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP]) GOOD_CORP = set([corpNorm(rmNoise(c), False) for c in GOOD_CORP])
for c,v in CORP_TAG.items(): for c,v in CORP_TAG.items():
cc = corpNorm(rmNoise(c), False) cc = corpNorm(rmNoise(c), False)
if not cc: print (c)
if not cc:
logger.info(c)
CORP_TAG = {corpNorm(rmNoise(c), False):v for c,v in CORP_TAG.items()} CORP_TAG = {corpNorm(rmNoise(c), False):v for c,v in CORP_TAG.items()}


def is_good(nm): def is_good(nm):

+ 20
- 15
deepdoc/parser/resume/step_two.py View File

# limitations under the License. # limitations under the License.
# #


import re, copy, time, datetime, demjson3, \
traceback, signal
import re
import copy
import time
import datetime
import demjson3
import traceback
import signal
import numpy as np import numpy as np
from deepdoc.parser.resume.entities import degrees, schools, corporations from deepdoc.parser.resume.entities import degrees, schools, corporations
from rag.nlp import rag_tokenizer, surname from rag.nlp import rag_tokenizer, surname
from xpinyin import Pinyin from xpinyin import Pinyin
from contextlib import contextmanager from contextlib import contextmanager
from api.utils.log_utils import logger




class TimeoutException(Exception): pass class TimeoutException(Exception): pass
y, m, d = getYMD(dt) y, m, d = getYMD(dt)
st_dt.append(str(y)) st_dt.append(str(y))
e["start_dt_kwd"] = str(y) e["start_dt_kwd"] = str(y)
except Exception as e:
except Exception:
pass pass


r = schools.select(n.get("school_name", "")) r = schools.select(n.get("school_name", ""))
y, m, d = getYMD(edu_end_dt) y, m, d = getYMD(edu_end_dt)
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000)) cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
except Exception as e: except Exception as e:
print("EXCEPTION: ", e, edu_end_dt, cv.get("work_exp_flt"))
logger.exception("forEdu {} {} {}".format(e, edu_end_dt, cv.get("work_exp_flt")))
if sch: if sch:
cv["school_name_kwd"] = sch cv["school_name_kwd"] = sch
if (len(cv.get("degree_kwd", [])) >= 1 and "本科" in cv["degree_kwd"]) \ if (len(cv.get("degree_kwd", [])) >= 1 and "本科" in cv["degree_kwd"]) \
if type(n) == type(""): if type(n) == type(""):
try: try:
n = json_loads(n) n = json_loads(n)
except Exception as e:
except Exception:
continue continue


if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"] if n.get("start_time") and (not work_st_tm or n["start_time"] < work_st_tm): work_st_tm = n["start_time"]


try: try:
duas.append((datetime.datetime.strptime(ed, "%Y-%m-%d") - datetime.datetime.strptime(st, "%Y-%m-%d")).days) duas.append((datetime.datetime.strptime(ed, "%Y-%m-%d") - datetime.datetime.strptime(st, "%Y-%m-%d")).days)
except Exception as e:
print("kkkkkkkkkkkkkkkkkkkk", n.get("start_time"), n.get("end_time"))
except Exception:
logger.exception("forWork {} {}".format(n.get("start_time"), n.get("end_time")))


if n.get("scale"): if n.get("scale"):
r = re.search(r"^([0-9]+)", str(n["scale"])) r = re.search(r"^([0-9]+)", str(n["scale"]))
y, m, d = getYMD(work_st_tm) y, m, d = getYMD(work_st_tm)
cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000)) cv["work_exp_flt"] = min(int(str(datetime.date.today())[0:4]) - int(y), cv.get("work_exp_flt", 1000))
except Exception as e: except Exception as e:
print("EXCEPTION: ", e, work_st_tm, cv.get("work_exp_flt"))
logger.exception("forWork {} {} {}".format(e, work_st_tm, cv.get("work_exp_flt")))


cv["job_num_int"] = 0 cv["job_num_int"] = 0
if duas: if duas:
t = k[:-4] t = k[:-4]
cv[f"{t}_kwd"] = nms cv[f"{t}_kwd"] = nms
cv[f"{t}_tks"] = rag_tokenizer.tokenize(" ".join(nms)) cv[f"{t}_tks"] = rag_tokenizer.tokenize(" ".join(nms))
except Exception as e:
print("【EXCEPTION】:", str(traceback.format_exc()), cv[k])
except Exception:
logger.exception("parse {} {}".format(str(traceback.format_exc()), cv[k]))
cv[k] = [] cv[k] = []


# tokenize fields # tokenize fields
if not y: y = "2012" if not y: y = "2012"
if not m: m = "01" if not m: m = "01"
if not d: d = "01" if not d: d = "01"
cv["updated_at_dt"] = f"%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
cv["updated_at_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
# long text tokenize # long text tokenize


if cv.get("responsibilities"): cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"])) if cv.get("responsibilities"): cv["responsibilities_ltks"] = rag_tokenizer.tokenize(rmHtmlTag(cv["responsibilities"]))
cv["work_exp_flt"] = (time.time() - int(int(cv["work_start_time"]) / 1000)) / 3600. / 24. / 365. cv["work_exp_flt"] = (time.time() - int(int(cv["work_start_time"]) / 1000)) / 3600. / 24. / 365.
elif re.match(r"[0-9]{4}[^0-9]", str(cv["work_start_time"])): elif re.match(r"[0-9]{4}[^0-9]", str(cv["work_start_time"])):
y, m, d = getYMD(str(cv["work_start_time"])) y, m, d = getYMD(str(cv["work_start_time"]))
cv["work_start_dt"] = f"%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
cv["work_start_dt"] = "%s-%02d-%02d 00:00:00" % (y, int(m), int(d))
cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y) cv["work_exp_flt"] = int(str(datetime.date.today())[0:4]) - int(y)
except Exception as e: except Exception as e:
print("【EXCEPTION】", e, "==>", cv.get("work_start_time"))
logger.exception("parse {} ==> {}".format(e, cv.get("work_start_time")))
if "work_exp_flt" not in cv and cv.get("work_experience", 0): cv["work_exp_flt"] = int(cv["work_experience"]) / 12. if "work_exp_flt" not in cv and cv.get("work_experience", 0): cv["work_exp_flt"] = int(cv["work_experience"]) / 12.


keys = list(cv.keys()) keys = list(cv.keys())


cv["tob_resume_id"] = str(cv["tob_resume_id"]) cv["tob_resume_id"] = str(cv["tob_resume_id"])
cv["id"] = cv["tob_resume_id"] cv["id"] = cv["tob_resume_id"]
print("CCCCCCCCCCCCCCC")
logger.info("CCCCCCCCCCCCCCC")


return dealWithInt64(cv) return dealWithInt64(cv)




if isinstance(d, np.integer): d = int(d) if isinstance(d, np.integer): d = int(d)
return d return d


+ 2
- 2
deepdoc/vision/operators.py View File

import numpy as np import numpy as np
import math import math
from PIL import Image from PIL import Image
from api.utils.log_utils import logger




class DecodeImage(object): class DecodeImage(object):
return None, (None, None) return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h))) img = cv2.resize(img, (int(resize_w), int(resize_h)))
except BaseException: except BaseException:
print(img.shape, resize_w, resize_h)
logger.exception("{} {} {}".format(img.shape, resize_w, resize_h))
sys.exit(0) sys.exit(0)
ratio_h = resize_h / float(h) ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w) ratio_w = resize_w / float(w)
return data return data


def resize_image_for_totaltext(self, im, max_side_len=512): def resize_image_for_totaltext(self, im, max_side_len=512):

h, w, _ = im.shape h, w, _ = im.shape
resize_w = w resize_w = w
resize_h = h resize_h = h

+ 2
- 1
deepdoc/vision/recognizer.py View File



from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from .operators import * from .operators import *
from api.utils.log_utils import logger




class Recognizer(object): class Recognizer(object):
end_index = min((i + 1) * batch_size, len(imgs)) end_index = min((i + 1) * batch_size, len(imgs))
batch_image_list = imgs[start_index:end_index] batch_image_list = imgs[start_index:end_index]
inputs = self.preprocess(batch_image_list) inputs = self.preprocess(batch_image_list)
print("preprocess")
logger.info("preprocess")
for ins in inputs: for ins in inputs:
bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names})[0], ins, thr) bb = self.postprocess(self.ort_sess.run(None, {k:v for k,v in ins.items() if k in self.input_names})[0], ins, thr)
res.append(bb) res.append(bb)

+ 2
- 1
deepdoc/vision/seeit.py View File

import os import os
import PIL import PIL
from PIL import ImageDraw from PIL import ImageDraw
from api.utils.log_utils import logger




def save_results(image_list, results, labels, output_dir='output/', threshold=0.5): def save_results(image_list, results, labels, output_dir='output/', threshold=0.5):


out_path = os.path.join(output_dir, f"{idx}.jpg") out_path = os.path.join(output_dir, f"{idx}.jpg")
im.save(out_path, quality=95) im.save(out_path, quality=95)
print("save result to: " + out_path)
logger.info("save result to: " + out_path)




def draw_box(im, result, lables, threshold=0.5): def draw_box(im, result, lables, threshold=0.5):

+ 5
- 2
deepdoc/vision/t_recognizer.py View File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os, sys
import os
import sys
from api.utils.log_utils import logger

sys.path.insert( sys.path.insert(
0, 0,
os.path.abspath( os.path.abspath(
} for t in lyt] } for t in lyt]
img = draw_box(images[i], lyt, labels, float(args.threshold)) img = draw_box(images[i], lyt, labels, float(args.threshold))
img.save(outputs[i], quality=95) img.save(outputs[i], quality=95)
print("save result to: " + outputs[i])
logger.info("save result to: " + outputs[i])




def get_table_html(img, tb_cpns, ocr): def get_table_html(img, tb_cpns, ocr):

+ 3
- 4
graphrag/claim_extractor.py View File



import argparse import argparse
import json import json
import logging
import re import re
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT from graphrag.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
from rag.llm.chat_model import Base as CompletionLLM from rag.llm.chat_model import Base as CompletionLLM
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
from api.utils.log_utils import logger


DEFAULT_TUPLE_DELIMITER = "<|>" DEFAULT_TUPLE_DELIMITER = "<|>"
DEFAULT_RECORD_DELIMITER = "##" DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
CLAIM_MAX_GLEANINGS = 1 CLAIM_MAX_GLEANINGS = 1
log = logging.getLogger(__name__)




@dataclass @dataclass
] ]
source_doc_map[document_id] = text source_doc_map[document_id] = text
except Exception as e: except Exception as e:
log.exception("error extracting claim")
logger.exception("error extracting claim")
self._on_error( self._on_error(
e, e,
traceback.format_exc(), traceback.format_exc(),
"claim_description": "" "claim_description": ""
} }
claim = ex(info) claim = ex(info)
print(json.dumps(claim.output, ensure_ascii=False, indent=2))
logger.info(json.dumps(claim.output, ensure_ascii=False, indent=2))

+ 4
- 7
graphrag/community_reports_extractor.py View File

""" """


import json import json
import logging
import re import re
import traceback import traceback
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Callable
from typing import List, Callable
import networkx as nx import networkx as nx
import pandas as pd import pandas as pd
from graphrag import leiden from graphrag import leiden
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, dict_has_keys_with_types
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from timeit import default_timer as timer from timeit import default_timer as timer

log = logging.getLogger(__name__)
from api.utils.log_utils import logger




@dataclass @dataclass
response = re.sub(r"[^\}]*$", "", response) response = re.sub(r"[^\}]*$", "", response)
response = re.sub(r"\{\{", "{", response) response = re.sub(r"\{\{", "{", response)
response = re.sub(r"\}\}", "}", response) response = re.sub(r"\}\}", "}", response)
print(response)
logger.info(response)
response = json.loads(response) response = json.loads(response)
if not dict_has_keys_with_types(response, [ if not dict_has_keys_with_types(response, [
("title", str), ("title", str),
response["weight"] = weight response["weight"] = weight
response["entities"] = ents response["entities"] = ents
except Exception as e: except Exception as e:
print("ERROR: ", traceback.format_exc())
logger.exception("CommunityReportsExtractor got exception")
self._on_error(e, traceback.format_exc(), None) self._on_error(e, traceback.format_exc(), None)
continue continue


report_sections = "\n\n".join( report_sections = "\n\n".join(
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
) )
return f"# {title}\n\n{summary}\n\n{report_sections}" return f"# {title}\n\n{summary}\n\n{report_sections}"

+ 3
- 2
graphrag/index.py View File

from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from api.utils.log_utils import logger




def graph_merge(g1, g2): def graph_merge(g1, g2):
chunks = [] chunks = []
for n, attr in graph.nodes(data=True): for n, attr in graph.nodes(data=True):
if attr.get("rank", 0) == 0: if attr.get("rank", 0) == 0:
print(f"Ignore entity: {n}")
logger.info(f"Ignore entity: {n}")
continue continue
chunk = { chunk = {
"name_kwd": n, "name_kwd": n,
mg = mindmap(_chunks).output mg = mindmap(_chunks).output
if not len(mg.keys()): return chunks if not len(mg.keys()): return chunks


print(json.dumps(mg, ensure_ascii=False, indent=2))
logger.info(json.dumps(mg, ensure_ascii=False, indent=2))
chunks.append( chunks.append(
{ {
"content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2), "content_with_weight": json.dumps(mg, ensure_ascii=False, indent=2),

+ 3
- 3
graphrag/mind_map_extractor.py View File

import logging import logging
import os import os
import re import re
import logging
import traceback import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
import markdown_to_json import markdown_to_json
from functools import reduce from functools import reduce
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from api.utils.log_utils import logger




@dataclass @dataclass
gen_conf = {"temperature": 0.5} gen_conf = {"temperature": 0.5}
response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf) response = self._llm.chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
response = re.sub(r"```[^\n]*", "", response) response = re.sub(r"```[^\n]*", "", response)
print(response)
print("---------------------------------------------------\n", self._todict(markdown_to_json.dictify(response)))
logger.info(response)
logger.info(self._todict(markdown_to_json.dictify(response)))
return self._todict(markdown_to_json.dictify(response)) return self._todict(markdown_to_json.dictify(response))

+ 3
- 3
intergrations/chatgpt-on-wechat/plugins/ragflow_chat.py View File

from bridge.context import ContextType # Import Context, ContextType from bridge.context import ContextType # Import Context, ContextType
from bridge.reply import Reply, ReplyType # Import Reply, ReplyType from bridge.reply import Reply, ReplyType # Import Reply, ReplyType
from bridge import * from bridge import *
from common.log import logger
from api.utils.log_utils import logger
from plugins import Plugin, register # Import Plugin and register from plugins import Plugin, register # Import Plugin and register
from plugins.event import Event, EventContext, EventAction # Import event-related classes from plugins.event import Event, EventContext, EventAction # Import event-related classes


logger.error(f"[RAGFlowChat] HTTP error when creating conversation: {response.status_code}") logger.error(f"[RAGFlowChat] HTTP error when creating conversation: {response.status_code}")
return f"Sorry, unable to connect to RAGFlow API (create conversation). HTTP status code: {response.status_code}" return f"Sorry, unable to connect to RAGFlow API (create conversation). HTTP status code: {response.status_code}"
except Exception as e: except Exception as e:
logger.exception(f"[RAGFlowChat] Exception when creating conversation: {e}")
logger.exception("[RAGFlowChat] Exception when creating conversation")
return f"Sorry, an internal error occurred: {str(e)}" return f"Sorry, an internal error occurred: {str(e)}"


# Step 2: Send the message and get a reply # Step 2: Send the message and get a reply
logger.error(f"[RAGFlowChat] HTTP error when getting answer: {response.status_code}") logger.error(f"[RAGFlowChat] HTTP error when getting answer: {response.status_code}")
return f"Sorry, unable to connect to RAGFlow API (get reply). HTTP status code: {response.status_code}" return f"Sorry, unable to connect to RAGFlow API (get reply). HTTP status code: {response.status_code}"
except Exception as e: except Exception as e:
logger.exception(f"[RAGFlowChat] Exception when getting answer: {e}")
logger.exception("[RAGFlowChat] Exception when getting answer")
return f"Sorry, an internal error occurred: {str(e)}" return f"Sorry, an internal error occurred: {str(e)}"

+ 2
- 1
rag/app/book.py View File

tokenize_chunks tokenize_chunks
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
from api.utils.log_utils import logger




class Pdf(PdfParser): class Pdf(PdfParser):
start = timer() start = timer()
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.67, "Layout analysis finished") callback(0.67, "Layout analysis finished")
print("layouts:", timer() - start)
logger.info("layouts: {}".format(timer() - start))
self._table_transformer_job(zoomin) self._table_transformer_job(zoomin)
callback(0.68, "Table analysis finished") callback(0.68, "Table analysis finished")
self._text_merge() self._text_merge()

+ 2
- 2
rag/app/email.py View File

from rag.nlp import rag_tokenizer, naive_merge, tokenize_chunks from rag.nlp import rag_tokenizer, naive_merge, tokenize_chunks
from deepdoc.parser import HtmlParser, TxtParser from deepdoc.parser import HtmlParser, TxtParser
from timeit import default_timer as timer from timeit import default_timer as timer
from rag.settings import cron_logger
from api.utils.log_utils import logger
import io import io




) )


main_res.extend(tokenize_chunks(chunks, doc, eng, None)) main_res.extend(tokenize_chunks(chunks, doc, eng, None))
cron_logger.info("naive_merge({}): {}".format(filename, timer() - st))
logger.info("naive_merge({}): {}".format(filename, timer() - st))
# get the attachment info # get the attachment info
for part in msg.iter_attachments(): for part in msg.iter_attachments():
content_disposition = part.get("Content-Disposition") content_disposition = part.get("Content-Disposition")

+ 3
- 3
rag/app/laws.py View File

make_colon_as_title, tokenize_chunks, docx_question_level make_colon_as_title, tokenize_chunks, docx_question_level
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser from deepdoc.parser import PdfParser, DocxParser, PlainParser, HtmlParser
from rag.settings import cron_logger
from api.utils.log_utils import logger




class Docx(DocxParser): class Docx(DocxParser):
start = timer() start = timer()
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.67, "Layout analysis finished") callback(0.67, "Layout analysis finished")
cron_logger.info("layouts:".format(
(timer() - start) / (self.total_page + 0.1)))
logger.info("layouts:".format(
))
self._naive_vertical_merge() self._naive_vertical_merge()


callback(0.8, "Text extraction finished") callback(0.8, "Text extraction finished")

+ 3
- 2
rag/app/manual.py View File

from deepdoc.parser import PdfParser, PlainParser, DocxParser from deepdoc.parser import PdfParser, PlainParser, DocxParser
from docx import Document from docx import Document
from PIL import Image from PIL import Image
from api.utils.log_utils import logger




class Pdf(PdfParser): class Pdf(PdfParser):
# for bb in self.boxes: # for bb in self.boxes:
# for b in bb: # for b in bb:
# print(b) # print(b)
print("OCR:", timer() - start)
logger.info("OCR: {}".format(timer() - start))


self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.65, "Layout analysis finished.") callback(0.65, "Layout analysis finished.")
print("layouts:", timer() - start)
logger.info("layouts: {}".format(timer() - start))
self._table_transformer_job(zoomin) self._table_transformer_job(zoomin)
callback(0.67, "Table analysis finished.") callback(0.67, "Table analysis finished.")
self._text_merge() self._text_merge()

+ 9
- 9
rag/app/naive.py View File

from rag.nlp import rag_tokenizer, naive_merge, tokenize_table, tokenize_chunks, find_codec, concat_img, \ from rag.nlp import rag_tokenizer, naive_merge, tokenize_table, tokenize_chunks, find_codec, concat_img, \
naive_merge_docx, tokenize_chunks_docx naive_merge_docx, tokenize_chunks_docx
from deepdoc.parser import PdfParser, ExcelParser, DocxParser, HtmlParser, JsonParser, MarkdownParser, TxtParser from deepdoc.parser import PdfParser, ExcelParser, DocxParser, HtmlParser, JsonParser, MarkdownParser, TxtParser
from rag.settings import cron_logger
from api.utils.log_utils import logger
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
from PIL import Image from PIL import Image
from functools import reduce from functools import reduce
try: try:
image_blob = related_part.image.blob image_blob = related_part.image.blob
except UnrecognizedImageError: except UnrecognizedImageError:
print("Unrecognized image format. Skipping image.")
logger.info("Unrecognized image format. Skipping image.")
return None return None
except UnexpectedEndOfFileError: except UnexpectedEndOfFileError:
print("EOF was unexpectedly encountered while reading an image stream. Skipping image.")
logger.info("EOF was unexpectedly encountered while reading an image stream. Skipping image.")
return None return None
except InvalidImageStreamError: except InvalidImageStreamError:
print("The recognized image stream appears to be corrupted. Skipping image.")
logger.info("The recognized image stream appears to be corrupted. Skipping image.")
return None return None
try: try:
image = Image.open(BytesIO(image_blob)).convert('RGB') image = Image.open(BytesIO(image_blob)).convert('RGB')
return image return image
except Exception as e:
except Exception:
return None return None


def __clean(self, line): def __clean(self, line):
callback callback
) )
callback(msg="OCR finished") callback(msg="OCR finished")
cron_logger.info("OCR({}~{}): {}".format(from_page, to_page, timer() - start))
logger.info("OCR({}~{}): {}".format(from_page, to_page, timer() - start))


start = timer() start = timer()
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
self._concat_downward() self._concat_downward()
# self._filter_forpages() # self._filter_forpages()


cron_logger.info("layouts: {}".format(timer() - start))
logger.info("layouts cost: {}s".format(timer() - start))
return [(b["text"], self._line_tag(b, zoomin)) return [(b["text"], self._line_tag(b, zoomin))
for b in self.boxes], tbls for b in self.boxes], tbls


return chunks return chunks


res.extend(tokenize_chunks_docx(chunks, doc, eng, images)) res.extend(tokenize_chunks_docx(chunks, doc, eng, images))
cron_logger.info("naive_merge({}): {}".format(filename, timer() - st))
logger.info("naive_merge({}): {}".format(filename, timer() - st))
return res return res


elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
return chunks return chunks


res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser)) res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
cron_logger.info("naive_merge({}): {}".format(filename, timer() - st))
logger.info("naive_merge({}): {}".format(filename, timer() - st))
return res return res





+ 2
- 1
rag/app/one.py View File

from rag.app import laws from rag.app import laws
from rag.nlp import rag_tokenizer, tokenize from rag.nlp import rag_tokenizer, tokenize
from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser from deepdoc.parser import PdfParser, ExcelParser, PlainParser, HtmlParser
from api.utils.log_utils import logger




class Pdf(PdfParser): class Pdf(PdfParser):
start = timer() start = timer()
self._layouts_rec(zoomin, drop=False) self._layouts_rec(zoomin, drop=False)
callback(0.63, "Layout analysis finished.") callback(0.63, "Layout analysis finished.")
print("layouts:", timer() - start)
logger.info("layouts cost: {}s".format(timer() - start))
self._table_transformer_job(zoomin) self._table_transformer_job(zoomin)
callback(0.65, "Table analysis finished.") callback(0.65, "Table analysis finished.")
self._text_merge() self._text_merge()

+ 8
- 7
rag/app/paper.py View File

from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks from rag.nlp import rag_tokenizer, tokenize, tokenize_table, add_positions, bullets_category, title_frequency, tokenize_chunks
from deepdoc.parser import PdfParser, PlainParser from deepdoc.parser import PdfParser, PlainParser
import numpy as np import numpy as np
from api.utils.log_utils import logger




class Pdf(PdfParser): class Pdf(PdfParser):
start = timer() start = timer()
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.63, "Layout analysis finished") callback(0.63, "Layout analysis finished")
print("layouts:", timer() - start)
logger.info(f"layouts cost: {timer() - start}s")
self._table_transformer_job(zoomin) self._table_transformer_job(zoomin)
callback(0.68, "Table analysis finished") callback(0.68, "Table analysis finished")
self._text_merge() self._text_merge()


# clean mess # clean mess
if column_width < self.page_images[0].size[0] / zoomin / 2: if column_width < self.page_images[0].size[0] / zoomin / 2:
print("two_column...................", column_width,
self.page_images[0].size[0] / zoomin / 2)
logger.info("two_column................... {} {}".format(column_width,
self.page_images[0].size[0] / zoomin / 2))
self.boxes = self.sort_X_by_page(self.boxes, column_width / 2) self.boxes = self.sort_X_by_page(self.boxes, column_width / 2)
for b in self.boxes: for b in self.boxes:
b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip()) b["text"] = re.sub(r"([\t  ]|\u3000){2,}", " ", b["text"].strip())
from_page, min( from_page, min(
to_page, self.total_page))) to_page, self.total_page)))
for b in self.boxes: for b in self.boxes:
print(b["text"], b.get("layoutno"))
print(tbls)
logger.info("{} {}".format(b["text"], b.get("layoutno")))
logger.info("{}".format(tbls))


return { return {
"title": title, "title": title,
doc["authors_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["authors_tks"]) doc["authors_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["authors_tks"])
# is it English # is it English
eng = lang.lower() == "english" # pdf_parser.is_english eng = lang.lower() == "english" # pdf_parser.is_english
print("It's English.....", eng)
logger.info("It's English.....{}".format(eng))


res = tokenize_table(paper["tables"], doc, eng) res = tokenize_table(paper["tables"], doc, eng)


if lvl <= most_level and i > 0 and lvl != levels[i - 1]: if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
sid += 1 sid += 1
sec_ids.append(sid) sec_ids.append(sid)
print(lvl, sorted_sections[i][0], most_level, sid)
logger.info("{} {} {} {}".format(lvl, sorted_sections[i][0], most_level, sid))


chunks = [] chunks = []
last_sid = -2 last_sid = -2

+ 3
- 3
rag/app/qa.py View File

from deepdoc.parser.utils import get_text from deepdoc.parser.utils import get_text
from rag.nlp import is_english, random_choices, qbullets_category, add_positions, has_qbullet, docx_question_level from rag.nlp import is_english, random_choices, qbullets_category, add_positions, has_qbullet, docx_question_level
from rag.nlp import rag_tokenizer, tokenize_table, concat_img from rag.nlp import rag_tokenizer, tokenize_table, concat_img
from rag.settings import cron_logger
from api.utils.log_utils import logger
from deepdoc.parser import PdfParser, ExcelParser, DocxParser from deepdoc.parser import PdfParser, ExcelParser, DocxParser
from docx import Document from docx import Document
from PIL import Image from PIL import Image
callback callback
) )
callback(msg="OCR finished") callback(msg="OCR finished")
cron_logger.info("OCR({}~{}): {}".format(from_page, to_page, timer() - start))
logger.info("OCR({}~{}): {}".format(from_page, to_page, timer() - start))
start = timer() start = timer()
self._layouts_rec(zoomin, drop=False) self._layouts_rec(zoomin, drop=False)
callback(0.63, "Layout analysis finished.") callback(0.63, "Layout analysis finished.")
#self._naive_vertical_merge() #self._naive_vertical_merge()
# self._concat_downward() # self._concat_downward()
#self._filter_forpages() #self._filter_forpages()
cron_logger.info("layouts: {}".format(timer() - start))
logger.info("layouts: {}".format(timer() - start))
sections = [b["text"] for b in self.boxes] sections = [b["text"] for b in self.boxes]
bull_x0_list = [] bull_x0_list = []
q_bull, reg = qbullets_category(sections) q_bull, reg = qbullets_category(sections)

+ 5
- 6
rag/app/resume.py View File

import datetime import datetime
import json import json
import re import re

import pandas as pd import pandas as pd
import requests import requests
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from deepdoc.parser.resume import refactor from deepdoc.parser.resume import refactor
from deepdoc.parser.resume import step_one, step_two from deepdoc.parser.resume import step_one, step_two
from rag.settings import cron_logger
from api.utils.log_utils import logger
from rag.utils import rmSpace from rag.utils import rmSpace


forbidden_select_fields4resume = [ forbidden_select_fields4resume = [
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}])) "updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
resume = step_two.parse(resume) resume = step_two.parse(resume)
return resume return resume
except Exception as e:
cron_logger.error("Resume parser error: " + str(e))
except Exception:
logger.exception("Resume parser error")
return {} return {}




callback(-1, "Resume is not successfully parsed.") callback(-1, "Resume is not successfully parsed.")
raise Exception("Resume parser remote call fail!") raise Exception("Resume parser remote call fail!")
callback(0.6, "Done parsing. Chunking...") callback(0.6, "Done parsing. Chunking...")
print(json.dumps(resume, ensure_ascii=False, indent=2))
logger.info("chunking resume: " + json.dumps(resume, ensure_ascii=False, indent=2))


field_map = { field_map = {
"name_kwd": "姓名/名字", "name_kwd": "姓名/名字",
resume[n] = rag_tokenizer.fine_grained_tokenize(resume[n]) resume[n] = rag_tokenizer.fine_grained_tokenize(resume[n])
doc[n] = resume[n] doc[n] = resume[n]


print(doc)
logger.info("chunked resume to " + str(doc))
KnowledgebaseService.update_parser_config( KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": field_map}) kwargs["kb_id"], {"field_map": field_map})
return [doc] return [doc]

+ 5
- 4
rag/llm/embedding_model.py View File

from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai import google.generativeai as genai
import json import json
from api.utils.log_utils import logger


class Base(ABC): class Base(ABC):
def __init__(self, key, model_name): def __init__(self, key, model_name):
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available()) use_fp16=torch.cuda.is_available())
except Exception as e:
except Exception:
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
local_dir_use_symlinks=False) local_dir_use_symlinks=False)
) )
return np.array(resp["output"]["embeddings"][0] return np.array(resp["output"]["embeddings"][0]
["embedding"]), resp["usage"]["total_tokens"] ["embedding"]), resp["usage"]["total_tokens"]
except Exception as e:
except Exception:
raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name)
return np.array([]), 0 return np.array([]), 0


if not LIGHTEN and not YoudaoEmbed._client: if not LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing from BCEmbedding import EmbeddingModel as qanthing
try: try:
print("LOADING BCE...")
logger.info("LOADING BCE...")
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join( YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
get_home_cache_dir(), get_home_cache_dir(),
"bce-embedding-base_v1")) "bce-embedding-base_v1"))
except Exception as e:
except Exception:
YoudaoEmbed._client = qanthing( YoudaoEmbed._client = qanthing(
model_name_or_path=model_name.replace( model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow")) "maidalun1020", "InfiniFlow"))

+ 4
- 3
rag/llm/rerank_model.py View File

from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import json import json
from api.utils.log_utils import logger




def sigmoid(x): def sigmoid(x):
DefaultRerank._model = FlagReranker( DefaultRerank._model = FlagReranker(
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
use_fp16=torch.cuda.is_available()) use_fp16=torch.cuda.is_available())
except Exception as e:
except Exception:
model_dir = snapshot_download(repo_id=model_name, model_dir = snapshot_download(repo_id=model_name,
local_dir=os.path.join(get_home_cache_dir(), local_dir=os.path.join(get_home_cache_dir(),
re.sub(r"^[a-zA-Z]+/", "", model_name)), re.sub(r"^[a-zA-Z]+/", "", model_name)),
with YoudaoRerank._model_lock: with YoudaoRerank._model_lock:
if not YoudaoRerank._model: if not YoudaoRerank._model:
try: try:
print("LOADING BCE...")
logger.info("LOADING BCE...")
YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join( YoudaoRerank._model = RerankerModel(model_name_or_path=os.path.join(
get_home_cache_dir(), get_home_cache_dir(),
re.sub(r"^[a-zA-Z]+/", "", model_name))) re.sub(r"^[a-zA-Z]+/", "", model_name)))
except Exception as e:
except Exception:
YoudaoRerank._model = RerankerModel( YoudaoRerank._model = RerankerModel(
model_name_or_path=model_name.replace( model_name_or_path=model_name.replace(
"maidalun1020", "InfiniFlow")) "maidalun1020", "InfiniFlow"))

+ 4
- 3
rag/nlp/__init__.py View File

from cn2an import cn2an from cn2an import cn2an
from PIL import Image from PIL import Image
import json import json
from api.utils.log_utils import logger


all_codecs = [ all_codecs = [
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs', 'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
# wrap up as es documents # wrap up as es documents
for ck in chunks: for ck in chunks:
if len(ck.strip()) == 0:continue if len(ck.strip()) == 0:continue
print("--", ck)
logger.debug("-- {}".format(ck))
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
if pdf_parser: if pdf_parser:
try: try:
# wrap up as es documents # wrap up as es documents
for ck, image in zip(chunks, images): for ck, image in zip(chunks, images):
if len(ck.strip()) == 0:continue if len(ck.strip()) == 0:continue
print("--", ck)
logger.debug("-- {}".format(ck))
d = copy.deepcopy(doc) d = copy.deepcopy(doc)
d["image"] = image d["image"] = image
tokenize(d, ck, eng) tokenize(d, ck, eng)


for i in range(len(cks)): for i in range(len(cks)):
cks[i] = [sections[j] for j in cks[i][::-1]] cks[i] = [sections[j] for j in cks[i][::-1]]
print("--------------\n", "\n* ".join(cks[i]))
logger.info("\n* ".join(cks[i]))


res = [[]] res = [[]]
num = [0] num = [0]

+ 23
- 24
rag/nlp/rag_tokenizer.py View File

import string import string
import sys import sys
from hanziconv import HanziConv from hanziconv import HanziConv
from huggingface_hub import snapshot_download
from nltk import word_tokenize from nltk import word_tokenize
from nltk.stem import PorterStemmer, WordNetLemmatizer from nltk.stem import PorterStemmer, WordNetLemmatizer
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import logger




class RagTokenizer: class RagTokenizer:
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1] return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]


def loadDict_(self, fnm): def loadDict_(self, fnm):
print("[HUQIE]:Build trie", fnm, file=sys.stderr)
logger.info(f"[HUQIE]:Build trie {fnm}")
try: try:
of = open(fnm, "r", encoding='utf-8') of = open(fnm, "r", encoding='utf-8')
while True: while True:
self.trie_[self.rkey_(line[0])] = 1 self.trie_[self.rkey_(line[0])] = 1
self.trie_.save(fnm + ".trie") self.trie_.save(fnm + ".trie")
of.close() of.close()
except Exception as e:
print("[HUQIE]:Faild to build trie, ", fnm, e, file=sys.stderr)
except Exception:
logger.exception(f"[HUQIE]:Build trie {fnm} failed")



def __init__(self, debug=False): def __init__(self, debug=False):
self.DEBUG = debug self.DEBUG = debug
try: try:
self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie") self.trie_ = datrie.Trie.load(self.DIR_ + ".txt.trie")
return return
except Exception as e:
print("[HUQIE]:Build default trie", file=sys.stderr)
except Exception:
logger.exception("[HUQIE]:Build default trie")
self.trie_ = datrie.Trie(string.printable) self.trie_ = datrie.Trie(string.printable)


self.loadDict_(self.DIR_ + ".txt") self.loadDict_(self.DIR_ + ".txt")
try: try:
self.trie_ = datrie.Trie.load(fnm + ".trie") self.trie_ = datrie.Trie.load(fnm + ".trie")
return return
except Exception as e:
except Exception:
self.trie_ = datrie.Trie(string.printable) self.trie_ = datrie.Trie(string.printable)
self.loadDict_(fnm) self.loadDict_(fnm)


tks.append(tk) tks.append(tk)
F /= len(tks) F /= len(tks)
L /= len(tks) L /= len(tks)
if self.DEBUG:
print("[SC]", tks, len(tks), L, F, B / len(tks) + L + F)
logger.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
return tks, B / len(tks) + L + F return tks, B / len(tks) + L + F


def sortTks_(self, tkslist): def sortTks_(self, tkslist):
tks, s = self.maxForward_(L) tks, s = self.maxForward_(L)
tks1, s1 = self.maxBackward_(L) tks1, s1 = self.maxBackward_(L)
if self.DEBUG: if self.DEBUG:
print("[FW]", tks, s)
print("[BW]", tks1, s1)
logger.debug("[FW] {} {}".format(tks, s))
logger.debug("[BW] {} {}".format(tks1, s1))


i, j, _i, _j = 0, 0, 0, 0 i, j, _i, _j = 0, 0, 0, 0
same = 0 same = 0
res.append(" ".join(self.sortTks_(tkslist)[0][0])) res.append(" ".join(self.sortTks_(tkslist)[0][0]))


res = " ".join(self.english_normalize_(res)) res = " ".join(self.english_normalize_(res))
if self.DEBUG:
print("[TKS]", self.merge_(res))
logger.debug("[TKS] {}".format(self.merge_(res)))
return self.merge_(res) return self.merge_(res)


def fine_grained_tokenize(self, tks): def fine_grained_tokenize(self, tks):
# huqie.addUserDict("/tmp/tmp.new.tks.dict") # huqie.addUserDict("/tmp/tmp.new.tks.dict")
tks = tknzr.tokenize( tks = tknzr.tokenize(
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈") "哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize( tks = tknzr.tokenize(
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。") "公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize( tks = tknzr.tokenize(
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥") "多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize( tks = tknzr.tokenize(
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa") "实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("虽然我不怎么玩") tks = tknzr.tokenize("虽然我不怎么玩")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的") tks = tknzr.tokenize("蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize( tks = tknzr.tokenize(
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了") "涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?") tks = tknzr.tokenize("这周日你去吗?这周日你有空吗?")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ") tks = tknzr.tokenize("Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
tks = tknzr.tokenize( tks = tknzr.tokenize(
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-") "数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-")
print(tknzr.fine_grained_tokenize(tks))
logger.info(tknzr.fine_grained_tokenize(tks))
if len(sys.argv) < 2: if len(sys.argv) < 2:
sys.exit() sys.exit()
tknzr.DEBUG = False tknzr.DEBUG = False
line = of.readline() line = of.readline()
if not line: if not line:
break break
print(tknzr.tokenize(line))
logger.info(tknzr.tokenize(line))
of.close() of.close()

+ 8
- 8
rag/nlp/search.py View File

from typing import List, Optional, Dict, Union from typing import List, Optional, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass


from rag.settings import doc_store_logger
from api.utils.log_utils import logger
from rag.utils import rmSpace from rag.utils import rmSpace
from rag.nlp import rag_tokenizer, query from rag.nlp import rag_tokenizer, query
import numpy as np import numpy as np
orderBy.desc("create_timestamp_flt") orderBy.desc("create_timestamp_flt")
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res) total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
logger.info("Dealer.search TOTAL: {}".format(total))
else: else:
highlightFields = ["content_ltks", "title_tks"] if highlight else [] highlightFields = ["content_ltks", "title_tks"] if highlight else []
matchText, keywords = self.qryr.question(qst, min_match=0.3) matchText, keywords = self.qryr.question(qst, min_match=0.3)
matchExprs = [matchText] matchExprs = [matchText]
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res) total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
logger.info("Dealer.search TOTAL: {}".format(total))
else: else:
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
q_vec = matchDense.embedding_data q_vec = matchDense.embedding_data


res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res) total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search TOTAL: {}".format(total))
logger.info("Dealer.search TOTAL: {}".format(total))


# If result is empty, try again with lower min_match # If result is empty, try again with lower min_match
if total == 0: if total == 0:
matchDense.extra_options["similarity"] = 0.17 matchDense.extra_options["similarity"] = 0.17
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids) res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
total=self.dataStore.getTotal(res) total=self.dataStore.getTotal(res)
doc_store_logger.info("Dealer.search 2 TOTAL: {}".format(total))
logger.info("Dealer.search 2 TOTAL: {}".format(total))


for k in keywords: for k in keywords:
kwds.add(k) kwds.add(k)
continue continue
kwds.add(kk) kwds.add(kk)


doc_store_logger.info(f"TOTAL: {total}")
logger.info(f"TOTAL: {total}")
ids=self.dataStore.getChunkIds(res) ids=self.dataStore.getChunkIds(res)
keywords=list(kwds) keywords=list(kwds)
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
continue continue
idx.append(i) idx.append(i)
pieces_.append(t) pieces_.append(t)
doc_store_logger.info("{} => {}".format(answer, pieces_))
logger.info("{} => {}".format(answer, pieces_))
if not pieces_: if not pieces_:
return answer, set([]) return answer, set([])


chunks_tks, chunks_tks,
tkweight, vtweight) tkweight, vtweight)
mx = np.max(sim) * 0.99 mx = np.max(sim) * 0.99
doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx))
logger.info("{} SIM: {}".format(pieces_[i], mx))
if mx < thr: if mx < thr:
continue continue
cites[idx[i]] = list( cites[idx[i]] = list(

+ 6
- 6
rag/nlp/synonym.py View File

import json import json
import os import os
import time import time
import logging
import re import re


from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import logger




class Dealer: class Dealer:
path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json") path = os.path.join(get_project_base_directory(), "rag/res", "synonym.json")
try: try:
self.dictionary = json.load(open(path, 'r')) self.dictionary = json.load(open(path, 'r'))
except Exception as e:
logging.warn("Missing synonym.json")
except Exception:
logger.warn("Missing synonym.json")
self.dictionary = {} self.dictionary = {}


if not redis: if not redis:
logging.warning(
logger.warning(
"Realtime synonym is disabled, since no redis connection.") "Realtime synonym is disabled, since no redis connection.")
if not len(self.dictionary.keys()): if not len(self.dictionary.keys()):
logging.warning(f"Fail to load synonym")
logger.warning("Fail to load synonym")


self.redis = redis self.redis = redis
self.load() self.load()
d = json.loads(d) d = json.loads(d)
self.dictionary = d self.dictionary = d
except Exception as e: except Exception as e:
logging.error("Fail to load synonym!" + str(e))
logger.error("Fail to load synonym!" + str(e))


def lookup(self, tk): def lookup(self, tk):
self.lookup_num += 1 self.lookup_num += 1

+ 5
- 4
rag/nlp/term_weight.py View File

import numpy as np import numpy as np
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import logger




class Dealer: class Dealer:
self.ne, self.df = {}, {} self.ne, self.df = {}, {}
try: try:
self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r")) self.ne = json.load(open(os.path.join(fnm, "ner.json"), "r"))
except Exception as e:
print("[WARNING] Load ner.json FAIL!")
except Exception:
logger.warning("Load ner.json FAIL!")
try: try:
self.df = load_dict(os.path.join(fnm, "term.freq")) self.df = load_dict(os.path.join(fnm, "term.freq"))
except Exception as e:
print("[WARNING] Load term.freq FAIL!")
except Exception:
logger.warning("Load term.freq FAIL!")


def pretoken(self, txt, num=False, stpwd=True): def pretoken(self, txt, num=False, stpwd=True):
patt = [ patt = [

+ 5
- 6
rag/raptor.py View File

# limitations under the License. # limitations under the License.
# #
import re import re
import traceback
from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait
from threading import Lock from threading import Lock
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture from sklearn.mixture import GaussianMixture


from rag.utils import num_tokens_from_string, truncate
from rag.utils import truncate
from api.utils.log_utils import logger




class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
{"temperature": 0.3, "max_tokens": self._max_token} {"temperature": 0.3, "max_tokens": self._max_token}
) )
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt) cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt)
print("SUM:", cnt)
logger.info(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt]) embds, _ = self._embd_model.encode([cnt])
with lock: with lock:
if not len(embds[0]): return if not len(embds[0]): return
chunks.append((cnt, embds[0])) chunks.append((cnt, embds[0]))
except Exception as e: except Exception as e:
print(e, flush=True)
traceback.print_stack(e)
logger.exception("summarize got exception")
return e return e


labels = [] labels = []
ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c] ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c]
threads.append(executor.submit(summarize, ck_idx, lock)) threads.append(executor.submit(summarize, ck_idx, lock))
wait(threads, return_when=ALL_COMPLETED) wait(threads, return_when=ALL_COMPLETED)
print([t.result() for t in threads])
logger.info(str([t.result() for t in threads]))


assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
labels.extend(lbls) labels.extend(lbls)

+ 0
- 27
rag/settings.py View File

# limitations under the License. # limitations under the License.
# #
import os import os
import logging
from api.utils import get_base_config, decrypt_database_config from api.utils import get_base_config, decrypt_database_config
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger



# Server # Server
RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf") RAG_CONF_PATH = os.path.join(get_project_base_directory(), "conf")
SUBPROCESS_STD_LOG_NAME = "std.log"


ES = get_base_config("es", {}) ES = get_base_config("es", {})
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
pass pass
DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024))


# Logger
LoggerFactory.set_directory(
os.path.join(
get_project_base_directory(),
"logs",
"rag"))
# {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0}
LoggerFactory.LEVEL = 30

doc_store_logger = getLogger("doc_store")
minio_logger = getLogger("minio")
s3_logger = getLogger("s3")
azure_logger = getLogger("azure")
cron_logger = getLogger("cron_logger")
chunk_logger = getLogger("chunk_logger")
database_logger = getLogger("database")

formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s")
for logger in [doc_store_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]:
logger.setLevel(logging.INFO)
for handler in logger.handlers:
handler.setFormatter(fmt=formatter)

SVR_QUEUE_NAME = "rag_flow_svr_queue" SVR_QUEUE_NAME = "rag_flow_svr_queue"
SVR_QUEUE_RETENTION = 60*60 SVR_QUEUE_RETENTION = 60*60
SVR_QUEUE_MAX_LEN = 1024 SVR_QUEUE_MAX_LEN = 1024

+ 4
- 5
rag/svr/cache_file_svr.py View File

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import random
import time import time
import traceback import traceback


from api.db.db_models import close_connection from api.db.db_models import close_connection
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from rag.settings import cron_logger
from api.utils.log_utils import logger
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from rag.utils.redis_conn import REDIS_CONN from rag.utils.redis_conn import REDIS_CONN




def collect(): def collect():
doc_locations = TaskService.get_ongoing_doc_name() doc_locations = TaskService.get_ongoing_doc_name()
print(doc_locations)
logger.info(doc_locations)
if len(doc_locations) == 0: if len(doc_locations) == 0:
time.sleep(1) time.sleep(1)
return return
def main(): def main():
locations = collect() locations = collect()
if not locations:return if not locations:return
print("TASKS:", len(locations))
logger.info(f"TASKS: {len(locations)}")
for kb_id, loc in locations: for kb_id, loc in locations:
try: try:
if REDIS_CONN.is_alive(): if REDIS_CONN.is_alive():
if REDIS_CONN.exist(key):continue if REDIS_CONN.exist(key):continue
file_bin = STORAGE_IMPL.get(kb_id, loc) file_bin = STORAGE_IMPL.get(kb_id, loc)
REDIS_CONN.transaction(key, file_bin, 12 * 60) REDIS_CONN.transaction(key, file_bin, 12 * 60)
cron_logger.info("CACHE: {}".format(loc))
logger.info("CACHE: {}".format(loc))
except Exception as e: except Exception as e:
traceback.print_stack(e) traceback.print_stack(e)
except Exception as e: except Exception as e:

+ 2
- 1
rag/svr/discord_svr.py View File

import requests import requests
import base64 import base64
import asyncio import asyncio
from api.utils.log_utils import logger


URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk URL = '{YOUR_IP_ADDRESS:PORT}/v1/api/completion_aibotk' # Default: https://demo.ragflow.io/v1/api/completion_aibotk




@client.event @client.event
async def on_ready(): async def on_ready():
print(f'We have logged in as {client.user}')
logger.info(f'We have logged in as {client.user}')




@client.event @client.event

+ 30
- 36
rag/svr/task_executor.py View File

import re import re
import sys import sys
import time import time
import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import database_logger, SVR_QUEUE_NAME
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from api.utils.log_utils import logger, LOG_FILE
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
from rag.utils import rmSpace, num_tokens_from_string from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.redis_conn import REDIS_CONN, Payload from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
d["progress"] = prog d["progress"] = prog
try: try:
TaskService.update_progress(task_id, d) TaskService.update_progress(task_id, d)
except Exception as e:
cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
except Exception:
logger.exception(f"set_progress({task_id}) got exception")


close_connection() close_connection()
if cancel: if cancel:
if not PAYLOAD: if not PAYLOAD:
time.sleep(1) time.sleep(1)
return pd.DataFrame() return pd.DataFrame()
except Exception as e:
cron_logger.error("Get task event from queue exception:" + str(e))
except Exception:
logger.exception("Get task event from queue exception")
return pd.DataFrame() return pd.DataFrame()


msg = PAYLOAD.get_message() msg = PAYLOAD.get_message()
return pd.DataFrame() return pd.DataFrame()


if TaskService.do_cancel(msg["id"]): if TaskService.do_cancel(msg["id"]):
cron_logger.info("Task {} has been canceled.".format(msg["id"]))
logger.info("Task {} has been canceled.".format(msg["id"]))
return pd.DataFrame() return pd.DataFrame()
tasks = TaskService.get_tasks(msg["id"]) tasks = TaskService.get_tasks(msg["id"])
if not tasks: if not tasks:
cron_logger.warning("{} empty task!".format(msg["id"]))
logger.warning("{} empty task!".format(msg["id"]))
return [] return []


tasks = pd.DataFrame(tasks) tasks = pd.DataFrame(tasks)
st = timer() st = timer()
bucket, name = File2DocumentService.get_storage_address(doc_id=row["doc_id"]) bucket, name = File2DocumentService.get_storage_address(doc_id=row["doc_id"])
binary = get_storage_binary(bucket, name) binary = get_storage_binary(bucket, name)
cron_logger.info(
logger.info(
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError: except TimeoutError:
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
cron_logger.error(
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
logger.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
return return
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"]) callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
else: else:
callback(-1, "Get file from minio: %s" % str(e).replace("'", "")) callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
traceback.print_exc()
logger.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
return return


try: try:
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"], cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
to_page=row["to_page"], lang=row["language"], callback=callback, to_page=row["to_page"], lang=row["language"], callback=callback,
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"]) kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
cron_logger.info(
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
logger.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
except Exception as e: except Exception as e:
callback(-1, "Internal server error while chunking: %s" % callback(-1, "Internal server error while chunking: %s" %
str(e).replace("'", "")) str(e).replace("'", ""))
cron_logger.error(
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
traceback.print_exc()
logger.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
return return


docs = [] docs = []
st = timer() st = timer()
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue()) STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st el += timer() - st
except Exception as e:
cron_logger.error(str(e))
traceback.print_exc()
except Exception:
logger.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))


d["img_id"] = "{}-{}".format(row["kb_id"], d["id"]) d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
del d["image"] del d["image"]
docs.append(d) docs.append(d)
cron_logger.info("MINIO PUT({}):{}".format(row["name"], el))
logger.info("MINIO PUT({}):{}".format(row["name"], el))


if row["parser_config"].get("auto_keywords", 0): if row["parser_config"].get("auto_keywords", 0):
callback(msg="Start to generate keywords for every chunk ...") callback(msg="Start to generate keywords for every chunk ...")
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"]) embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
except Exception as e: except Exception as e:
callback(-1, msg=str(e)) callback(-1, msg=str(e))
cron_logger.error(str(e))
logger.exception("LLMBundle got exception")
continue continue


if r.get("task_type", "") == "raptor": if r.get("task_type", "") == "raptor":
cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback) cks, tk_count, vector_size = run_raptor(r, chat_mdl, embd_mdl, callback)
except Exception as e: except Exception as e:
callback(-1, msg=str(e)) callback(-1, msg=str(e))
cron_logger.error(str(e))
logger.exception("run_raptor got exception")
continue continue
else: else:
st = timer() st = timer()
cks = build(r) cks = build(r)
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
if cks is None: if cks is None:
continue continue
if not cks: if not cks:
tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback) tk_count, vector_size = embedding(cks, embd_mdl, r["parser_config"], callback)
except Exception as e: except Exception as e:
callback(-1, "Embedding error:{}".format(str(e))) callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))
logger.exception("run_rembedding got exception")
tk_count = 0 tk_count = 0
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st)) callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))


# cron_logger.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
# logger.info(f"task_executor init_kb index {search.index_name(r["tenant_id"])} embd_mdl {embd_mdl.llm_name} vector length {vector_size}")
init_kb(r, vector_size) init_kb(r, vector_size)
chunk_count = len(set([c["id"] for c in cks])) chunk_count = len(set([c["id"] for c in cks]))
st = timer() st = timer()
if b % 128 == 0: if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")


cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r: if es_r:
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
callback(-1, f"Insert chunk error, detail info please check {LOG_FILE}. Please also check ES status!")
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
cron_logger.error('Insert chunk error: ' + str(es_r))
logger.error('Insert chunk error: ' + str(es_r))
else: else:
if TaskService.do_cancel(r["id"]): if TaskService.do_cancel(r["id"]):
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
callback(1., "Done!") callback(1., "Done!")
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info(
logger.info(
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format( "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
r["id"], tk_count, len(cks), timer() - st)) r["id"], tk_count, len(cks), timer() - st))


obj[CONSUMER_NAME].append(timer()) obj[CONSUMER_NAME].append(timer())
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:] obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
REDIS_CONN.set_obj("TASKEXE", obj, 60*2) REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
except Exception as e:
print("[Exception]:", str(e))
except Exception:
logger.exception("report_status got exception")
time.sleep(30) time.sleep(30)




if __name__ == "__main__": if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee') peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
peewee_logger.addHandler(logger.handlers[0])
peewee_logger.setLevel(logger.handlers[0].level)


exe = ThreadPoolExecutor(max_workers=1) exe = ThreadPoolExecutor(max_workers=1)
exe.submit(report_status) exe.submit(report_status)

+ 13
- 15
rag/utils/azure_sas_conn.py View File

import time import time
from io import BytesIO from io import BytesIO
from rag import settings from rag import settings
from rag.settings import azure_logger
from rag.utils import singleton from rag.utils import singleton
from azure.storage.blob import ContainerClient from azure.storage.blob import ContainerClient


try: try:
if self.conn: if self.conn:
self.__close__() self.__close__()
except Exception as e:
except Exception:
pass pass


try: try:
self.conn = ContainerClient.from_container_url(self.container_url + "?" + self.sas_token) self.conn = ContainerClient.from_container_url(self.container_url + "?" + self.sas_token)
except Exception as e:
azure_logger.error(
"Fail to connect %s " % self.container_url + str(e))
except Exception:
logger.exception("Fail to connect %s " % self.container_url)


def __close__(self): def __close__(self):
del self.conn del self.conn
for _ in range(3): for _ in range(3):
try: try:
return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary)) return self.conn.upload_blob(name=fnm, data=BytesIO(binary), length=len(binary))
except Exception as e:
azure_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)


def rm(self, bucket, fnm): def rm(self, bucket, fnm):
try: try:
self.conn.delete_blob(fnm) self.conn.delete_blob(fnm)
except Exception as e:
azure_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail rm {bucket}/{fnm}")


def get(self, bucket, fnm): def get(self, bucket, fnm):
for _ in range(1): for _ in range(1):
try: try:
r = self.conn.download_blob(fnm) r = self.conn.download_blob(fnm)
return r.read() return r.read()
except Exception as e:
azure_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return
def obj_exist(self, bucket, fnm): def obj_exist(self, bucket, fnm):
try: try:
return self.conn.get_blob_client(fnm).exists() return self.conn.get_blob_client(fnm).exists()
except Exception as e:
azure_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}")
return False return False


def get_presigned_url(self, bucket, fnm, expires): def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10): for _ in range(10):
try: try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires) return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e:
azure_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return

+ 13
- 15
rag/utils/azure_spn_conn.py View File

import os import os
import time import time
from rag import settings from rag import settings
from rag.settings import azure_logger
from rag.utils import singleton from rag.utils import singleton
from azure.identity import ClientSecretCredential, AzureAuthorityHosts from azure.identity import ClientSecretCredential, AzureAuthorityHosts
from azure.storage.filedatalake import FileSystemClient from azure.storage.filedatalake import FileSystemClient
try: try:
if self.conn: if self.conn:
self.__close__() self.__close__()
except Exception as e:
except Exception:
pass pass


try: try:
credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA) credentials = ClientSecretCredential(tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.secret, authority=AzureAuthorityHosts.AZURE_CHINA)
self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials) self.conn = FileSystemClient(account_url=self.account_url, file_system_name=self.container_name, credential=credentials)
except Exception as e:
azure_logger.error(
"Fail to connect %s " % self.account_url + str(e))
except Exception:
logger.exception("Fail to connect %s" % self.account_url)


def __close__(self): def __close__(self):
del self.conn del self.conn
f = self.conn.create_file(fnm) f = self.conn.create_file(fnm)
f.append_data(binary, offset=0, length=len(binary)) f.append_data(binary, offset=0, length=len(binary))
return f.flush_data(len(binary)) return f.flush_data(len(binary))
except Exception as e:
azure_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)


def rm(self, bucket, fnm): def rm(self, bucket, fnm):
try: try:
self.conn.delete_file(fnm) self.conn.delete_file(fnm)
except Exception as e:
azure_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail rm {bucket}/{fnm}")


def get(self, bucket, fnm): def get(self, bucket, fnm):
for _ in range(1): for _ in range(1):
client = self.conn.get_file_client(fnm) client = self.conn.get_file_client(fnm)
r = client.download_file() r = client.download_file()
return r.read() return r.read()
except Exception as e:
azure_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return
try: try:
client = self.conn.get_file_client(fnm) client = self.conn.get_file_client(fnm)
return client.exists() return client.exists()
except Exception as e:
azure_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}")
return False return False


def get_presigned_url(self, bucket, fnm, expires): def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10): for _ in range(10):
try: try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires) return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e:
azure_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return

+ 27
- 38
rag/utils/es_conn.py View File

from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
from elastic_transport import ConnectionTimeout from elastic_transport import ConnectionTimeout
from rag.settings import doc_store_logger
from api.utils.log_utils import logger
from rag import settings from rag import settings
from rag.utils import singleton from rag.utils import singleton
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
from rag.nlp import is_english, rag_tokenizer from rag.nlp import is_english, rag_tokenizer


doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))
logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__))




@singleton @singleton
) )
if self.es: if self.es:
self.info = self.es.info() self.info = self.es.info()
doc_store_logger.info("Connect to es.")
logger.info("Connect to es.")
break break
except Exception as e:
doc_store_logger.error("Fail to connect to es: " + str(e))
except Exception:
logger.exception("Fail to connect to es")
time.sleep(1) time.sleep(1)
if not self.es.ping(): if not self.es.ping():
raise Exception("Can't connect to ES cluster") raise Exception("Can't connect to ES cluster")
return IndicesClient(self.es).create(index=indexName, return IndicesClient(self.es).create(index=indexName,
settings=self.mapping["settings"], settings=self.mapping["settings"],
mappings=self.mapping["mappings"]) mappings=self.mapping["mappings"])
except Exception as e:
doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e)))
except Exception:
logger.exception("ES create index error %s" % (indexName))


def deleteIdx(self, indexName: str, knowledgebaseId: str): def deleteIdx(self, indexName: str, knowledgebaseId: str):
try: try:
return self.es.indices.delete(indexName, allow_no_indices=True) return self.es.indices.delete(indexName, allow_no_indices=True)
except Exception as e:
doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e)))
except Exception:
logger.exception("ES delete index error %s" % (indexName))


def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
s = Index(indexName, self.es) s = Index(indexName, self.es)
try: try:
return s.exists() return s.exists()
except Exception as e: except Exception as e:
doc_store_logger.error("ES indexExist: " + str(e))
logger.exception("ES indexExist")
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
return False return False
if limit > 0: if limit > 0:
s = s[offset:limit] s = s[offset:limit]
q = s.to_dict() q = s.to_dict()
doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q))
# logger.info("ESConnection.search [Q]: " + json.dumps(q))


for i in range(3): for i in range(3):
try: try:
_source=True) _source=True)
if str(res.get("timed_out", "")).lower() == "true": if str(res.get("timed_out", "")).lower() == "true":
raise Exception("Es Timeout.") raise Exception("Es Timeout.")
doc_store_logger.info("ESConnection.search res: " + str(res))
logger.info("ESConnection.search res: " + str(res))
return res return res
except Exception as e: except Exception as e:
doc_store_logger.error(
"ES search exception: " +
str(e) +
"\n[Q]: " +
str(q))
logger.exception("ES search [Q]: " + str(q))
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
raise e raise e
doc_store_logger.error("ES search timeout for 3 times!")
logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.") raise Exception("ES search timeout.")


def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
chunk["id"] = chunkId chunk["id"] = chunkId
return chunk return chunk
except Exception as e: except Exception as e:
doc_store_logger.error(
"ES get exception: " +
str(e) +
"[Q]: " +
chunkId)
logger.exception(f"ES get({chunkId}) got exception")
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
raise e raise e
doc_store_logger.error("ES search timeout for 3 times!")
logger.error("ES search timeout for 3 times!")
raise Exception("ES search timeout.") raise Exception("ES search timeout.")


def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]: def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"])) res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
return res return res
except Exception as e: except Exception as e:
doc_store_logger.warning("Fail to bulk: " + str(e))
logger.warning("Fail to bulk: " + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3) time.sleep(3)
continue continue
self.es.update(index=indexName, id=chunkId, doc=doc) self.es.update(index=indexName, id=chunkId, doc=doc)
return True return True
except Exception as e: except Exception as e:
doc_store_logger.error(
"ES update exception: " + str(e) + " id:" + str(id) +
json.dumps(newValue, ensure_ascii=False))
logger.exception(f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})")
if str(e).find("Timeout") > 0: if str(e).find("Timeout") > 0:
continue continue
else: else:
_ = ubq.execute() _ = ubq.execute()
return True return True
except Exception as e: except Exception as e:
doc_store_logger.error("ES update exception: " +
str(e) + "[Q]:" + str(bqry.to_dict()))
logger.error("ES update exception: " + str(e) + "[Q]:" + str(bqry.to_dict()))
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
continue continue
return False return False
qry.must.append(Q("term", **{k: v})) qry.must.append(Q("term", **{k: v}))
else: else:
raise Exception("Condition value must be int, str or list.") raise Exception("Condition value must be int, str or list.")
doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict()))
for _ in range(10): for _ in range(10):
try: try:
res = self.es.delete_by_query( res = self.es.delete_by_query(
refresh=True) refresh=True)
return res["deleted"] return res["deleted"]
except Exception as e: except Exception as e:
doc_store_logger.warning("Fail to delete: " + str(filter) + str(e))
logger.warning("Fail to delete: " + str(filter) + str(e))
if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE):
time.sleep(3) time.sleep(3)
continue continue
SQL SQL
""" """
def sql(self, sql: str, fetch_size: int, format: str): def sql(self, sql: str, fetch_size: int, format: str):
doc_store_logger.info(f"ESConnection.sql get sql: {sql}")
logger.info(f"ESConnection.sql get sql: {sql}")
sql = re.sub(r"[ `]+", " ", sql) sql = re.sub(r"[ `]+", " ", sql)
sql = sql.replace("%", "") sql = sql.replace("%", "")
replaces = [] replaces = []


for p, r in replaces: for p, r in replaces:
sql = sql.replace(p, r, 1) sql = sql.replace(p, r, 1)
doc_store_logger.info(f"ESConnection.sql to es: {sql}")
logger.info(f"ESConnection.sql to es: {sql}")


for i in range(3): for i in range(3):
try: try:
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s") res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s")
return res return res
except ConnectionTimeout: except ConnectionTimeout:
doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql)
logger.exception("ESConnection.sql timeout [Q]: " + sql)
continue continue
except Exception as e:
doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e))
except Exception:
logger.exception("ESConnection.sql got exception [Q]: " + sql)
return None return None
doc_store_logger.error("ESConnection.sql timeout for 3 times!")
logger.error("ESConnection.sql timeout for 3 times!")
return None return None

+ 12
- 13
rag/utils/infinity_conn.py View File

from infinity.index import IndexInfo, IndexType from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool from infinity.connection_pool import ConnectionPool
from rag import settings from rag import settings
from rag.settings import doc_store_logger
from api.utils.log_utils import logger
from rag.utils import singleton from rag.utils import singleton
import polars as pl import polars as pl
from polars.series.series import Series from polars.series.series import Series
OrderByExpr, OrderByExpr,
) )



def equivalent_condition_to_str(condition: dict) -> str: def equivalent_condition_to_str(condition: dict) -> str:
assert "_id" not in condition assert "_id" not in condition
cond = list() cond = list()
host, port = infinity_uri.split(":") host, port = infinity_uri.split(":")
infinity_uri = infinity.common.NetworkAddress(host, int(port)) infinity_uri = infinity.common.NetworkAddress(host, int(port))
self.connPool = ConnectionPool(infinity_uri) self.connPool = ConnectionPool(infinity_uri)
doc_store_logger.info(f"Connected to infinity {infinity_uri}.")
logger.info(f"Connected to infinity {infinity_uri}.")


""" """
Database operations Database operations
TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables` TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables`
""" """
inf_conn = self.connPool.get_conn() inf_conn = self.connPool.get_conn()
res = infinity.show_current_node()
res = inf_conn.show_current_node()
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
color = "green" if res.error_code == 0 else "red" color = "green" if res.error_code == 0 else "red"
res2 = { res2 = {
) )
break break
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
doc_store_logger.info(
logger.info(
f"INFINITY created table {table_name}, vector size {vectorSize}" f"INFINITY created table {table_name}, vector size {vectorSize}"
) )


db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
db_instance.drop_table(table_name, ConflictType.Ignore) db_instance.drop_table(table_name, ConflictType.Ignore)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
doc_store_logger.info(f"INFINITY dropped table {table_name}")
logger.info(f"INFINITY dropped table {table_name}")


def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
table_name = f"{indexName}_{knowledgebaseId}" table_name = f"{indexName}_{knowledgebaseId}"
_ = db_instance.get_table(table_name) _ = db_instance.get_table(table_name)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
return True return True
except Exception as e:
doc_store_logger.error("INFINITY indexExist: " + str(e))
except Exception:
logger.exception("INFINITY indexExist")
return False return False


""" """
df_list.append(kb_res) df_list.append(kb_res)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
res = pl.concat(df_list) res = pl.concat(df_list)
doc_store_logger.info("INFINITY search tables: " + str(table_list))
logger.info("INFINITY search tables: " + str(table_list))
return res return res


def get( def get(
str_filter = f"id IN ({str_ids})" str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter) table_instance.delete(str_filter)
# for doc in documents: # for doc in documents:
# doc_store_logger.info(f"insert position_list: {doc['position_list']}")
# doc_store_logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
# logger.info(f"insert position_list: {doc['position_list']}")
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(documents) table_instance.insert(documents)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
doc_store_logger.info(f"inserted into {table_name} {str_ids}.") doc_store_logger.info(f"inserted into {table_name} {str_ids}.")
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
) -> bool: ) -> bool:
# if 'position_list' in newValue: # if 'position_list' in newValue:
# doc_store_logger.info(f"update position_list: {newValue['position_list']}")
# logger.info(f"upsert position_list: {newValue['position_list']}")
inf_conn = self.connPool.get_conn() inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
table_name = f"{indexName}_{knowledgebaseId}" table_name = f"{indexName}_{knowledgebaseId}"
try: try:
table_instance = db_instance.get_table(table_name) table_instance = db_instance.get_table(table_name)
except Exception: except Exception:
doc_store_logger.warning(
logger.warning(
f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist." f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist."
) )
return 0 return 0

+ 15
- 16
rag/utils/minio_conn.py View File

import os
import time import time
from minio import Minio from minio import Minio
from io import BytesIO from io import BytesIO
from rag import settings from rag import settings
from rag.settings import minio_logger
from rag.utils import singleton from rag.utils import singleton
from api.utils.log_utils import logger




@singleton @singleton
try: try:
if self.conn: if self.conn:
self.__close__() self.__close__()
except Exception as e:
except Exception:
pass pass


try: try:
secret_key=settings.MINIO["password"], secret_key=settings.MINIO["password"],
secure=False secure=False
) )
except Exception as e:
minio_logger.error(
"Fail to connect %s " % settings.MINIO["host"] + str(e))
except Exception:
logger.exception(
"Fail to connect %s " % settings.MINIO["host"])


def __close__(self): def __close__(self):
del self.conn del self.conn
len(binary) len(binary)
) )
return r return r
except Exception as e:
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}:")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)


def rm(self, bucket, fnm): def rm(self, bucket, fnm):
try: try:
self.conn.remove_object(bucket, fnm) self.conn.remove_object(bucket, fnm)
except Exception as e:
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}:")


def get(self, bucket, fnm): def get(self, bucket, fnm):
for _ in range(1): for _ in range(1):
try: try:
r = self.conn.get_object(bucket, fnm) r = self.conn.get_object(bucket, fnm)
return r.read() return r.read()
except Exception as e:
minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}:")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return
try: try:
if self.conn.stat_object(bucket, fnm):return True if self.conn.stat_object(bucket, fnm):return True
return False return False
except Exception as e:
minio_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}:")
return False return False




for _ in range(10): for _ in range(10):
try: try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires) return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e:
minio_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}:")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return

+ 4
- 5
rag/utils/redis_conn.py View File

#pipeline.expire(queue, exp) #pipeline.expire(queue, exp)
pipeline.execute() pipeline.execute()
return True return True
except Exception as e:
print(e)
logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e))
except Exception:
logging.exception("producer" + str(queue) + " got exception")
return False return False


def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload: def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload:
if 'key' in str(e): if 'key' in str(e):
pass pass
else: else:
logging.warning("[EXCEPTION]consumer: " + str(queue_name) + "||" + str(e))
logging.exception("consumer: " + str(queue_name) + " got exception")
return None return None


def get_unacked_for(self, consumer_name, queue_name, group_name): def get_unacked_for(self, consumer_name, queue_name, group_name):
except Exception as e: except Exception as e:
if 'key' in str(e): if 'key' in str(e):
return return
logging.warning("[EXCEPTION]xpending_range: " + consumer_name + "||" + str(e))
logging.exception("xpending_range: " + consumer_name + " got exception")
self.__open__() self.__open__()


REDIS_CONN = RedisDB() REDIS_CONN = RedisDB()

+ 18
- 19
rag/utils/s3_conn.py View File

from botocore.client import Config from botocore.client import Config
import time import time
from io import BytesIO from io import BytesIO
from rag.settings import s3_logger
from rag.utils import singleton from rag.utils import singleton


@singleton @singleton
try: try:
if self.conn: if self.conn:
self.__close__() self.__close__()
except Exception as e:
except Exception:
pass pass


try: try:
aws_secret_access_key=self.secret_key, aws_secret_access_key=self.secret_key,
config=config config=config
) )
except Exception as e:
s3_logger.error(
"Fail to connect %s " % self.endpoint + str(e))
except Exception:
logger.exception(
"Fail to connect %s" % self.endpoint)


def __close__(self): def __close__(self):
del self.conn del self.conn


def bucket_exists(self, bucket): def bucket_exists(self, bucket):
try: try:
s3_logger.error(f"head_bucket bucketname {bucket}")
logger.debug(f"head_bucket bucketname {bucket}")
self.conn.head_bucket(Bucket=bucket) self.conn.head_bucket(Bucket=bucket)
exists = True exists = True
except ClientError as e:
s3_logger.error(f"head_bucket error {bucket}: " + str(e))
except ClientError:
logger.exception(f"head_bucket error {bucket}")
exists = False exists = False
return exists return exists




if not self.bucket_exists(bucket): if not self.bucket_exists(bucket):
self.conn.create_bucket(Bucket=bucket) self.conn.create_bucket(Bucket=bucket)
s3_logger.error(f"create bucket {bucket} ********")
logger.debug(f"create bucket {bucket} ********")


r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm) r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)
return r return r
return [] return []


def put(self, bucket, fnm, binary): def put(self, bucket, fnm, binary):
s3_logger.error(f"bucket name {bucket}; filename :{fnm}:")
logger.debug(f"bucket name {bucket}; filename :{fnm}:")
for _ in range(1): for _ in range(1):
try: try:
if not self.bucket_exists(bucket): if not self.bucket_exists(bucket):
self.conn.create_bucket(Bucket=bucket) self.conn.create_bucket(Bucket=bucket)
s3_logger.error(f"create bucket {bucket} ********")
logger.info(f"create bucket {bucket} ********")
r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm) r = self.conn.upload_fileobj(BytesIO(binary), bucket, fnm)


return r return r
except Exception as e:
s3_logger.error(f"Fail put {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail put {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)


def rm(self, bucket, fnm): def rm(self, bucket, fnm):
try: try:
self.conn.delete_object(Bucket=bucket, Key=fnm) self.conn.delete_object(Bucket=bucket, Key=fnm)
except Exception as e:
s3_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"Fail rm {bucket}/{fnm}")


def get(self, bucket, fnm): def get(self, bucket, fnm):
for _ in range(1): for _ in range(1):
r = self.conn.get_object(Bucket=bucket, Key=fnm) r = self.conn.get_object(Bucket=bucket, Key=fnm)
object_data = r['Body'].read() object_data = r['Body'].read()
return object_data return object_data
except Exception as e:
s3_logger.error(f"fail get {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"fail get {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return
ExpiresIn=expires) ExpiresIn=expires)


return r return r
except Exception as e:
s3_logger.error(f"fail get url {bucket}/{fnm}: " + str(e))
except Exception:
logger.exception(f"fail get url {bucket}/{fnm}")
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return

Loading…
Cancel
Save