瀏覽代碼

fix user login issue (#85)

tags/v0.1.0
KevinHuSh 1 年之前
父節點
當前提交
0429107e80
No account linked to committer's email address

+ 47
- 60
api/apps/user_app.py 查看文件

@manager.route('/login', methods=['POST', 'GET']) @manager.route('/login', methods=['POST', 'GET'])
def login(): def login():
userinfo = None
login_channel = "password" login_channel = "password"
if session.get("access_token"):
login_channel = session["access_token_from"]
if session["access_token_from"] == "github":
userinfo = user_info_from_github(session["access_token"])
elif not request.json:
if not request.json:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg='Unautherized!') retmsg='Unautherized!')
email = request.json.get('email') if not userinfo else userinfo["email"]
email = request.json.get('email', "")
users = UserService.query(email=email) users = UserService.query(email=email)
if not users:
if request.json is not None:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
avatar = ""
try:
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
user_id = get_uuid()
try:
users = user_register(user_id, {
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
"nickname": userinfo["login"],
"login_channel": login_channel,
"last_login_time": get_format_time(),
"is_superuser": False,
})
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return server_error_response(e)
elif not request.json:
login_user(users[0])
return cors_reponse(data=users[0].to_json(), auth=users[0].get_id(), retmsg="Welcome back!")
if not users: return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!')
password = request.json.get('password') password = request.json.get('password')
try: try:
@manager.route('/github_callback', methods=['GET']) @manager.route('/github_callback', methods=['GET'])
def github_callback(): def github_callback():
try:
import requests
res = requests.post(GITHUB_OAUTH.get("url"), data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"code": request.args.get('code')
},headers={"Accept": "application/json"})
res = res.json()
if "error" in res:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg=res["error_description"])
if "user:email" not in res["scope"].split(","):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
return redirect(url_for("user.login"), code=307)
import requests
res = requests.post(GITHUB_OAUTH.get("url"), data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"code": request.args.get('code')
}, headers={"Accept": "application/json"})
res = res.json()
if "error" in res:
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR,
retmsg=res["error_description"])
except Exception as e:
stat_logger.exception(e)
return server_error_response(e)
if "user:email" not in res["scope"].split(","):
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='user:email not in scope')
session["access_token"] = res["access_token"]
session["access_token_from"] = "github"
userinfo = user_info_from_github(session["access_token"])
users = UserService.query(email=userinfo["email"])
user_id = get_uuid()
if not users:
try:
try:
avatar = download_img(userinfo["avatar_url"])
except Exception as e:
stat_logger.exception(e)
avatar = ""
users = user_register(user_id, {
"access_token": session["access_token"],
"email": userinfo["email"],
"avatar": avatar,
"nickname": userinfo["login"],
"login_channel": "github",
"last_login_time": get_format_time(),
"is_superuser": False,
})
if not users: raise Exception('Register user failure.')
if len(users) > 1: raise Exception('Same E-mail exist!')
user = users[0]
login_user(user)
except Exception as e:
rollback_user_registration(user_id)
stat_logger.exception(e)
return redirect("/knowledge")
def user_info_from_github(access_token): def user_info_from_github(access_token):
for llm in LLMService.query(fid=LLM_FACTORY): for llm in LLMService.query(fid=LLM_FACTORY):
tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY}) tenant_llm.append({"tenant_id": user_id, "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type":llm.model_type, "api_key": API_KEY})
if not UserService.insert(**user):return
if not UserService.save(**user):return
TenantService.insert(**tenant) TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant) UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm) TenantLLMService.insert_many(tenant_llm)

+ 0
- 1
api/db/__init__.py 查看文件

class ParserType(StrEnum): class ParserType(StrEnum):
GENERAL = "general"
PRESENTATION = "presentation" PRESENTATION = "presentation"
LAWS = "laws" LAWS = "laws"
MANUAL = "manual" MANUAL = "manual"

+ 1
- 1
api/db/db_models.py 查看文件

similarity_threshold = FloatField(default=0.2) similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3) vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value)
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]}) parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1") status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")

+ 2
- 2
api/db/init_data.py 查看文件

"password": "admin", "password": "admin",
"nickname": "admin", "nickname": "admin",
"is_superuser": True, "is_superuser": True,
"email": "kai.hu@infiniflow.org",
"email": "admin@ragflow.io",
"creator": "system", "creator": "system",
"status": "1", "status": "1",
} }
TenantService.insert(**tenant) TenantService.insert(**tenant)
UserTenantService.insert(**usr_tenant) UserTenantService.insert(**usr_tenant)
TenantLLMService.insert_many(tenant_llm) TenantLLMService.insert_many(tenant_llm)
print("【INFO】Super user initialized. \033[93muser name: admin, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
print("【INFO】Super user initialized. \033[93memail: admin@ragflow.io, password: admin\033[0m. Changing the password after logining is strongly recomanded.")
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}) msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})

+ 10
- 3
api/db/services/user_service.py 查看文件

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from datetime import datetime

import peewee import peewee
from werkzeug.security import generate_password_hash, check_password_hash from werkzeug.security import generate_password_hash, check_password_hash


from api.db.db_models import DB, UserTenant from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant from api.db.db_models import User, Tenant
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.db import StatusEnum from api.db import StatusEnum




kwargs["id"] = get_uuid() kwargs["id"] = get_uuid()
if "password" in kwargs: if "password" in kwargs:
kwargs["password"] = generate_password_hash(str(kwargs["password"])) kwargs["password"] = generate_password_hash(str(kwargs["password"]))

kwargs["create_time"] = current_timestamp()
kwargs["create_date"] = datetime_format(datetime.now())
kwargs["update_time"] = current_timestamp()
kwargs["update_date"] = datetime_format(datetime.now())
obj = cls.model(**kwargs).save(force_insert=True) obj = cls.model(**kwargs).save(force_insert=True)
return obj return obj


