| question = req["question"] | question = req["question"] | ||||
| kb_id = req["kb_id"] | kb_id = req["kb_id"] | ||||
| doc_ids = req.get("doc_ids", []) | doc_ids = req.get("doc_ids", []) | ||||
| similarity_threshold = float(req.get("similarity_threshold", 0.4)) | |||||
| similarity_threshold = float(req.get("similarity_threshold", 0.2)) | |||||
| vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | ||||
| top = int(req.get("top", 1024)) | top = int(req.get("top", 1024)) | ||||
| try: | try: |
| if p["key"] not in kwargs: | if p["key"] not in kwargs: | ||||
| prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ") | prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ") | ||||
| model_config = TenantLLMService.get_api_key(dialog.tenant_id, LLMType.CHAT.value, dialog.llm_id) | |||||
| model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id) | |||||
| if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id)) | if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id)) | ||||
| question = messages[-1]["content"] | question = messages[-1]["content"] | ||||
| kwargs["knowledge"] = "\n".join(knowledges) | kwargs["knowledge"] = "\n".join(knowledges) | ||||
| gen_conf = dialog.llm_setting[dialog.llm_setting_type] | gen_conf = dialog.llm_setting[dialog.llm_setting_type] | ||||
| msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] | msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"] | ||||
| used_token_count = message_fit_in(msg, int(llm.max_tokens * 0.97)) | |||||
| used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97)) | |||||
| if "max_tokens" in gen_conf: | if "max_tokens" in gen_conf: | ||||
| gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count) | ||||
| mdl = ChatModel[model_config.llm_factory](model_config["api_key"], dialog.llm_id) | |||||
| mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id) | |||||
| answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf) | ||||
| answer = retrievaler.insert_citations(answer, | answer = retrievaler.insert_citations(answer, | ||||
| embd_mdl, | embd_mdl, | ||||
| tkweight=1-dialog.vector_similarity_weight, | tkweight=1-dialog.vector_similarity_weight, | ||||
| vtweight=dialog.vector_similarity_weight) | vtweight=dialog.vector_similarity_weight) | ||||
| for c in kbinfos["chunks"]: | |||||
| if c.get("vector"):del c["vector"] | |||||
| return {"answer": answer, "retrieval": kbinfos} | return {"answer": answer, "retrieval": kbinfos} |
| # distributed under the License is distributed on an "AS IS" BASIS, | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | |||||
| # limitations under the License | |||||
| # | |||||
| # | # | ||||
| import base64 | import base64 | ||||
| import pathlib | import pathlib | ||||
| while MINIO.obj_exist(kb_id, location): | while MINIO.obj_exist(kb_id, location): | ||||
| location += "_" | location += "_" | ||||
| blob = request.files['file'].read() | blob = request.files['file'].read() | ||||
| MINIO.put(kb_id, filename, blob) | |||||
| MINIO.put(kb_id, location, blob) | |||||
| doc = DocumentService.insert({ | doc = DocumentService.insert({ | ||||
| "id": get_uuid(), | "id": get_uuid(), | ||||
| "kb_id": kb.id, | "kb_id": kb.id, | ||||
| e, doc = DocumentService.get_by_id(req["doc_id"]) | e, doc = DocumentService.get_by_id(req["doc_id"]) | ||||
| if not e: | if not e: | ||||
| return get_data_error_result(retmsg="Document not found!") | return get_data_error_result(retmsg="Document not found!") | ||||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)) | |||||
| tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | |||||
| if not tenant_id: | |||||
| return get_data_error_result(retmsg="Tenant not found!") | |||||
| ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) | |||||
| DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) | DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0) | ||||
| if not DocumentService.delete_by_id(req["doc_id"]): | if not DocumentService.delete_by_id(req["doc_id"]): |
| llms = LLMService.get_all() | llms = LLMService.get_all() | ||||
| llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] | llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value] | ||||
| for m in llms: | for m in llms: | ||||
| m["available"] = m.llm_name in mdlnms | |||||
| m["available"] = m["llm_name"] in mdlnms | |||||
| res = {} | res = {} | ||||
| for m in llms: | for m in llms: |
| doc_num = IntegerField(default=0) | doc_num = IntegerField(default=0) | ||||
| token_num = IntegerField(default=0) | token_num = IntegerField(default=0) | ||||
| chunk_num = IntegerField(default=0) | chunk_num = IntegerField(default=0) | ||||
| similarity_threshold = FloatField(default=0.4) | |||||
| similarity_threshold = FloatField(default=0.2) | |||||
| vector_similarity_weight = FloatField(default=0.3) | vector_similarity_weight = FloatField(default=0.3) | ||||
| parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | parser_id = CharField(max_length=32, null=False, help_text="default parser ID") | ||||
| prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", | prompt_config = JSONField(null=False, default={"system": "", "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", | ||||
| "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) | "parameters": [], "empty_response": "Sorry! 知识库中未找到相关内容!"}) | ||||
| similarity_threshold = FloatField(default=0.4) | |||||
| similarity_threshold = FloatField(default=0.2) | |||||
| vector_similarity_weight = FloatField(default=0.3) | vector_similarity_weight = FloatField(default=0.3) | ||||
| top_n = IntegerField(default=6) | top_n = IntegerField(default=6) | ||||
| model_config = cls.get_api_key(tenant_id, mdlnm) | model_config = cls.get_api_key(tenant_id, mdlnm) | ||||
| if not model_config: raise LookupError("Model({}) not found".format(mdlnm)) | if not model_config: raise LookupError("Model({}) not found".format(mdlnm)) | ||||
| model_config = model_config[0].to_dict() | |||||
| model_config = model_config.to_dict() | |||||
| if llm_type == LLMType.EMBEDDING.value: | if llm_type == LLMType.EMBEDDING.value: | ||||
| if model_config["llm_factory"] not in EmbeddingModel: return | if model_config["llm_factory"] not in EmbeddingModel: return | ||||
| return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) | return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"]) |
| if re.match(r".*\.pdf$", filename): | if re.match(r".*\.pdf$", filename): | ||||
| return FileType.PDF.value | return FileType.PDF.value | ||||
| if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||||
| if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename): | |||||
| return FileType.DOC.value | return FileType.DOC.value | ||||
| if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): | if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename): |
| class Base(ABC): | class Base(ABC): | ||||
| def __init__(self, key, model_name): | |||||
| pass | |||||
| def chat(self, system, history, gen_conf): | def chat(self, system, history, gen_conf): | ||||
| raise NotImplementedError("Please implement encode method!") | raise NotImplementedError("Please implement encode method!") | ||||
| class GptTurbo(Base): | class GptTurbo(Base): | ||||
| def __init__(self): | |||||
| self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |||||
| def __init__(self, key, model_name="gpt-3.5-turbo"): | |||||
| self.client = OpenAI(api_key=key) | |||||
| self.model_name = model_name | |||||
| def chat(self, system, history, gen_conf): | def chat(self, system, history, gen_conf): | ||||
| history.insert(0, {"role": "system", "content": system}) | history.insert(0, {"role": "system", "content": system}) | ||||
| res = self.client.chat.completions.create( | res = self.client.chat.completions.create( | ||||
| model="gpt-3.5-turbo", | |||||
| model=self.model_name, | |||||
| messages=history, | messages=history, | ||||
| **gen_conf) | **gen_conf) | ||||
| return res.choices[0].message.content.strip() | return res.choices[0].message.content.strip() | ||||
| from dashscope import Generation | |||||
| class QWenChat(Base): | class QWenChat(Base): | ||||
| def __init__(self, key, model_name=Generation.Models.qwen_turbo): | |||||
| import dashscope | |||||
| dashscope.api_key = key | |||||
| self.model_name = model_name | |||||
| def chat(self, system, history, gen_conf): | def chat(self, system, history, gen_conf): | ||||
| from http import HTTPStatus | from http import HTTPStatus | ||||
| from dashscope import Generation | |||||
| # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY | |||||
| history.insert(0, {"role": "system", "content": system}) | history.insert(0, {"role": "system", "content": system}) | ||||
| response = Generation.call( | response = Generation.call( | ||||
| Generation.Models.qwen_turbo, | |||||
| self.model_name, | |||||
| messages=history, | messages=history, | ||||
| result_format='message' | result_format='message' | ||||
| ) | ) |
| raise NotImplementedError("Please implement encode method!") | raise NotImplementedError("Please implement encode method!") | ||||
| def image2base64(self, image): | def image2base64(self, image): | ||||
| if isinstance(image, bytes): | |||||
| return base64.b64encode(image).decode("utf-8") | |||||
| if isinstance(image, BytesIO): | if isinstance(image, BytesIO): | ||||
| return base64.b64encode(image.getvalue()).decode("utf-8") | return base64.b64encode(image.getvalue()).decode("utf-8") | ||||
| buffered = BytesIO() | buffered = BytesIO() | ||||
| class GptV4(Base): | class GptV4(Base): | ||||
| def __init__(self, key, model_name="gpt-4-vision-preview"): | def __init__(self, key, model_name="gpt-4-vision-preview"): | ||||
| self.client = OpenAI(key) | |||||
| self.client = OpenAI(api_key = key) | |||||
| self.model_name = model_name | self.model_name = model_name | ||||
| def describe(self, image, max_tokens=300): | def describe(self, image, max_tokens=300): |
| if len(t) < 5: continue | if len(t) < 5: continue | ||||
| idx.append(i) | idx.append(i) | ||||
| pieces_.append(t) | pieces_.append(t) | ||||
| es_logger.info("{} => {}".format(answer, pieces_)) | |||||
| if not pieces_: return answer | if not pieces_: return answer | ||||
| ans_v = embd_mdl.encode(pieces_) | |||||
| ans_v, c = embd_mdl.encode(pieces_) | |||||
| assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( | assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format( | ||||
| len(ans_v[0]), len(chunk_v[0])) | len(ans_v[0]), len(chunk_v[0])) | ||||
| Dealer.trans2floats( | Dealer.trans2floats( | ||||
| sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids] | sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids] | ||||
| if not ins_embd: | if not ins_embd: | ||||
| return [] | |||||
| return [], [], [] | |||||
| ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids] | ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids] | ||||
| sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, | ||||
| ins_embd, | ins_embd, | ||||
| def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, | def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2, | ||||
| vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): | vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True): | ||||
| ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | |||||
| if not question: return ranks | |||||
| req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top, | req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top, | ||||
| "question": question, "vector": True, | "question": question, "vector": True, | ||||
| "similarity": similarity_threshold} | "similarity": similarity_threshold} | ||||
| sim, tsim, vsim = self.rerank( | sim, tsim, vsim = self.rerank( | ||||
| sres, question, 1 - vector_similarity_weight, vector_similarity_weight) | sres, question, 1 - vector_similarity_weight, vector_similarity_weight) | ||||
| idx = np.argsort(sim * -1) | idx = np.argsort(sim * -1) | ||||
| ranks = {"total": 0, "chunks": [], "doc_aggs": {}} | |||||
| dim = len(sres.query_vector) | dim = len(sres.query_vector) | ||||
| start_idx = (page - 1) * page_size | start_idx = (page - 1) * page_size | ||||
| for i in idx: | for i in idx: |
| field = TextChunker.Fields() | field = TextChunker.Fields() | ||||
| field.text_chunks = [(txt, binary)] | field.text_chunks = [(txt, binary)] | ||||
| field.table_chunks = [] | field.table_chunks = [] | ||||
| return field | |||||
| return TextChunker()(binary) | return TextChunker()(binary) | ||||
| doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"]) | ||||
| output_buffer = BytesIO() | output_buffer = BytesIO() | ||||
| docs = [] | docs = [] | ||||
| md5 = hashlib.md5() | |||||
| for txt, img in obj.text_chunks: | for txt, img in obj.text_chunks: | ||||
| d = copy.deepcopy(doc) | d = copy.deepcopy(doc) | ||||
| md5 = hashlib.md5() | |||||
| md5.update((txt + str(d["doc_id"])).encode("utf-8")) | md5.update((txt + str(d["doc_id"])).encode("utf-8")) | ||||
| d["_id"] = md5.hexdigest() | d["_id"] = md5.hexdigest() | ||||
| d["content_ltks"] = huqie.qie(txt) | d["content_ltks"] = huqie.qie(txt) | ||||
| for i, txt in enumerate(arr): | for i, txt in enumerate(arr): | ||||
| d = copy.deepcopy(doc) | d = copy.deepcopy(doc) | ||||
| d["content_ltks"] = huqie.qie(txt) | d["content_ltks"] = huqie.qie(txt) | ||||
| md5 = hashlib.md5() | |||||
| md5.update((txt + str(d["doc_id"])).encode("utf-8")) | md5.update((txt + str(d["doc_id"])).encode("utf-8")) | ||||
| d["_id"] = md5.hexdigest() | d["_id"] = md5.hexdigest() | ||||
| if not img: | if not img: | ||||
| def main(comm, mod): | def main(comm, mod): | ||||
| global model | |||||
| from rag.llm import HuEmbedding | |||||
| model = HuEmbedding() | |||||
| tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") | tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm") | ||||
| tm = findMaxTm(tm_fnm) | tm = findMaxTm(tm_fnm) | ||||
| rows = collect(comm, mod, tm) | rows = collect(comm, mod, tm) | ||||
| set_progress(r["id"], random.randint(70, 95) / 100., | set_progress(r["id"], random.randint(70, 95) / 100., | ||||
| "Finished embedding! Start to build index!") | "Finished embedding! Start to build index!") | ||||
| init_kb(r) | init_kb(r) | ||||
| chunk_count = len(set([c["_id"] for c in cks])) | |||||
| es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"])) | es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"])) | ||||
| if es_r: | if es_r: | ||||
| set_progress(r["id"], -1, "Index failure!") | set_progress(r["id"], -1, "Index failure!") | ||||
| cron_logger.error(str(es_r)) | cron_logger.error(str(es_r)) | ||||
| else: | else: | ||||
| set_progress(r["id"], 1., "Done!") | set_progress(r["id"], 1., "Done!") | ||||
| DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm) | |||||
| DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, chunk_count, timer()-st_tm) | |||||
| cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) | cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks))) | ||||
| tmf.write(str(r["update_time"]) + "\n") | tmf.write(str(r["update_time"]) + "\n") |