Procházet zdrojové kódy

refactor retieval_test, add SQl retrieval methods (#61)

tags/v0.1.0
KevinHuSh před 1 rokem
rodič
revize
5e0a689c43

+ 4
- 1
api/apps/chunk_app.py Zobrazit soubor

@@ -227,7 +227,7 @@ def retrieval_test():
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top", 1024))
top = int(req.get("top_k", 1024))
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@@ -237,6 +237,9 @@ def retrieval_test():
kb.tenant_id, LLMType.EMBEDDING.value)
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
vector_similarity_weight, top, doc_ids)
for c in ranks["chunks"]:
if "vector" in c:
del c["vector"]
return get_json_result(data=ranks)
except Exception as e:

+ 2
- 0
api/apps/conversation_app.py Zobrazit soubor

@@ -229,6 +229,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
sql = re.sub(r" +", " ", sql)
sql = re.sub(r"[;;].*", "", sql)
if sql[:len("select ")].lower() != "select ":
return None, None
if sql[:len("select *")].lower() != "select *":
@@ -241,6 +242,7 @@ def use_sql(question,field_map, tenant_id, chat_mdl):
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
# compose markdown table
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]

+ 13
- 4
api/apps/document_app.py Zobrazit soubor

@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License
#
#
import base64
import pathlib
import re
import flask
from elasticsearch_dsl import Q
@@ -27,7 +28,7 @@ from api.db.services import duplicate_name
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.db import FileType, TaskStatus
from api.db import FileType, TaskStatus, ParserType
from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.utils.api_utils import get_json_result
@@ -66,7 +67,7 @@ def upload():
location += "_"
blob = request.files['file'].read()
MINIO.put(kb_id, location, blob)
doc = DocumentService.insert({
doc = {
"id": get_uuid(),
"kb_id": kb.id,
"parser_id": kb.parser_id,
@@ -77,7 +78,12 @@ def upload():
"location": location,
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
})
}
if doc["type"] == FileType.VISUAL:
doc["parser_id"] = ParserType.PICTURE.value
if re.search(r"\.(ppt|pptx|pages)$", filename):
doc["parser_id"] = ParserType.PRESENTATION.value
doc = DocumentService.insert(doc)
return get_json_result(data=doc.to_json())
except Exception as e:
return server_error_response(e)
@@ -283,6 +289,9 @@ def change_parser():
if doc.parser_id.lower() == req["parser_id"].lower():
return get_json_result(data=True)
if doc.type == FileType.VISUAL or re.search(r"\.(ppt|pptx|pages)$", doc.name):
return get_data_error_result(retmsg="Not supported yet!")
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
if not e:
return get_data_error_result(retmsg="Document not found!")

+ 2
- 0
api/db/__init__.py Zobrazit soubor

@@ -78,3 +78,5 @@ class ParserType(StrEnum):
BOOK = "book"
QA = "qa"
TABLE = "table"
NAIVE = "naive"
PICTURE = "picture"

+ 1
- 1
api/db/db_models.py Zobrazit soubor

@@ -381,7 +381,7 @@ class Tenant(DataBaseModel):
embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
parser_ids = CharField(max_length=128, null=False, help_text="document processors")
parser_ids = CharField(max_length=256, null=False, help_text="document processors")
credit = IntegerField(default=512)
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")

+ 36
- 3
api/db/init_data.py Zobrazit soubor