@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def update_user(cls, user_id, user_dict): def update_user(cls, user_id, user_dict):
date_time = get_format_time()
with DB.atomic(): with DB.atomic():
if user_dict: if user_dict:
user_dict["update_time"] = date_time
user_dict["update_time"] = current_timestamp()
user_dict["update_date"] = datetime_format(datetime.now())
cls.model.update(user_dict).where(cls.model.id == user_id).execute() cls.model.update(user_dict).where(cls.model.id == user_id).execute()





+ 1
- 1
api/settings.py 查看文件

IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
API_KEY = LLM.get("api_key", "infiniflow API Key") API_KEY = LLM.get("api_key", "infiniflow API Key")
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
# distribution # distribution
DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False) DEPENDENT_DISTRIBUTION = get_base_config("dependent_distribution", False)

+ 1
- 1
deepdoc/parser/pdf_parser.py 查看文件

def __init__(self): def __init__(self):
self.ocr = OCR() self.ocr = OCR()
if not hasattr(self, "model_speciess"): if not hasattr(self, "model_speciess"):
self.model_speciess = ParserType.GENERAL.value
self.model_speciess = ParserType.NAIVE.value
self.layouter = LayoutRecognizer("layout."+self.model_speciess) self.layouter = LayoutRecognizer("layout."+self.model_speciess)
self.tbl_det = TableStructureRecognizer() self.tbl_det = TableStructureRecognizer()



+ 1
- 2
deepdoc/vision/layout_recognizer.py 查看文件

"Equation", "Equation",
] ]
def __init__(self, domain): def __init__(self, domain):
super().__init__(self.labels, domain,
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
super().__init__(self.labels, domain) #, os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16): def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
def __is_garbage(b): def __is_garbage(b):

+ 1
- 2
deepdoc/vision/table_structure_recognizer.py 查看文件

] ]
def __init__(self): def __init__(self):
super().__init__(self.labels, "tsr",
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
super().__init__(self.labels, "tsr")#,os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, images, thr=0.2): def __call__(self, images, thr=0.2):
tbls = super().__call__(images, thr) tbls = super().__call__(images, thr)

+ 6
- 0
rag/app/manual.py 查看文件

import copy import copy
import re import re
from api.db import ParserType
from rag.nlp import huqie, tokenize from rag.nlp import huqie, tokenize
from deepdoc.parser import PdfParser from deepdoc.parser import PdfParser
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
class Pdf(PdfParser): class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.MANUAL.value
super().__init__()
def __call__(self, filename, binary=None, from_page=0, def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None): to_page=100000, zoomin=3, callback=None):
self.__images__( self.__images__(

+ 29
- 6
rag/app/naive.py 查看文件

from timeit import default_timer as timer from timeit import default_timer as timer
start = timer() start = timer()
start = timer()
self._layouts_rec(zoomin) self._layouts_rec(zoomin)
callback(0.77, "Layout analysis finished")
callback(0.5, "Layout analysis finished.")
print("paddle layouts:", timer() - start)
self._table_transformer_job(zoomin)
callback(0.7, "Table analysis finished.")
self._text_merge()
self._concat_downward(concat_between_pages=False)
self._filter_forpages()
callback(0.77, "Text merging finished")
tbls = self._extract_table_figure(True, zoomin, False)
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1))) cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
self._naive_vertical_merge()
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
#self._naive_vertical_merge()
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs): def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
Successive text will be sliced into pieces using 'delimiter'. Successive text will be sliced into pieces using 'delimiter'.
Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'. Next, these successive pieces are merge into chunks whose token number is no more than 'Max token number'.
""" """
eng = lang.lower() == "english"#is_english(cks)
doc = { doc = {
"docnm_kwd": filename, "docnm_kwd": filename,
"title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename)) "title_tks": huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", filename))
} }
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
res = []
pdf_parser = None pdf_parser = None
sections = [] sections = []
if re.search(r"\.docx?$", filename, re.IGNORECASE): if re.search(r"\.docx?$", filename, re.IGNORECASE):
callback(0.8, "Finish parsing.") callback(0.8, "Finish parsing.")
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() pdf_parser = Pdf()
sections = pdf_parser(filename if not binary else binary,
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback) from_page=from_page, to_page=to_page, callback=callback)
# add tables
for img, rows in tbls:
bs = 10
de = ";" if eng else ";"
for i in range(0, len(rows), bs):
d = copy.deepcopy(doc)
r = de.join(rows[i:i + bs])
r = re.sub(r"\t——(来自| in ).*”%s" % de, "", r)
tokenize(d, r, eng)
d["image"] = img
res.append(d)
elif re.search(r"\.txt$", filename, re.IGNORECASE): elif re.search(r"\.txt$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.") callback(0.1, "Start to parse.")
txt = "" txt = ""
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"}) parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"]) cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
eng = lang.lower() == "english"#is_english(cks)
res = []
# wrap up to es documents # wrap up to es documents
for ck in cks: for ck in cks:
print("--", ck) print("--", ck)

+ 2
- 2
rag/svr/task_executor.py 查看文件

from io import BytesIO from io import BytesIO
import pandas as pd import pandas as pd


from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive


from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
BATCH_SIZE = 64 BATCH_SIZE = 64


FACTORY = { FACTORY = {
ParserType.GENERAL.value: laws,
ParserType.NAIVE.value: naive,
ParserType.PAPER.value: paper, ParserType.PAPER.value: paper,
ParserType.BOOK.value: book, ParserType.BOOK.value: book,
ParserType.PRESENTATION.value: presentation, ParserType.PRESENTATION.value: presentation,

Loading…
取消
儲存