浏览代码

fix english query bug (#840)

### What problem does this PR solve?

#834 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.6.0
KevinHuSh 1年前
父节点
当前提交
2b36283712
没有帐户链接到提交者的电子邮件
共有 4 个文件被更改,包括 80 次插入7 次删除
  1. 1
    1
      api/db/services/dialog_service.py
  2. 44
    0
      rag/llm/chat_model.py
  3. 29
    1
      rag/llm/rpc_server.py
  4. 6
    5
      rag/nlp/query.py

+ 1
- 1
api/db/services/dialog_service.py 查看文件

kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold, dialog.similarity_threshold,
dialog.vector_similarity_weight, dialog.vector_similarity_weight,
doc_ids=kwargs.get("doc_ids", "").split(","),
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False) top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info( chat_logger.info(

+ 44
- 0
rag/llm/chat_model.py 查看文件

import openai import openai
from ollama import Client from ollama import Client
from rag.nlp import is_english from rag.nlp import is_english
from rag.utils import num_tokens_from_string




class Base(ABC): class Base(ABC):
except Exception as e: except Exception as e:
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
yield 0 yield 0


class LocalLLM(Base):
class RPCProxy:
def __init__(self, host, port):
self.host = host
self.port = int(port)
self.__conn()

def __conn(self):
from multiprocessing.connection import Client
self._connection = Client(
(self.host, self.port), authkey=b'infiniflow-token4kevinhu')

def __getattr__(self, name):
import pickle

def do_rpc(*args, **kwargs):
for _ in range(3):
try:
self._connection.send(
pickle.dumps((name, args, kwargs)))
return pickle.loads(self._connection.recv())
except Exception as e:
self.__conn()
raise Exception("RPC connection lost!")

return do_rpc

def __init__(self, key, model_name="glm-3-turbo"):
self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)

def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
try:
ans = self.client.chat(
history,
gen_conf
)
return ans, num_tokens_from_string(ans)
except Exception as e:
return "**ERROR**: " + str(e), 0

+ 29
- 1
rag/llm/rpc_server.py 查看文件

import pickle import pickle
import random import random
import time import time
from copy import deepcopy
from multiprocessing.connection import Listener from multiprocessing.connection import Listener
from threading import Thread from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
def torch_gc(): def torch_gc():
return str(e) return str(e)
def chat_streamly(messages, gen_conf):
global tokenizer
model = Model()
try:
torch_gc()
conf = deepcopy(gen_conf)
print(messages, conf)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
streamer = TextStreamer(tokenizer)
conf["inputs"] = model_inputs.input_ids
conf["streamer"] = streamer
conf["max_new_tokens"] = conf["max_tokens"]
del conf["max_tokens"]
thread = Thread(target=model.generate, kwargs=conf)
thread.start()
for _, new_text in enumerate(streamer):
yield new_text
except Exception as e:
yield "**ERROR**: " + str(e)
def Model(): def Model():
global models global models
random.seed(time.time()) random.seed(time.time())
handler = RPCHandler() handler = RPCHandler()
handler.register_function(chat) handler.register_function(chat)
handler.register_function(chat_streamly)
models = [] models = []
for _ in range(1): for _ in range(1):

+ 6
- 5
rag/nlp/query.py 查看文件

patts = [ patts = [
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""), (r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down)", " ")
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
] ]
for r, p in patts: for r, p in patts:
txt = re.sub(r, p, txt, flags=re.IGNORECASE) txt = re.sub(r, p, txt, flags=re.IGNORECASE)


def question(self, txt, tbl="qa", min_match="60%"): def question(self, txt, tbl="qa", min_match="60%"):
txt = re.sub( txt = re.sub(
r"[ \r\n\t,,。??/`!!&]+",
r"[ \r\n\t,,。??/`!!&\^%%]+",
" ", " ",
rag_tokenizer.tradi2simp( rag_tokenizer.tradi2simp(
rag_tokenizer.strQ2B( rag_tokenizer.strQ2B(


if not self.isChinese(txt): if not self.isChinese(txt):
tks = rag_tokenizer.tokenize(txt).split(" ") tks = rag_tokenizer.tokenize(txt).split(" ")
q = copy.deepcopy(tks)
for i in range(1, len(tks)):
q.append("\"%s %s\"^2" % (tks[i - 1], tks[i]))
tks_w = self.tw.weights(tks)
q = [re.sub(r"[ \\\"']+", "", tk)+"^{:.4f}".format(w) for tk, w in tks_w]
for i in range(1, len(tks_w)):
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
if not q: if not q:
q.append(txt) q.append(txt)
return Q("bool", return Q("bool",

正在加载...
取消
保存