|
|
|
@@ -273,6 +273,8 @@ class ZhipuEmbed(Base): |
|
|
|
class OllamaEmbed(Base): |
|
|
|
_FACTORY_NAME = "Ollama" |
|
|
|
|
|
|
|
_special_tokens = ["<|endoftext|>"] |
|
|
|
|
|
|
|
def __init__(self, key, model_name, **kwargs): |
|
|
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) |
|
|
|
self.model_name = model_name |
|
|
|
@@ -281,6 +283,9 @@ class OllamaEmbed(Base): |
|
|
|
arr = [] |
|
|
|
tks_num = 0 |
|
|
|
for txt in texts: |
|
|
|
# remove special tokens if they exist |
|
|
|
for token in OllamaEmbed._special_tokens: |
|
|
|
txt = txt.replace(token, "") |
|
|
|
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}) |
|
|
|
try: |
|
|
|
arr.append(res["embedding"]) |
|
|
|
@@ -290,6 +295,9 @@ class OllamaEmbed(Base): |
|
|
|
return np.array(arr), tks_num |
|
|
|
|
|
|
|
def encode_queries(self, text): |
|
|
|
# remove special tokens if they exist |
|
|
|
for token in OllamaEmbed._special_tokens: |
|
|
|
text = text.replace(token, "") |
|
|
|
res = self.client.embeddings(prompt=text, model=self.model_name, options={"use_mmap": True}) |
|
|
|
try: |
|
|
|
return np.array(res["embedding"]), 128 |