@@ -63,7 +63,9 @@ def init_llm_factory():
"status": "1",
},
]
llm_infos = [{
llm_infos = [
# ---------------------- OpenAI ------------------------
{
"fid": factory_infos[0]["name"],
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
@@ -105,7 +107,9 @@ def init_llm_factory():
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},{
},
# ----------------------- Qwen -----------------------
{
"fid": factory_infos[1]["name"],
"llm_name": "qwen-turbo",
"tags": "LLM,CHAT,8K",
@@ -135,7 +139,9 @@ def init_llm_factory():
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},{
},
# ----------------------- Infiniflow -----------------------
{
"fid": factory_infos[2]["name"],
"llm_name": "gpt-3.5-turbo",
"tags": "LLM,CHAT,4K",
@@ -160,6 +166,33 @@ def init_llm_factory():
"max_tokens": 765,
"model_type": LLMType.IMAGE2TEXT.value
},
# ---------------------- ZhipuAI ----------------------
{
"fid": factory_infos[3]["name"],
"llm_name": "glm-3-turbo",
"tags": "LLM,CHAT,",
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[3]["name"],
"llm_name": "glm-4",
"tags": "LLM,CHAT,",
"max_tokens": 128 * 1000,
"model_type": LLMType.CHAT.value
}, {
"fid": factory_infos[3]["name"],
"llm_name": "glm-4v",
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 2000,
"model_type": LLMType.IMAGE2TEXT.value
},
{
"fid": factory_infos[3]["name"],
"llm_name": "embedding-2",
"tags": "TEXT EMBEDDING",
"max_tokens": 512,
"model_type": LLMType.SPEECH2TEXT.value
},
]
for info in factory_infos:
LLMFactoriesService.save(**info)

+ 1
- 1
api/settings.py Zobrazit soubor

@@ -47,7 +47,7 @@ LLM = get_base_config("llm", {})
CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
ASR_MDL = LLM.get("asr_model", "whisper-1")
PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation")
PARSERS = LLM.get("parsers", "general:General,qa:Q&A,resume:Resume,naive:Naive,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
# distribution

+ 2
- 1
rag/app/naive.py Zobrazit soubor

@@ -57,7 +57,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
callback(0.8, "Finish parsing.")
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
cks = naive_merge(sections, kwargs.get("chunk_token_num", 128), kwargs.get("delimer", "\n。;!?"))
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimer": "\n。;!?"})
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimer"])
eng = is_english(cks)
res = []
# wrap up to es documents

+ 32
- 16
rag/app/qa.py Zobrazit soubor

@@ -24,31 +24,45 @@ class Excel(object):
for i, r in enumerate(rows):
q, a = "", ""
for cell in r:
if not cell.value: continue
if not q: q = str(cell.value)
elif not a: a = str(cell.value)
else: break
if q and a: res.append((q, a))
else: fails.append(str(i+1))
if not cell.value:
continue
if not q:
q = str(cell.value)
elif not a:
a = str(cell.value)
else:
break
if q and a:
res.append((q, a))
else:
fails.append(str(i + 1))
if len(res) % 999 == 0:
callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else "")))
callback(len(res) *
0.6 /
total, ("Extract Q&A: {}".format(len(res)) +
(f"{len(fails)} failure, line: %s..." %
(",".join(fails[:3])) if fails else "")))
callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + (
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1])
self.is_english = is_english(
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
return res
def rmPrefix(txt):
return re.sub(r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE)
return re.sub(
r"^(问题|答案|回答|user|assistant|Q|A|Question|Answer|问|答)[\t:: ]+", "", txt.strip(), flags=re.IGNORECASE)
def beAdoc(d, q, a, eng):
qprefix = "Question: " if eng else "问题:"
aprefix = "Answer: " if eng else "回答:"
d["content_with_weight"] = "\t".join([qprefix+rmPrefix(q), aprefix+rmPrefix(a)])
d["content_with_weight"] = "\t".join(
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
if eng:
d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(q)])
d["content_ltks"] = " ".join([stemmer.stem(w)
for w in word_tokenize(q)])
else:
d["content_ltks"] = huqie.qie(q)
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
@@ -61,7 +75,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
excel_parser = Excel()
for q,a in excel_parser(filename, binary, callback):
for q, a in excel_parser(filename, binary, callback):
res.append(beAdoc({}, q, a, excel_parser.is_english))
return res
elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
@@ -73,7 +87,8 @@ def chunk(filename, binary=None, callback=None, **kwargs):
with open(filename, "r") as f:
while True:
l = f.readline()
if not l: break
if not l:
break
txt += l
lines = txt.split("\n")
eng = is_english([rmPrefix(l) for l in lines[:100]])
@@ -93,12 +108,13 @@ def chunk(filename, binary=None, callback=None, **kwargs):
return res
raise NotImplementedError("file type not supported yet(pptx, pdf supported)")
raise NotImplementedError(
"file type not supported yet(pptx, pdf supported)")
if __name__== "__main__":
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

+ 37
- 18
rag/app/resume.py Zobrazit soubor

