瀏覽代碼

Fix:disallowed special token while embedding (#8692)

### What problem does this PR solve?

https://github.com/infiniflow/ragflow/issues/8567

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.20.0
Stephen Hu 3 月之前
父節點
當前提交
e60ec0a31b
沒有連結到貢獻者的電子郵件帳戶。
共有 1 個檔案被更改,包括 8 行新增0 行删除
  1. 8
    0
      rag/llm/embedding_model.py

+ 8
- 0
rag/llm/embedding_model.py 查看文件

@@ -273,6 +273,8 @@ class ZhipuEmbed(Base):
class OllamaEmbed(Base):
_FACTORY_NAME = "Ollama"

_special_tokens = ["<|endoftext|>"]

def __init__(self, key, model_name, **kwargs):
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"})
self.model_name = model_name
@@ -281,6 +283,9 @@ class OllamaEmbed(Base):
arr = []
tks_num = 0
for txt in texts:
# remove special tokens if they exist
for token in OllamaEmbed._special_tokens:
txt = txt.replace(token, "")
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True})
try:
arr.append(res["embedding"])
@@ -290,6 +295,9 @@ class OllamaEmbed(Base):
return np.array(arr), tks_num

def encode_queries(self, text):
# remove special tokens if they exist
for token in OllamaEmbed._special_tokens:
text = text.replace(token, "")
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True})
try:
return np.array(res["embedding"]), 128

Loading…
取消
儲存