|
|
|
@@ -123,30 +123,38 @@ class QWenEmbed(Base): |
|
|
|
|
|
|
|
def encode(self, texts: list, batch_size=10): |
|
|
|
import dashscope |
|
|
|
res = [] |
|
|
|
token_count = 0 |
|
|
|
texts = [truncate(t, 2048) for t in texts] |
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
|
try: |
|
|
|
res = [] |
|
|
|
token_count = 0 |
|
|
|
texts = [truncate(t, 2048) for t in texts] |
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
|
resp = dashscope.TextEmbedding.call( |
|
|
|
model=self.model_name, |
|
|
|
input=texts[i:i + batch_size], |
|
|
|
text_type="document" |
|
|
|
) |
|
|
|
embds = [[] for _ in range(len(resp["output"]["embeddings"]))] |
|
|
|
for e in resp["output"]["embeddings"]: |
|
|
|
embds[e["text_index"]] = e["embedding"] |
|
|
|
res.extend(embds) |
|
|
|
token_count += resp["usage"]["total_tokens"] |
|
|
|
return np.array(res), token_count |
|
|
|
except Exception as e: |
|
|
|
raise Exception("Account abnormal. Please ensure it's on good standing.") |
|
|
|
return np.array([]), 0 |
|
|
|
|
|
|
|
def encode_queries(self, text): |
|
|
|
try: |
|
|
|
resp = dashscope.TextEmbedding.call( |
|
|
|
model=self.model_name, |
|
|
|
input=texts[i:i + batch_size], |
|
|
|
text_type="document" |
|
|
|
input=text[:2048], |
|
|
|
text_type="query" |
|
|
|
) |
|
|
|
embds = [[] for _ in range(len(resp["output"]["embeddings"]))] |
|
|
|
for e in resp["output"]["embeddings"]: |
|
|
|
embds[e["text_index"]] = e["embedding"] |
|
|
|
res.extend(embds) |
|
|
|
token_count += resp["usage"]["total_tokens"] |
|
|
|
return np.array(res), token_count |
|
|
|
|
|
|
|
def encode_queries(self, text): |
|
|
|
resp = dashscope.TextEmbedding.call( |
|
|
|
model=self.model_name, |
|
|
|
input=text[:2048], |
|
|
|
text_type="query" |
|
|
|
) |
|
|
|
return np.array(resp["output"]["embeddings"][0] |
|
|
|
["embedding"]), resp["usage"]["total_tokens"] |
|
|
|
return np.array(resp["output"]["embeddings"][0] |
|
|
|
["embedding"]), resp["usage"]["total_tokens"] |
|
|
|
except Exception as e: |
|
|
|
raise Exception("Account abnormal. Please ensure it's on good standing.") |
|
|
|
return np.array([]), 0 |
|
|
|
|
|
|
|
|
|
|
|
class ZhipuEmbed(Base): |