@@ -11,15 +11,22 @@ from rag.utils import rmSpace
def chunk(filename, binary=None, callback=None, **kwargs):
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE): raise NotImplementedError("file type not supported yet(pdf supported)")
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
raise NotImplementedError("file type not supported yet(pdf supported)")
url = os.environ.get("INFINIFLOW_SERVER")
if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
if not url:
raise EnvironmentError(
"Please set environment variable: 'INFINIFLOW_SERVER'")
token = os.environ.get("INFINIFLOW_TOKEN")
if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
if not token:
raise EnvironmentError(
"Please set environment variable: 'INFINIFLOW_TOKEN'")
if not binary:
with open(filename, "rb") as f: binary = f.read()
with open(filename, "rb") as f:
binary = f.read()
def remote_call():
nonlocal filename, binary
for _ in range(3):
@@ -27,14 +34,17 @@ def chunk(filename, binary=None, callback=None, **kwargs):
res = requests.post(url + "/v1/layout/resume/", files=[(filename, binary)],
headers={"Authorization": token}, timeout=180)
res = res.json()
if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
if res["retcode"] != 0:
raise RuntimeError(res["retmsg"])
return res["data"]
except RuntimeError as e:
raise e
except Exception as e:
cron_logger.error("resume parsing:" + str(e))
callback(0.2, "Resume parsing is going on...")
resume = remote_call()
callback(0.6, "Done parsing. Chunking...")
print(json.dumps(resume, ensure_ascii=False, indent=2))
field_map = {
@@ -45,19 +55,19 @@ def chunk(filename, binary=None, callback=None, **kwargs):
"email_tks": "email/e-mail/邮箱",
"position_name_tks": "职位/职能/岗位/职责",
"expect_position_name_tks": "期望职位/期望职能/期望岗位",
"hightest_degree_kwd": "最高学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
"first_degree_kwd": "第一学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
"first_major_tks": "第一学历专业",
"first_school_name_tks": "第一学历毕业学校",
"edu_first_fea_kwd": "第一学历标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
"degree_kwd": "过往学历(高中,职高,硕士,本科,博士,初中,中技,中专,专科,专升本,MPA,MBA,EMBA)",
"major_tks": "学过的专业/过往专业",
"school_name_tks": "学校/毕业院校",
"sch_rank_kwd": "学校标签(顶尖学校,精英学校,优质学校,一般学校)",
"edu_fea_kwd": "教育标签(211,留学,双一流,985,海外知名,重点大学,中专,专升本,专科,本科,大专)",
"work_exp_flt": "工作年限/工作年份/N年经验/毕业了多少年",
"birth_dt": "生日/出生年份",
"corp_nm_tks": "就职过的公司/之前的公司/上过班的公司",
@@ -69,34 +79,43 @@ def chunk(filename, binary=None, callback=None, **kwargs):
titles = []
for n in ["name_kwd", "gender_kwd", "position_name_tks", "age_int"]:
v = resume.get(n, "")
if isinstance(v, list):v = v[0]
if n.find("tks") > 0: v = rmSpace(v)
if isinstance(v, list):
v = v[0]
if n.find("tks") > 0:
v = rmSpace(v)
titles.append(str(v))
doc = {
"docnm_kwd": filename,
"title_tks": huqie.qie("-".join(titles)+"-简历")
"title_tks": huqie.qie("-".join(titles) + "-简历")
}
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
pairs = []
for n,m in field_map.items():
if not resume.get(n):continue
for n, m in field_map.items():
if not resume.get(n):
continue
v = resume[n]
if isinstance(v, list):v = " ".join(v)
if n.find("tks") > 0: v = rmSpace(v)
if isinstance(v, list):
v = " ".join(v)
if n.find("tks") > 0:
v = rmSpace(v)
pairs.append((m, str(v)))
doc["content_with_weight"] = "\n".join(["{}: {}".format(re.sub(r"([^()]+)", "", k), v) for k,v in pairs])
doc["content_with_weight"] = "\n".join(
["{}: {}".format(re.sub(r"([^()]+)", "", k), v) for k, v in pairs])
doc["content_ltks"] = huqie.qie(doc["content_with_weight"])
doc["content_sm_ltks"] = huqie.qieqie(doc["content_ltks"])
for n, _ in field_map.items(): doc[n] = resume[n]
for n, _ in field_map.items():
doc[n] = resume[n]
print(doc)
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map})
KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": field_map})
return [doc]
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

