### What problem does this PR solve? #4567 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)tags/v0.16.0
| @@ -53,7 +53,7 @@ class Base(ABC): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| return ans, response.usage.total_tokens | |||
| return ans, self.total_token_count(response) | |||
| except openai.APIError as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| @@ -75,15 +75,11 @@ class Base(ABC): | |||
| resp.choices[0].delta.content = "" | |||
| ans += resp.choices[0].delta.content | |||
| if not hasattr(resp, "usage") or not resp.usage: | |||
| total_tokens = ( | |||
| total_tokens | |||
| + num_tokens_from_string(resp.choices[0].delta.content) | |||
| ) | |||
| elif isinstance(resp.usage, dict): | |||
| total_tokens = resp.usage.get("total_tokens", total_tokens) | |||
| tol = self.total_token_count(resp) | |||
| if not tol: | |||
| total_tokens += num_tokens_from_string(resp.choices[0].delta.content) | |||
| else: | |||
| total_tokens = resp.usage.total_tokens | |||
| total_tokens = tol | |||
| if resp.choices[0].finish_reason == "length": | |||
| if is_chinese(ans): | |||
| @@ -97,6 +93,17 @@ class Base(ABC): | |||
| yield total_tokens | |||
| def total_token_count(self, resp): | |||
| try: | |||
| return resp.usage.total_tokens | |||
| except Exception: | |||
| pass | |||
| try: | |||
| return resp["usage"]["total_tokens"] | |||
| except Exception: | |||
| pass | |||
| return 0 | |||
| class GptTurbo(Base): | |||
| def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): | |||
| @@ -182,7 +189,7 @@ class BaiChuanChat(Base): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| return ans, response.usage.total_tokens | |||
| return ans, self.total_token_count(response) | |||
| except openai.APIError as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| @@ -212,14 +219,11 @@ class BaiChuanChat(Base): | |||
| if not resp.choices[0].delta.content: | |||
| resp.choices[0].delta.content = "" | |||
| ans += resp.choices[0].delta.content | |||
| total_tokens = ( | |||
| ( | |||
| total_tokens | |||
| + num_tokens_from_string(resp.choices[0].delta.content) | |||
| ) | |||
| if not hasattr(resp, "usage") | |||
| else resp.usage["total_tokens"] | |||
| ) | |||
| tol = self.total_token_count(resp) | |||
| if not tol: | |||
| total_tokens += num_tokens_from_string(resp.choices[0].delta.content) | |||
| else: | |||
| total_tokens = tol | |||
| if resp.choices[0].finish_reason == "length": | |||
| if is_chinese([ans]): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| @@ -256,7 +260,7 @@ class QWenChat(Base): | |||
| tk_count = 0 | |||
| if response.status_code == HTTPStatus.OK: | |||
| ans += response.output.choices[0]['message']['content'] | |||
| tk_count += response.usage.total_tokens | |||
| tk_count += self.total_token_count(response) | |||
| if response.output.choices[0].get("finish_reason", "") == "length": | |||
| if is_chinese([ans]): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| @@ -292,7 +296,7 @@ class QWenChat(Base): | |||
| for resp in response: | |||
| if resp.status_code == HTTPStatus.OK: | |||
| ans = resp.output.choices[0]['message']['content'] | |||
| tk_count = resp.usage.total_tokens | |||
| tk_count = self.total_token_count(resp) | |||
| if resp.output.choices[0].get("finish_reason", "") == "length": | |||
| if is_chinese(ans): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| @@ -334,7 +338,7 @@ class ZhipuChat(Base): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| return ans, response.usage.total_tokens | |||
| return ans, self.total_token_count(response) | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| @@ -364,9 +368,9 @@ class ZhipuChat(Base): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| tk_count = resp.usage.total_tokens | |||
| tk_count = self.total_token_count(resp) | |||
| if resp.choices[0].finish_reason == "stop": | |||
| tk_count = resp.usage.total_tokens | |||
| tk_count = self.total_token_count(resp) | |||
| yield ans | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| @@ -569,7 +573,7 @@ class MiniMaxChat(Base): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| return ans, response["usage"]["total_tokens"] | |||
| return ans, self.total_token_count(response) | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| @@ -603,11 +607,11 @@ class MiniMaxChat(Base): | |||
| if "choices" in resp and "delta" in resp["choices"][0]: | |||
| text = resp["choices"][0]["delta"]["content"] | |||
| ans += text | |||
| total_tokens = ( | |||
| total_tokens + num_tokens_from_string(text) | |||
| if "usage" not in resp | |||
| else resp["usage"]["total_tokens"] | |||
| ) | |||
| tol = self.total_token_count(resp) | |||
| if not tol: | |||
| total_tokens += num_tokens_from_string(text) | |||
| else: | |||
| total_tokens = tol | |||
| yield ans | |||
| except Exception as e: | |||
| @@ -640,7 +644,7 @@ class MistralChat(Base): | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| return ans, response.usage.total_tokens | |||
| return ans, self.total_token_count(response) | |||
| except openai.APIError as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| @@ -838,7 +842,7 @@ class GeminiChat(Base): | |||
| yield 0 | |||
| class GroqChat: | |||
| class GroqChat(Base): | |||
| def __init__(self, key, model_name, base_url=''): | |||
| from groq import Groq | |||
| self.client = Groq(api_key=key) | |||
| @@ -863,7 +867,7 @@ class GroqChat: | |||
| ans += LENGTH_NOTIFICATION_CN | |||
| else: | |||
| ans += LENGTH_NOTIFICATION_EN | |||
| return ans, response.usage.total_tokens | |||
| return ans, self.total_token_count(response) | |||
| except Exception as e: | |||
| return ans + "\n**ERROR**: " + str(e), 0 | |||
| @@ -1255,7 +1259,7 @@ class BaiduYiyanChat(Base): | |||
| **gen_conf | |||
| ).body | |||
| ans = response['result'] | |||
| return ans, response["usage"]["total_tokens"] | |||
| return ans, self.total_token_count(response) | |||
| except Exception as e: | |||
| return ans + "\n**ERROR**: " + str(e), 0 | |||
| @@ -1283,7 +1287,7 @@ class BaiduYiyanChat(Base): | |||
| for resp in response: | |||
| resp = resp.body | |||
| ans += resp['result'] | |||
| total_tokens = resp["usage"]["total_tokens"] | |||
| total_tokens = self.total_token_count(resp) | |||
| yield ans | |||
| @@ -44,11 +44,23 @@ class Base(ABC): | |||
| def encode_queries(self, text: str): | |||
| raise NotImplementedError("Please implement encode method!") | |||
| def total_token_count(self, resp): | |||
| try: | |||
| return resp.usage.total_tokens | |||
| except Exception: | |||
| pass | |||
| try: | |||
| return resp["usage"]["total_tokens"] | |||
| except Exception: | |||
| pass | |||
| return 0 | |||
| class DefaultEmbedding(Base): | |||
| _model = None | |||
| _model_name = "" | |||
| _model_lock = threading.Lock() | |||
| def __init__(self, key, model_name, **kwargs): | |||
| """ | |||
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |||
| @@ -115,13 +127,13 @@ class OpenAIEmbed(Base): | |||
| res = self.client.embeddings.create(input=texts[i:i + batch_size], | |||
| model=self.model_name) | |||
| ress.extend([d.embedding for d in res.data]) | |||
| total_tokens += res.usage.total_tokens | |||
| total_tokens += self.total_token_count(res) | |||
| return np.array(ress), total_tokens | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=[truncate(text, 8191)], | |||
| model=self.model_name) | |||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||
| return np.array(res.data[0].embedding), self.total_token_count(res) | |||
| class LocalAIEmbed(Base): | |||
| @@ -188,7 +200,7 @@ class QWenEmbed(Base): | |||
| for e in resp["output"]["embeddings"]: | |||
| embds[e["text_index"]] = e["embedding"] | |||
| res.extend(embds) | |||
| token_count += resp["usage"]["total_tokens"] | |||
| token_count += self.total_token_count(resp) | |||
| return np.array(res), token_count | |||
| except Exception as e: | |||
| raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) | |||
| @@ -203,7 +215,7 @@ class QWenEmbed(Base): | |||
| text_type="query" | |||
| ) | |||
| return np.array(resp["output"]["embeddings"][0] | |||
| ["embedding"]), resp["usage"]["total_tokens"] | |||
| ["embedding"]), self.total_token_count(resp) | |||
| except Exception: | |||
| raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) | |||
| return np.array([]), 0 | |||
| @@ -229,13 +241,13 @@ class ZhipuEmbed(Base): | |||
| res = self.client.embeddings.create(input=txt, | |||
| model=self.model_name) | |||
| arr.append(res.data[0].embedding) | |||
| tks_num += res.usage.total_tokens | |||
| tks_num += self.total_token_count(res) | |||
| return np.array(arr), tks_num | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=text, | |||
| model=self.model_name) | |||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||
| return np.array(res.data[0].embedding), self.total_token_count(res) | |||
| class OllamaEmbed(Base): | |||
| @@ -318,13 +330,13 @@ class XinferenceEmbed(Base): | |||
| for i in range(0, len(texts), batch_size): | |||
| res = self.client.embeddings.create(input=texts[i:i + batch_size], model=self.model_name) | |||
| ress.extend([d.embedding for d in res.data]) | |||
| total_tokens += res.usage.total_tokens | |||
| total_tokens += self.total_token_count(res) | |||
| return np.array(ress), total_tokens | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings.create(input=[text], | |||
| model=self.model_name) | |||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||
| return np.array(res.data[0].embedding), self.total_token_count(res) | |||
| class YoudaoEmbed(Base): | |||
| @@ -383,7 +395,7 @@ class JinaEmbed(Base): | |||
| } | |||
| res = requests.post(self.base_url, headers=self.headers, json=data).json() | |||
| ress.extend([d["embedding"] for d in res["data"]]) | |||
| token_count += res["usage"]["total_tokens"] | |||
| token_count += self.total_token_count(res) | |||
| return np.array(ress), token_count | |||
| def encode_queries(self, text): | |||
| @@ -447,13 +459,13 @@ class MistralEmbed(Base): | |||
| res = self.client.embeddings(input=texts[i:i + batch_size], | |||
| model=self.model_name) | |||
| ress.extend([d.embedding for d in res.data]) | |||
| token_count += res.usage.total_tokens | |||
| token_count += self.total_token_count(res) | |||
| return np.array(ress), token_count | |||
| def encode_queries(self, text): | |||
| res = self.client.embeddings(input=[truncate(text, 8196)], | |||
| model=self.model_name) | |||
| return np.array(res.data[0].embedding), res.usage.total_tokens | |||
| return np.array(res.data[0].embedding), self.total_token_count(res) | |||
| class BedrockEmbed(Base): | |||
| @@ -565,7 +577,7 @@ class NvidiaEmbed(Base): | |||
| } | |||
| res = requests.post(self.base_url, headers=self.headers, json=payload).json() | |||
| ress.extend([d["embedding"] for d in res["data"]]) | |||
| token_count += res["usage"]["total_tokens"] | |||
| token_count += self.total_token_count(res) | |||
| return np.array(ress), token_count | |||
| def encode_queries(self, text): | |||
| @@ -677,7 +689,7 @@ class SILICONFLOWEmbed(Base): | |||
| if "data" not in res or not isinstance(res["data"], list) or len(res["data"]) != len(texts_batch): | |||
| raise ValueError(f"SILICONFLOWEmbed.encode got invalid response from {self.base_url}") | |||
| ress.extend([d["embedding"] for d in res["data"]]) | |||
| token_count += res["usage"]["total_tokens"] | |||
| token_count += self.total_token_count(res) | |||
| return np.array(ress), token_count | |||
| def encode_queries(self, text): | |||
| @@ -689,7 +701,7 @@ class SILICONFLOWEmbed(Base): | |||
| res = requests.post(self.base_url, json=payload, headers=self.headers).json() | |||
| if "data" not in res or not isinstance(res["data"], list) or len(res["data"])!= 1: | |||
| raise ValueError(f"SILICONFLOWEmbed.encode_queries got invalid response from {self.base_url}") | |||
| return np.array(res["data"][0]["embedding"]), res["usage"]["total_tokens"] | |||
| return np.array(res["data"][0]["embedding"]), self.total_token_count(res) | |||
| class ReplicateEmbed(Base): | |||
| @@ -727,14 +739,14 @@ class BaiduYiyanEmbed(Base): | |||
| res = self.client.do(model=self.model_name, texts=texts).body | |||
| return ( | |||
| np.array([r["embedding"] for r in res["data"]]), | |||
| res["usage"]["total_tokens"], | |||
| self.total_token_count(res), | |||
| ) | |||
| def encode_queries(self, text): | |||
| res = self.client.do(model=self.model_name, texts=[text]).body | |||
| return ( | |||
| np.array([r["embedding"] for r in res["data"]]), | |||
| res["usage"]["total_tokens"], | |||
| self.total_token_count(res), | |||
| ) | |||
| @@ -42,6 +42,17 @@ class Base(ABC): | |||
| def similarity(self, query: str, texts: list): | |||
| raise NotImplementedError("Please implement encode method!") | |||
| def total_token_count(self, resp): | |||
| try: | |||
| return resp.usage.total_tokens | |||
| except Exception: | |||
| pass | |||
| try: | |||
| return resp["usage"]["total_tokens"] | |||
| except Exception: | |||
| pass | |||
| return 0 | |||
| class DefaultRerank(Base): | |||
| _model = None | |||
| @@ -115,7 +126,7 @@ class JinaRerank(Base): | |||
| rank = np.zeros(len(texts), dtype=float) | |||
| for d in res["results"]: | |||
| rank[d["index"]] = d["relevance_score"] | |||
| return rank, res["usage"]["total_tokens"] | |||
| return rank, self.total_token_count(res) | |||
| class YoudaoRerank(DefaultRerank): | |||
| @@ -417,7 +428,7 @@ class BaiduYiyanRerank(Base): | |||
| rank = np.zeros(len(texts), dtype=float) | |||
| for d in res["results"]: | |||
| rank[d["index"]] = d["relevance_score"] | |||
| return rank, res["usage"]["total_tokens"] | |||
| return rank, self.total_token_count(res) | |||
| class VoyageRerank(Base): | |||