| @manager.route('/login', methods=['POST', 'GET']) | @manager.route('/login', methods=['POST', 'GET']) | ||||
| def login(): | def login(): | ||||
| userinfo = None | |||||
| login_channel = "password" | login_channel = "password" | ||||
| if session.get("access_token"): | |||||
| login_channel = session["access_token_from"] | |||||
| if session["access_token_from"] == "github": | |||||
| userinfo = user_info_from_github(session["access_token"]) | |||||
| elif not request.json: | |||||
| if not request.json: | |||||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | ||||
| retmsg='Unautherized!') | retmsg='Unautherized!') | ||||
| email = request.json.get('email') if not userinfo else userinfo["email"] | |||||
| email = request.json.get('email', "") | |||||
| users = UserService.query(email=email) | users = UserService.query(email=email) | ||||
| if not users: | |||||
| if request.json is not None: | |||||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') | |||||
| avatar = "" | |||||
| try: | |||||
| avatar = download_img(userinfo["avatar_url"]) | |||||
| except Exception as e: | |||||
| stat_logger.exception(e) | |||||
| user_id = get_uuid() | |||||
| try: | |||||
| users = user_register(user_id, { | |||||
| "access_token": session["access_token"], | |||||
| "email": userinfo["email"], | |||||
| "avatar": avatar, | |||||
| "nickname": userinfo["login"], | |||||
| "login_channel": login_channel, | |||||
| "last_login_time": get_format_time(), | |||||
| "is_superuser": False, | |||||
| }) | |||||
| if not users: raise Exception('Register user failure.') | |||||
| if len(users) > 1: raise Exception('Same E-mail exist!') | |||||
| user = users[0] | |||||
| login_user(user) | |||||
| return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!") | |||||
| except Exception as e: | |||||
| rollback_user_registration(user_id) | |||||
| stat_logger.exception(e) | |||||
| return server_error_response(e) | |||||
| elif not request.json: | |||||
| login_user(users[0]) | |||||
| return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!") | |||||
| if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') | |||||
| password = request.json.get('password') | password = request.json.get('password') | ||||
| try: | try: | ||||
| @manager.route('/github_callback', methods=['GET']) | @manager.route('/github_callback', methods=['GET']) | ||||
| def github_callback(): | def github_callback(): | ||||
| try: | |||||
| import requests | |||||
| res = requests.post(GITHUB_OAUTH.get("url"), data={ | |||||
| "client_id": GITHUB_OAUTH.get("client_id"), | |||||
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |||||
| "code": request.args.get('code') | |||||
| },headers={"Accept": "application/json"}) | |||||
| res = res.json() | |||||
| if "error" in res: | |||||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |||||
| retmsg=res["error_description"]) | |||||
| if "user:email" not in res["scope"].split(","): | |||||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope') | |||||
| session["access_token"] = res["access_token"] | |||||
| session["access_token_from"] = "github" | |||||
| return redirect(url_for("user.login"), code=307) | |||||
| import requests | |||||
| res = requests.post(GITHUB_OAUTH.get("url"), data={ | |||||
| "client_id": GITHUB_OAUTH.get("client_id"), | |||||
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |||||
| "code": request.args.get('code') | |||||
| }, headers={"Accept": "application/json"}) | |||||
| res = res.json() | |||||
| if "error" in res: | |||||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |||||
| retmsg=res["error_description"]) | |||||
| except Exception as e: | |||||
| stat_logger.exception(e) | |||||
| return server_error_response(e) | |||||
| if "user:email" not in res["scope"].split(","): | |||||
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope') | |||||
| session["access_token"] = res["access_token"] | |||||
| session["access_token_from"] = "github" | |||||
| userinfo = user_info_from_github(session["access_token"]) | |||||
| users = UserService.query(email=userinfo["email"]) | |||||
| user_id = get_uuid() | |||||
| if not users: | |||||
| try: | |||||
| try: | |||||
| avatar = download_img(userinfo["avatar_url"]) | |||||
| except Exception as e: | |||||
| stat_logger.exception(e) | |||||
| avatar = "" | |||||
| users = user_register(user_id, { | |||||
| "access_token": session["access_token"], | |||||
| "email": userinfo["email"], | |||||
| "avatar": avatar, | |||||
| "nickname": userinfo["login"], | |||||
| "login_channel": "github", | |||||
| "last_login_time": get_format_time(), | |||||
| "is_superuser": False, | |||||
| }) | |||||
| if not users: raise Exception('Register user failure.') | |||||
| if len(users) > 1: raise Exception('Same E-mail exist!') | |||||
| user = users[0] | |||||
| login_user(user) | |||||
| except Exception as e: | |||||
| rollback_user_registration(user_id) | |||||
| stat_logger.exception(e) | |||||
| return redirect("/knowledge") | |||||
| def user_info_from_github(access_token): | def user_info_from_github(access_token): | ||||
| for llm in LLMService.query(fid=LLM_FACTORY): | for llm in LLMService.query(fid=LLM_FACTORY): | ||||
| tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) | tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) | ||||
| if not UserService.insert(**user):return | |||||
| if not UserService.save(**user):return | |||||
| TenantService.insert(**tenant) | TenantService.insert(**tenant) | ||||
| UserTenantService.insert(**usr_tenant) | UserTenantService.insert(**usr_tenant) | ||||
| TenantLLMService.insert_many(tenant_llm) | TenantLLMService.insert_many(tenant_llm) |
| class ParserType(StrEnum): | class ParserType(StrEnum): | ||||
| GENERAL = "general" | |||||
| PRESENTATION = "presentation" | PRESENTATION = "presentation" | ||||
| LAWS = "laws" | LAWS = "laws" | ||||
| MANUAL = "manual" | MANUAL = "manual" |
| similarity_threshold = FloatField(default=0.2) | similarity_threshold = FloatField(default=0.2) | ||||
| vector_similarity_weight = FloatField(default=0.3) | vector_similarity_weight = FloatField(default=0.3) | ||||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value) | |||||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value) | |||||
| parser_config = JSONField(null=False, default={"pages":[[0,1000000]]}) | parser_config = JSONField(null=False, default={"pages":[[0,1000000]]}) | ||||
| status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") | ||||
| "password": "admin", | "password": "admin", | ||||
| "nickname": "admin", | "nickname": "admin", | ||||
| "is_superuser": True, | "is_superuser": True, | ||||
| "email": "kai.hu@infiniflow.org", | |||||
| "email": "admin@ragflow.io", | |||||
| "creator": "system", | "creator": "system", | ||||
| "status": "1", | "status": "1", | ||||
| } | } | ||||
| TenantService.insert(**tenant) | TenantService.insert(**tenant) | ||||
| UserTenantService.insert(**usr_tenant) | UserTenantService.insert(**usr_tenant) | ||||
| TenantLLMService.insert_many(tenant_llm) | TenantLLMService.insert_many(tenant_llm) | ||||
| print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.") | |||||
| print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.") | |||||
| chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) | ||||
| msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) | msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) |
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| from datetime import datetime | |||||
| import peewee | import peewee | ||||
| from werkzeug.security import generate_password_hash, check_password_hash | from werkzeug.security import generate_password_hash, check_password_hash | ||||
| from api.db.db_models import DB, UserTenant | from api.db.db_models import DB, UserTenant | ||||
| from api.db.db_models import User, Tenant | from api.db.db_models import User, Tenant | ||||
| from api.db.services.common_service import CommonService | from api.db.services.common_service import CommonService | ||||
| from api.utils import get_uuid, get_format_time | |||||
| from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format | |||||
| from api.db import StatusEnum | from api.db import StatusEnum | ||||
| kwargs["id"] = get_uuid() | kwargs["id"] = get_uuid() | ||||
| if "password" in kwargs: | if "password" in kwargs: | ||||
| kwargs["password"] = generate_password_hash(str(kwargs["password"])) | kwargs["password"] = generate_password_hash(str(kwargs["password"])) | ||||
| kwargs["create_time"] = current_timestamp() | |||||
| kwargs["create_date"] = datetime_format(datetime.now()) | |||||
| kwargs["update_time"] = current_timestamp() | |||||
| kwargs["update_date"] = datetime_format(datetime.now()) | |||||
| obj = cls.model(**kwargs).save(force_insert=True) | obj = cls.model(**kwargs).save(force_insert=True) | ||||
| return obj | return obj | ||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def update_user(cls, user_id, user_dict): | def update_user(cls, user_id, user_dict): | ||||
| date_time = get_format_time() | |||||
| with DB.atomic(): | with DB.atomic(): | ||||
| if user_dict: | if user_dict: | ||||
| user_dict["update_time"] = date_time | |||||
| user_dict["update_time"] = current_timestamp() | |||||
| user_dict["update_date"] = datetime_format(datetime.now()) | |||||
| cls.model.update(user_dict).where(cls.model.id == user_id).execute() | cls.model.update(user_dict).where(cls.model.id == user_id).execute() | ||||
| IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] | IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] | ||||
| API_KEY = LLM.get("api_key", "infiniflow API Key") | API_KEY = LLM.get("api_key", "infiniflow API Key") | ||||
| PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") | |||||
| PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture") | |||||
| # distribution | # distribution | ||||
| DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) | DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) |
| def __init__(self): | def __init__(self): | ||||
| self.ocr = OCR() | self.ocr = OCR() | ||||
| if not hasattr(self, "model_speciess"): | if not hasattr(self, "model_speciess"): | ||||
| self.model_speciess = ParserType.GENERAL.value | |||||
| self.model_speciess = ParserType.NAIVE.value | |||||
| self.layouter = LayoutRecognizer("layout."+self.model_speciess) | self.layouter = LayoutRecognizer("layout."+self.model_speciess) | ||||
| self.tbl_det = TableStructureRecognizer() | self.tbl_det = TableStructureRecognizer() | ||||
| "Equation", | "Equation", | ||||
| ] | ] | ||||
| def __init__(self, domain): | def __init__(self, domain): | ||||
| super().__init__(self.labels, domain, | |||||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): | def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): | ||||
| def __is_garbage(b): | def __is_garbage(b): |
| ] | ] | ||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__(self.labels, "tsr", | |||||
| os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/")) | |||||
| def __call__(self, images, thr=0.2): | def __call__(self, images, thr=0.2): | ||||
| tbls = super().__call__(images, thr) | tbls = super().__call__(images, thr) |
| import copy | import copy | ||||
| import re | import re | ||||
| from api.db import ParserType | |||||
| from rag.nlp import huqie, tokenize | from rag.nlp import huqie, tokenize | ||||
| from deepdoc.parser import PdfParser | from deepdoc.parser import PdfParser | ||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| class Pdf(PdfParser): | class Pdf(PdfParser): | ||||
| def __init__(self): | |||||
| self.model_speciess = ParserType.MANUAL.value | |||||
| super().__init__() | |||||
| def __call__(self, filename, binary=None, from_page=0, | def __call__(self, filename, binary=None, from_page=0, | ||||
| to_page=100000, zoomin=3, callback=None): | to_page=100000, zoomin=3, callback=None): | ||||
| self.__images__( | self.__images__( |
| from timeit import default_timer as timer | from timeit import default_timer as timer | ||||
| start = timer() | start = timer() | ||||
| start = timer() | |||||
| self._layouts_rec(zoomin) | self._layouts_rec(zoomin) | ||||
| callback(0.77, "Layout analysis finished") | |||||
| callback(0.5, "Layout analysis finished.") | |||||
| print("paddle layouts:", timer() - start) | |||||
| self._table_transformer_job(zoomin) | |||||
| callback(0.7, "Table analysis finished.") | |||||
| self._text_merge() | |||||
| self._concat_downward(concat_between_pages=False) | |||||
| self._filter_forpages() | |||||
| callback(0.77, "Text merging finished") | |||||
| tbls = self._extract_table_figure(True, zoomin, False) | |||||
| cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) | cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) | ||||
| self._naive_vertical_merge() | |||||
| return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes] | |||||
| #self._naive_vertical_merge() | |||||
| return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls | |||||
| def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): | ||||
| Successive text will be sliced into pieces using 'delimiter'. | Successive text will be sliced into pieces using 'delimiter'. | ||||
| Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. | Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. | ||||
| """ | """ | ||||
| eng = lang.lower() == "english"#is_english(cks) | |||||
| doc = { | doc = { | ||||
| "docnm_kwd": filename, | "docnm_kwd": filename, | ||||
| "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) | "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) | ||||
| } | } | ||||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | ||||
| res = [] | |||||
| pdf_parser = None | pdf_parser = None | ||||
| sections = [] | sections = [] | ||||
| if re.search(r"\.docx?$", filename, re.IGNORECASE): | if re.search(r"\.docx?$", filename, re.IGNORECASE): | ||||
| callback(0.8, "Finish parsing.") | callback(0.8, "Finish parsing.") | ||||
| elif re.search(r"\.pdf$", filename, re.IGNORECASE): | elif re.search(r"\.pdf$", filename, re.IGNORECASE): | ||||
| pdf_parser = Pdf() | pdf_parser = Pdf() | ||||
| sections = pdf_parser(filename if not binary else binary, | |||||
| sections, tbls = pdf_parser(filename if not binary else binary, | |||||
| from_page=from_page, to_page=to_page, callback=callback) | from_page=from_page, to_page=to_page, callback=callback) | ||||
| # add tables | |||||
| for img, rows in tbls: | |||||
| bs = 10 | |||||
| de = ";" if eng else ";" | |||||
| for i in range(0, len(rows), bs): | |||||
| d = copy.deepcopy(doc) | |||||
| r = de.join(rows[i:i + bs]) | |||||
| r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r) | |||||
| tokenize(d, r, eng) | |||||
| d["image"] = img | |||||
| res.append(d) | |||||
| elif re.search(r"\.txt$", filename, re.IGNORECASE): | elif re.search(r"\.txt$", filename, re.IGNORECASE): | ||||
| callback(0.1, "Start to parse.") | callback(0.1, "Start to parse.") | ||||
| txt = "" | txt = "" | ||||
| parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"}) | parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"}) | ||||
| cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"]) | cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"]) | ||||
| eng = lang.lower() == "english"#is_english(cks) | |||||
| res = [] | |||||
| # wrap up to es documents | # wrap up to es documents | ||||
| for ck in cks: | for ck in cks: | ||||
| print("--", ck) | print("--", ck) |
| from io import BytesIO | from io import BytesIO | ||||
| import pandas as pd | import pandas as pd | ||||
| from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture | |||||
| from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive | |||||
| from api.db import LLMType, ParserType | from api.db import LLMType, ParserType | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| BATCH_SIZE = 64 | BATCH_SIZE = 64 | ||||
| FACTORY = { | FACTORY = { | ||||
| ParserType.GENERAL.value: laws, | |||||
| ParserType.NAIVE.value: naive, | |||||
| ParserType.PAPER.value: paper, | ParserType.PAPER.value: paper, | ||||
| ParserType.BOOK.value: book, | ParserType.BOOK.value: book, | ||||
| ParserType.PRESENTATION.value: presentation, | ParserType.PRESENTATION.value: presentation, |