+ 40
- 18
rag/app/table.py Zobrazit soubor

@@ -28,10 +28,15 @@ class Excel(object):
rows = list(ws.rows)
headers = [cell.value for cell in rows[0]]
missed = set([i for i, h in enumerate(headers) if h is None])
headers = [cell.value for i, cell in enumerate(rows[0]) if i not in missed]
headers = [
cell.value for i,
cell in enumerate(
rows[0]) if i not in missed]
data = []
for i, r in enumerate(rows[1:]):
row = [cell.value for ii, cell in enumerate(r) if ii not in missed]
row = [
cell.value for ii,
cell in enumerate(r) if ii not in missed]
if len(row) != len(headers):
fails.append(str(i))
continue
@@ -55,8 +60,10 @@ def trans_datatime(s):
def trans_bool(s):
if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", "是"]
if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", "否"]
if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE):
return ["yes", "是"]
if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE):
return ["no", "否"]
def column_data_type(arr):
@@ -65,7 +72,8 @@ def column_data_type(arr):
trans = {t: f for f, t in
[(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
for a in arr:
if a is None: continue
if a is None:
continue
if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")):
counts["int"] += 1
elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")):
@@ -79,7 +87,8 @@ def column_data_type(arr):
counts = sorted(counts.items(), key=lambda x: x[1] * -1)
ty = counts[0][0]
for i in range(len(arr)):
if arr[i] is None: continue
if arr[i] is None:
continue
try:
arr[i] = trans[ty](str(arr[i]))
except Exception as e:
@@ -105,7 +114,8 @@ def chunk(filename, binary=None, callback=None, **kwargs):
with open(filename, "r") as f:
while True:
l = f.readline()
if not l: break
if not l:
break
txt += l
lines = txt.split("\n")
fails = []
@@ -127,14 +137,22 @@ def chunk(filename, binary=None, callback=None, **kwargs):
dfs = [pd.DataFrame(np.array(rows), columns=headers)]
else:
raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
raise NotImplementedError(
"file type not supported yet(excel, text, csv supported)")
res = []
PY = Pinyin()
fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
fieds_map = {
"text": "_tks",
"int": "_int",
"keyword": "_kwd",
"float": "_flt",
"datetime": "_dt",
"bool": "_kwd"}
for df in dfs:
for n in ["id", "_id", "index", "idx"]:
if n in df.columns: del df[n]
if n in df.columns:
del df[n]
clmns = df.columns.values
txts = list(copy.deepcopy(clmns))
py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
@@ -143,23 +161,29 @@ def chunk(filename, binary=None, callback=None, **kwargs):
cln, ty = column_data_type(df[clmns[j]])
clmn_tys.append(ty)
df[clmns[j]] = cln
if ty == "text": txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))]
if ty == "text":
txts.extend([str(c) for c in cln if c])
clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j])
for i in range(len(clmns))]
eng = is_english(txts)
for ii, row in df.iterrows():
d = {}
row_txt = []
for j in range(len(clmns)):
if row[clmns[j]] is None: continue
if row[clmns[j]] is None:
continue
fld = clmns_map[j][0]
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]])
d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(
row[clmns[j]])
row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
if not row_txt: continue
if not row_txt:
continue
tokenize(d, "; ".join(row_txt), eng)
res.append(d)
KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
KnowledgebaseService.update_parser_config(
kwargs["kb_id"], {"field_map": {k: v for k, v in clmns_map}})
callback(0.6, "")
return res
@@ -168,9 +192,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
if __name__ == "__main__":
import sys
def dummy(a, b):
pass
chunk(sys.argv[1], callback=dummy)

+ 18
- 0
rag/llm/chat_model.py Zobrazit soubor

@@ -58,3 +58,21 @@ class QWenChat(Base):
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0


from zhipuai import ZhipuAI
class ZhipuChat(Base):
def __init__(self, key, model_name="glm-3-turbo"):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name

