|
|
|
@@ -179,7 +179,41 @@ class Base(ABC): |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
return 0 |
|
|
|
|
|
|
|
|
|
|
|
def _calculate_dynamic_ctx(self, history): |
|
|
|
"""Calculate dynamic context window size""" |
|
|
|
def count_tokens(text): |
|
|
|
"""Calculate token count for text""" |
|
|
|
# Simple calculation: 1 token per ASCII character |
|
|
|
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.) |
|
|
|
total = 0 |
|
|
|
for char in text: |
|
|
|
if ord(char) < 128: # ASCII characters |
|
|
|
total += 1 |
|
|
|
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.) |
|
|
|
total += 2 |
|
|
|
return total |
|
|
|
|
|
|
|
# Calculate total tokens for all messages |
|
|
|
total_tokens = 0 |
|
|
|
for message in history: |
|
|
|
content = message.get("content", "") |
|
|
|
# Calculate content tokens |
|
|
|
content_tokens = count_tokens(content) |
|
|
|
# Add role marker token overhead |
|
|
|
role_tokens = 4 |
|
|
|
total_tokens += content_tokens + role_tokens |
|
|
|
|
|
|
|
# Apply 1.2x buffer ratio |
|
|
|
total_tokens_with_buffer = int(total_tokens * 1.2) |
|
|
|
|
|
|
|
if total_tokens_with_buffer <= 8192: |
|
|
|
ctx_size = 8192 |
|
|
|
else: |
|
|
|
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1 |
|
|
|
ctx_size = ctx_multiplier * 8192 |
|
|
|
|
|
|
|
return ctx_size |
|
|
|
|
|
|
|
class GptTurbo(Base): |
|
|
|
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): |
|
|
|
@@ -469,7 +503,7 @@ class ZhipuChat(Base): |
|
|
|
|
|
|
|
class OllamaChat(Base): |
|
|
|
def __init__(self, key, model_name, **kwargs): |
|
|
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) |
|
|
|
self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bearer {key}"}) |
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
def chat(self, system, history, gen_conf): |
|
|
|
@@ -478,7 +512,12 @@ class OllamaChat(Base): |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
try: |
|
|
|
options = {"num_ctx": 32768} |
|
|
|
# Calculate context size |
|
|
|
ctx_size = self._calculate_dynamic_ctx(history) |
|
|
|
|
|
|
|
options = { |
|
|
|
"num_ctx": ctx_size |
|
|
|
} |
|
|
|
if "temperature" in gen_conf: |
|
|
|
options["temperature"] = gen_conf["temperature"] |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
@@ -489,9 +528,11 @@ class OllamaChat(Base): |
|
|
|
options["presence_penalty"] = gen_conf["presence_penalty"] |
|
|
|
if "frequency_penalty" in gen_conf: |
|
|
|
options["frequency_penalty"] = gen_conf["frequency_penalty"] |
|
|
|
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1) |
|
|
|
|
|
|
|
response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=10) |
|
|
|
ans = response["message"]["content"].strip() |
|
|
|
return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0) |
|
|
|
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0) |
|
|
|
return ans, token_count |
|
|
|
except Exception as e: |
|
|
|
return "**ERROR**: " + str(e), 0 |
|
|
|
|
|
|
|
@@ -500,28 +541,38 @@ class OllamaChat(Base): |
|
|
|
history.insert(0, {"role": "system", "content": system}) |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
del gen_conf["max_tokens"] |
|
|
|
options = {} |
|
|
|
if "temperature" in gen_conf: |
|
|
|
options["temperature"] = gen_conf["temperature"] |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
options["num_predict"] = gen_conf["max_tokens"] |
|
|
|
if "top_p" in gen_conf: |
|
|
|
options["top_p"] = gen_conf["top_p"] |
|
|
|
if "presence_penalty" in gen_conf: |
|
|
|
options["presence_penalty"] = gen_conf["presence_penalty"] |
|
|
|
if "frequency_penalty" in gen_conf: |
|
|
|
options["frequency_penalty"] = gen_conf["frequency_penalty"] |
|
|
|
ans = "" |
|
|
|
try: |
|
|
|
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1) |
|
|
|
for resp in response: |
|
|
|
if resp["done"]: |
|
|
|
yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) |
|
|
|
ans = resp["message"]["content"] |
|
|
|
yield ans |
|
|
|
# Calculate context size |
|
|
|
ctx_size = self._calculate_dynamic_ctx(history) |
|
|
|
options = { |
|
|
|
"num_ctx": ctx_size |
|
|
|
} |
|
|
|
if "temperature" in gen_conf: |
|
|
|
options["temperature"] = gen_conf["temperature"] |
|
|
|
if "max_tokens" in gen_conf: |
|
|
|
options["num_predict"] = gen_conf["max_tokens"] |
|
|
|
if "top_p" in gen_conf: |
|
|
|
options["top_p"] = gen_conf["top_p"] |
|
|
|
if "presence_penalty" in gen_conf: |
|
|
|
options["presence_penalty"] = gen_conf["presence_penalty"] |
|
|
|
if "frequency_penalty" in gen_conf: |
|
|
|
options["frequency_penalty"] = gen_conf["frequency_penalty"] |
|
|
|
|
|
|
|
ans = "" |
|
|
|
try: |
|
|
|
response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=10 ) |
|
|
|
for resp in response: |
|
|
|
if resp["done"]: |
|
|
|
token_count = resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) |
|
|
|
yield token_count |
|
|
|
ans = resp["message"]["content"] |
|
|
|
yield ans |
|
|
|
except Exception as e: |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
yield 0 |
|
|
|
except Exception as e: |
|
|
|
yield ans + "\n**ERROR**: " + str(e) |
|
|
|
yield 0 |
|
|
|
yield "**ERROR**: " + str(e) |
|
|
|
yield 0 |
|
|
|
|
|
|
|
|
|
|
|
class LocalAIChat(Base): |