|
|
|
@@ -27,8 +27,7 @@ import torch |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from api.utils.file_utils import get_project_base_directory, get_home_cache_dir |
|
|
|
from rag.utils import num_tokens_from_string |
|
|
|
|
|
|
|
from rag.utils import num_tokens_from_string, truncate |
|
|
|
|
|
|
|
try: |
|
|
|
flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), |
|
|
|
@@ -70,7 +69,7 @@ class DefaultEmbedding(Base): |
|
|
|
self.model = flag_model |
|
|
|
|
|
|
|
def encode(self, texts: list, batch_size=32): |
|
|
|
texts = [t[:2000] for t in texts] |
|
|
|
texts = [truncate(t, 2048) for t in texts] |
|
|
|
token_count = 0 |
|
|
|
for t in texts: |
|
|
|
token_count += num_tokens_from_string(t) |
|
|
|
@@ -93,12 +92,14 @@ class OpenAIEmbed(Base): |
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
def encode(self, texts: list, batch_size=32): |
|
|
|
texts = [truncate(t, 8196) for t in texts] |
|
|
|
res = self.client.embeddings.create(input=texts, |
|
|
|
model=self.model_name) |
|
|
|
return np.array([d.embedding for d in res.data]), res.usage.total_tokens |
|
|
|
return np.array([d.embedding for d in res.data] |
|
|
|
), res.usage.total_tokens |
|
|
|
|
|
|
|
def encode_queries(self, text): |
|
|
|
res = self.client.embeddings.create(input=[text], |
|
|
|
res = self.client.embeddings.create(input=[truncate(text, 8196)], |
|
|
|
model=self.model_name) |
|
|
|
return np.array(res.data[0].embedding), res.usage.total_tokens |
|
|
|
|
|
|
|
@@ -112,7 +113,7 @@ class QWenEmbed(Base): |
|
|
|
import dashscope |
|
|
|
res = [] |
|
|
|
token_count = 0 |
|
|
|
texts = [txt[:2048] for txt in texts] |
|
|
|
texts = [truncate(t, 2048) for t in texts] |
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
|
resp = dashscope.TextEmbedding.call( |
|
|
|
model=self.model_name, |