def chat(self, system, history, gen_conf):
from http import HTTPStatus
history.insert(0, {"role": "system", "content": system})
response = self.client.chat.completions.create(
self.model_name,
messages=history
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
return response.message, 0

+ 20
- 1
rag/llm/cv_model.py Zobrazit soubor

@@ -61,7 +61,7 @@ class Base(ABC):

class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview"):
self.client = OpenAI(api_key = key)
self.client = OpenAI(api_key=key)
self.model_name = model_name

def describe(self, image, max_tokens=300):
@@ -89,3 +89,22 @@ class QWenCV(Base):
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0


from zhipuai import ZhipuAI


class Zhipu4V(Base):
def __init__(self, key, model_name="glm-4v"):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name

def describe(self, image, max_tokens=1024):
b64 = self.image2base64(image)

res = self.client.chat.completions.create(
model=self.model_name,
messages=self.prompt(b64),
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens

+ 18
- 2
rag/llm/embedding_model.py Zobrazit soubor

@@ -19,7 +19,6 @@ import dashscope
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import os
import numpy as np

from rag.utils import num_tokens_from_string
@@ -114,4 +113,21 @@ class QWenEmbed(Base):
input=text[:2048],
text_type="query"
)
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]


from zhipuai import ZhipuAI
class ZhipuEmbed(Base):
def __init__(self, key, model_name="embedding-2"):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name

def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts,
model=self.model_name)
return np.array([d.embedding for d in res.data]), res.usage.total_tokens

def encode_queries(self, text):
res = self.client.embeddings.create(input=text,
model=self.model_name)
return np.array(res["data"][0]["embedding"]), res.usage.total_tokens

+ 5
- 3
rag/nlp/search.py Zobrazit soubor

@@ -268,9 +268,9 @@ class Dealer:
dim = len(sres.query_vector)
start_idx = (page - 1) * page_size
for i in idx:
ranks["total"] += 1
if sim[i] < similarity_threshold:
break
ranks["total"] += 1
start_idx -= 1
if start_idx >= 0:
continue
@@ -280,6 +280,7 @@ class Dealer:
break
id = sres.ids[i]
dnm = sres.field[id]["docnm_kwd"]
did = sres.field[id]["doc_id"]
d = {
"chunk_id": id,
"content_ltks": sres.field[id]["content_ltks"],
@@ -296,8 +297,9 @@ class Dealer:
}
ranks["chunks"].append(d)
if dnm not in ranks["doc_aggs"]:
ranks["doc_aggs"][dnm] = 0
ranks["doc_aggs"][dnm] += 1
ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
ranks["doc_aggs"][dnm]["count"] += 1
ranks["doc_aggs"] = [{"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} for k,v in sorted(ranks["doc_aggs"].items(), key=lambda x:x[1]["count"]*-1)]

return ranks


+ 7
- 5
rag/svr/task_executor.py Zobrazit soubor

@@ -36,7 +36,7 @@ from rag.nlp import search
from io import BytesIO
import pandas as pd

from rag.app import laws, paper, presentation, manual, qa, table,book
from rag.app import laws, paper, presentation, manual, qa, table, book, resume

from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
@@ -55,6 +55,7 @@ FACTORY = {
ParserType.LAWS.value: laws,
ParserType.QA.value: qa,
ParserType.TABLE.value: table,
ParserType.RESUME.value: resume,
}


@@ -119,7 +120,7 @@ def build(row, cvmdl):
try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
callback, kb_id=row["kb_id"])
callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
except Exception as e:
if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s>" % row["doc_name"])
@@ -171,7 +172,7 @@ def init_kb(row):
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))


def embedding(docs, mdl):
def embedding(docs, mdl, parser_config={}):
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
tk_count = 0
if len(tts) == len(cnts):
@@ -180,7 +181,8 @@ def embedding(docs, mdl):

cnts, c = mdl.encode(cnts)
tk_count += c
vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts
title_w = float(parser_config.get("filename_embd_weight", 0.1))
vects = (title_w * tts + (1-title_w) * cnts) if len(tts) == len(cnts) else cnts

assert len(vects) == len(docs)
for i, d in enumerate(docs):
@@ -216,7 +218,7 @@ def main(comm, mod):
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
try:
tk_count = embedding(cks, embd_mdl)
tk_count = embedding(cks, embd_mdl, r["parser_config"])
except Exception as e:
callback(-1, "Embedding error:{}".format(str(e)))
cron_logger.error(str(e))

Načítá se…
Zrušit
Uložit