浏览代码

fix #994 (#1006)

### What problem does this PR solve?

#994 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.8.0
KevinHuSh 1年前
父节点
当前提交
b9bb11879f
没有帐户链接到提交者的电子邮件
共有 1 个文件被更改,包括 29 次插入21 次删除
  1. 29
    21
      rag/llm/embedding_model.py

+ 29
- 21
rag/llm/embedding_model.py 查看文件



def encode(self, texts: list, batch_size=10): def encode(self, texts: list, batch_size=10):
import dashscope import dashscope
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
try:
res = []
token_count = 0
texts = [truncate(t, 2048) for t in texts]
for i in range(0, len(texts), batch_size):
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=texts[i:i + batch_size],
text_type="document"
)
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += resp["usage"]["total_tokens"]
return np.array(res), token_count
except Exception as e:
raise Exception("Account abnormal. Please ensure it's on good standing.")
return np.array([]), 0

def encode_queries(self, text):
try:
resp = dashscope.TextEmbedding.call( resp = dashscope.TextEmbedding.call(
model=self.model_name, model=self.model_name,
input=texts[i:i + batch_size],
text_type="document"
input=text[:2048],
text_type="query"
) )
embds = [[] for _ in range(len(resp["output"]["embeddings"]))]
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += resp["usage"]["total_tokens"]
return np.array(res), token_count

def encode_queries(self, text):
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=text[:2048],
text_type="query"
)
return np.array(resp["output"]["embeddings"][0]
["embedding"]), resp["usage"]["total_tokens"]
return np.array(resp["output"]["embeddings"][0]
["embedding"]), resp["usage"]["total_tokens"]
except Exception as e:
raise Exception("Account abnormal. Please ensure it's on good standing.")
return np.array([]), 0




class ZhipuEmbed(Base): class ZhipuEmbed(Base):

正在加载...
取消
保存