瀏覽代碼

Fix: Patch LiteLLM (#9416)

### What problem does this PR solve?

Patch LiteLLM refactor. #9408

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.20.2
Yongteng Lei 2 月之前
父節點
當前提交
a0c2da1219
沒有連結到貢獻者的電子郵件帳戶。
共有 1 個檔案被更改,包括 20 行新增10 行删除
  1. 20
    10
      rag/llm/chat_model.py

+ 20
- 10
rag/llm/chat_model.py 查看文件

if self.model_name.lower().find("qwen3") >= 0: if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False} kwargs["extra_body"] = {"enable_thinking": False}


completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
response = litellm.completion( response = litellm.completion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4)) logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False reasoning_start = False


completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop") stop = kwargs.get("stop")
if stop: if stop:
completion_args["stop"] = stop completion_args["stop"] = stop
self.toolcall_session = toolcall_session self.toolcall_session = toolcall_session
self.tools = tools self.tools = tools


def _construct_completion_args(self, history, **kwargs):
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
completion_args = { completion_args = {
"model": self.model_name, "model": self.model_name,
"messages": history, "messages": history,
"stream": False,
"tools": self.tools,
"tool_choice": "auto",
"api_key": self.api_key, "api_key": self.api_key,
**kwargs, **kwargs,
} }
if self.provider in SupportedLiteLLMProvider:
if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url}) completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock: elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None) completion_args.pop("api_key", None)
for _ in range(self.max_rounds + 1): for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}") logging.info(f"{self.tools=}")


completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
response = litellm.completion( response = litellm.completion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
reasoning_start = False reasoning_start = False
logging.info(f"{tools=}") logging.info(f"{tools=}")


completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion( response = litellm.completion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,
logging.warning(f"Exceed max rounds: {self.max_rounds}") logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"}) history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})


completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion( response = litellm.completion(
**completion_args, **completion_args,
drop_params=True, drop_params=True,

Loading…
取消
儲存