| @@ -33,49 +33,14 @@ from api.utils.api_utils import get_json_result, cors_reponse | |||
| @manager.route('/login', methods=['POST', 'GET']) | |||
| def login(): | |||
| userinfo = None | |||
| login_channel = "password" | |||
| if session.get("access_token"): | |||
| login_channel = session["access_token_from"] | |||
| if session["access_token_from"] == "github": | |||
| userinfo = user_info_from_github(session["access_token"]) | |||
| elif not request.json: | |||
| if not request.json: | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |||
| retmsg='Unautherized!') | |||
| email = request.json.get('email') if not userinfo else userinfo["email"] | |||
| email = request.json.get('email', "") | |||
| users = UserService.query(email=email) | |||
| if not users: | |||
| if request.json is not None: | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') | |||
| avatar = "" | |||
| try: | |||
| avatar = download_img(userinfo["avatar_url"]) | |||
| except Exception as e: | |||
| stat_logger.exception(e) | |||
| user_id = get_uuid() | |||
| try: | |||
| users = user_register(user_id, { | |||
| "access_token": session["access_token"], | |||
| "email": userinfo["email"], | |||
| "avatar": avatar, | |||
| "nickname": userinfo["login"], | |||
| "login_channel": login_channel, | |||
| "last_login_time": get_format_time(), | |||
| "is_superuser": False, | |||
| }) | |||
| if not users: raise Exception('Register user failure.') | |||
| if len(users) > 1: raise Exception('Same E-mail exist!') | |||
| user = users[0] | |||
| login_user(user) | |||
| return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") | |||
| except Exception as e: | |||
| rollback_user_registration(user_id) | |||
| stat_logger.exception(e) | |||
| return server_error_response(e) | |||
| elif not request.json: | |||
| login_user(users[0]) | |||
| return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!") | |||
| if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') | |||
| password = request.json.get('password') | |||
| try: | |||
| @@ -97,28 +62,50 @@ def login(): | |||
| @manager.route('/github_callback', methods=['GET']) | |||
| def github_callback(): | |||
| try: | |||
| import requests | |||
| res = requests.post(GITHUB_OAUTH.get("url"), data={ | |||
| "client_id": GITHUB_OAUTH.get("client_id"), | |||
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |||
| "code": request.args.get('code') | |||
| },headers={"Accept": "application/json"}) | |||
| res = res.json() | |||
| if "error" in res: | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |||
| retmsg=res["error_description"]) | |||
| if "user:email" not in res["scope"].split(","): | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope') | |||
| session["access_token"] = res["access_token"] | |||
| session["access_token_from"] = "github" | |||
| return redirect(url_for("user.login"), code=307) | |||
| import requests | |||
| res = requests.post(GITHUB_OAUTH.get("url"), data={ | |||
| "client_id": GITHUB_OAUTH.get("client_id"), | |||
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |||
| "code": request.args.get('code') | |||
| }, headers={"Accept": "application/json"}) | |||
| res = res.json() | |||
| if "error" in res: | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |||
| retmsg=res["error_description"]) | |||
| except Exception as e: | |||
| stat_logger.exception(e) | |||
| return server_error_response(e) | |||
| if "user:email" not in res["scope"].split(","): | |||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope') | |||
| session["access_token"] = res["access_token"] | |||
| session["access_token_from"] = "github" | |||
| userinfo = user_info_from_github(session["access_token"]) | |||
| users = UserService.query(email=userinfo["email"]) | |||
| user_id = get_uuid() | |||
| if not users: | |||
| try: | |||
| try: | |||
| avatar = download_img(userinfo["avatar_url"]) | |||
| except Exception as e: | |||
| stat_logger.exception(e) | |||
| avatar = "" | |||
| users = user_register(user_id, { | |||
| "access_token": session["access_token"], | |||
| "email": userinfo["email"], | |||
| "avatar": avatar, | |||
| "nickname": userinfo["login"], | |||
| "login_channel": "github", | |||
| "last_login_time": get_format_time(), | |||
| "is_superuser": False, | |||
| }) | |||
| if not users: raise Exception('Register user failure.') | |||
| if len(users) > 1: raise Exception('Same E-mail exist!') | |||
| user = users[0] | |||
| login_user(user) | |||
| except Exception as e: | |||
| rollback_user_registration(user_id) | |||
| stat_logger.exception(e) | |||
| return redirect("/knowledge") | |||
| def user_info_from_github(access_token): | |||
| @@ -208,7 +195,7 @@ def user_register(user_id, user): | |||
| for llm in LLMService.query(fid=LLM_FACTORY): | |||
| tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) | |||
| if not UserService.insert(**user):return | |||
| if not UserService.save(**user):return | |||
| TenantService.insert(**tenant) | |||
| UserTenantService.insert(**usr_tenant) | |||
| TenantLLMService.insert_many(tenant_llm) | |||
| @@ -69,7 +69,6 @@ class TaskStatus(StrEnum): | |||
| class ParserType(StrEnum): | |||
| GENERAL = "general" | |||
| PRESENTATION = "presentation" | |||
| LAWS = "laws" | |||
| MANUAL = "manual" | |||
| @@ -475,7 +475,7 @@ class Knowledgebase(DataBaseModel): | |||
| similarity_threshold = FloatField(default=0.2) | |||
| vector_similarity_weight = FloatField(default=0.3) | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value) | |||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value) | |||
| parser_config = JSONField(null=False, default={"pages":[[0,1000000]]}) | |||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | |||
| @@ -30,7 +30,7 @@ def init_superuser(): | |||
| "password": "admin", | |||
| "nickname": "admin", | |||
| "is_superuser": True, | |||
| "email": "kai.hu@infiniflow.org", | |||
| "email": "admin@ragflow.io", | |||
| "creator": "system", | |||
| "status": "1", | |||
| } | |||
| @@ -61,7 +61,7 @@ def init_superuser(): | |||
| TenantService.insert(**tenant) | |||
| UserTenantService.insert(**usr_tenant) | |||
| TenantLLMService.insert_many(tenant_llm) | |||
| print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.") | |||
| print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.") | |||
| chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | |||
| msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) | |||
| @@ -13,6 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from datetime import datetime | |||
| import peewee | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| @@ -20,7 +22,7 @@ from api.db import UserTenantRole | |||
| from api.db.db_models import DB, UserTenant | |||
| from api.db.db_models import User, Tenant | |||
| from api.db.services.common_service import CommonService | |||
| from api.utils import get_uuid, get_format_time | |||
| from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format | |||
| from api.db import StatusEnum | |||
| @@ -53,6 +55,11 @@ class UserService(CommonService): | |||
| kwargs["id"] = get_uuid() | |||
| if "password" in kwargs: | |||
| kwargs["password"] = generate_password_hash(str(kwargs["password"])) | |||
| kwargs["create_time"] = current_timestamp() | |||
| kwargs["create_date"] = datetime_format(datetime.now()) | |||
| kwargs["update_time"] = current_timestamp() | |||
| kwargs["update_date"] = datetime_format(datetime.now()) | |||
| obj = cls.model(**kwargs).save(force_insert=True) | |||
| return obj | |||
| @@ -66,10 +73,10 @@ class UserService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def update_user(cls, user_id, user_dict): | |||
| date_time = get_format_time() | |||
| with DB.atomic(): | |||
| if user_dict: | |||
| user_dict["update_time"] = date_time | |||
| user_dict["update_time"] = current_timestamp() | |||
| user_dict["update_date"] = datetime_format(datetime.now()) | |||
| cls.model.update(user_dict).where(cls.model.id == user_id).execute() | |||
| @@ -76,7 +76,7 @@ ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] | |||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] | |||
| API_KEY = LLM.get("api_key", "infiniflow API Key") | |||
| PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") | |||
| PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") | |||
| # distribution | |||
| DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) | |||
| @@ -25,7 +25,7 @@ class HuParser: | |||
| def __init__(self): | |||
| self.ocr = OCR() | |||
| if not hasattr(self, "model_speciess"): | |||
| self.model_speciess = ParserType.GENERAL.value | |||
| self.model_speciess = ParserType.NAIVE.value | |||
| self.layouter = LayoutRecognizer("layout."+self.model_speciess) | |||
| self.tbl_det = TableStructureRecognizer() | |||
| @@ -34,8 +34,7 @@ class LayoutRecognizer(Recognizer): | |||
| "Equation", | |||
| ] | |||
| def __init__(self, domain): | |||
| super().__init__(self.labels, domain, | |||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): | |||
| def __is_garbage(b): | |||
| @@ -33,8 +33,7 @@ class TableStructureRecognizer(Recognizer): | |||
| ] | |||
| def __init__(self): | |||
| super().__init__(self.labels, "tsr", | |||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||
| def __call__(self, images, thr=0.2): | |||
| tbls = super().__call__(images, thr) | |||
| @@ -1,11 +1,17 @@ | |||
| import copy | |||
| import re | |||
| from api.db import ParserType | |||
| from rag.nlp import huqie, tokenize | |||
| from deepdoc.parser import PdfParser | |||
| from rag.utils import num_tokens_from_string | |||
| class Pdf(PdfParser): | |||
| def __init__(self): | |||
| self.model_speciess = ParserType.MANUAL.value | |||
| super().__init__() | |||
| def __call__(self, filename, binary=None, from_page=0, | |||
| to_page=100000, zoomin=3, callback=None): | |||
| self.__images__( | |||
| @@ -30,11 +30,21 @@ class Pdf(PdfParser): | |||
| from timeit import default_timer as timer | |||
| start = timer() | |||
| start = timer() | |||
| self._layouts_rec(zoomin) | |||
| callback(0.77, "Layout analysis finished") | |||
| callback(0.5, "Layout analysis finished.") | |||
| print("paddle layouts:", timer() - start) | |||
| self._table_transformer_job(zoomin) | |||
| callback(0.7, "Table analysis finished.") | |||
| self._text_merge() | |||
| self._concat_downward(concat_between_pages=False) | |||
| self._filter_forpages() | |||
| callback(0.77, "Text merging finished") | |||
| tbls = self._extract_table_figure(True, zoomin, False) | |||
| cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) | |||
| self._naive_vertical_merge() | |||
| return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes] | |||
| #self._naive_vertical_merge() | |||
| return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls | |||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | |||
| @@ -44,11 +54,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| Successive text will be sliced into pieces using 'delimiter'. | |||
| Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. | |||
| """ | |||
| eng = lang.lower() == "english"#is_english(cks) | |||
| doc = { | |||
| "docnm_kwd": filename, | |||
| "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) | |||
| } | |||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | |||
| res = [] | |||
| pdf_parser = None | |||
| sections = [] | |||
| if re.search(r"\.docx?$", filename, re.IGNORECASE): | |||
| @@ -58,8 +71,19 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| callback(0.8, "Finish parsing.") | |||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | |||
| pdf_parser = Pdf() | |||
| sections = pdf_parser(filename if not binary else binary, | |||
| sections, tbls = pdf_parser(filename if not binary else binary, | |||
| from_page=from_page, to_page=to_page, callback=callback) | |||
| # add tables | |||
| for img, rows in tbls: | |||
| bs = 10 | |||
| de = ";" if eng else ";" | |||
| for i in range(0, len(rows), bs): | |||
| d = copy.deepcopy(doc) | |||
| r = de.join(rows[i:i + bs]) | |||
| r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r) | |||
| tokenize(d, r, eng) | |||
| d["image"] = img | |||
| res.append(d) | |||
| elif re.search(r"\.txt$", filename, re.IGNORECASE): | |||
| callback(0.1, "Start to parse.") | |||
| txt = "" | |||
| @@ -79,8 +103,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca | |||
| parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"}) | |||
| cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"]) | |||
| eng = lang.lower() == "english"#is_english(cks) | |||
| res = [] | |||
| # wrap up to es documents | |||
| for ck in cks: | |||
| print("--", ck) | |||
| @@ -37,7 +37,7 @@ from rag.nlp import search | |||
| from io import BytesIO | |||
| import pandas as pd | |||
| from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture | |||
| from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive | |||
| from api.db import LLMType, ParserType | |||
| from api.db.services.document_service import DocumentService | |||
| @@ -48,7 +48,7 @@ from api.utils.file_utils import get_project_base_directory | |||
| BATCH_SIZE = 64 | |||
| FACTORY = { | |||
| ParserType.GENERAL.value: laws, | |||
| ParserType.NAIVE.value: naive, | |||
| ParserType.PAPER.value: paper, | |||
| ParserType.BOOK.value: book, | |||
| ParserType.PRESENTATION.value: presentation, | |||