Selaa lähdekoodia

Fix: Patch LiteLLM (#9416)

### What problem does this PR solve?

Patch LiteLLM refactor. #9408

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.20.2
Yongteng Lei 2 kuukautta sitten
vanhempi
commit
a0c2da1219
No account linked to committer's email address
1 muutettua tiedostoa jossa 20 lisäystä ja 10 poistoa
  1. 20
    10
      rag/llm/chat_model.py

+ 20
- 10
rag/llm/chat_model.py Näytä tiedosto

@@ -1455,7 +1455,7 @@ class LiteLLMBase(ABC):
if self.model_name.lower().find("qwen3") >= 0:
kwargs["extra_body"] = {"enable_thinking": False}

completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
response = litellm.completion(
**completion_args,
drop_params=True,
@@ -1475,7 +1475,7 @@ class LiteLLMBase(ABC):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False

completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
stop = kwargs.get("stop")
if stop:
completion_args["stop"] = stop
@@ -1571,17 +1571,27 @@ class LiteLLMBase(ABC):
self.toolcall_session = toolcall_session
self.tools = tools

def _construct_completion_args(self, history, **kwargs):
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
completion_args = {
"model": self.model_name,
"messages": history,
"stream": False,
"tools": self.tools,
"tool_choice": "auto",
"api_key": self.api_key,
**kwargs,
}
if self.provider in SupportedLiteLLMProvider:
if stream:
completion_args.update(
{
"stream": stream,
}
)
if tools and self.tools:
completion_args.update(
{
"tools": self.tools,
"tool_choice": "auto",
}
)
if self.provider in FACTORY_DEFAULT_BASE_URL:
completion_args.update({"api_base": self.base_url})
elif self.provider == SupportedLiteLLMProvider.Bedrock:
completion_args.pop("api_key", None)
@@ -1611,7 +1621,7 @@ class LiteLLMBase(ABC):
for _ in range(self.max_rounds + 1):
logging.info(f"{self.tools=}")

completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
response = litellm.completion(
**completion_args,
drop_params=True,
@@ -1708,7 +1718,7 @@ class LiteLLMBase(ABC):
reasoning_start = False
logging.info(f"{tools=}")

completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion(
**completion_args,
drop_params=True,
@@ -1786,7 +1796,7 @@ class LiteLLMBase(ABC):
logging.warning(f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})

completion_args = self._construct_completion_args(history=history, **gen_conf)
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
response = litellm.completion(
**completion_args,
drop_params=True,

Loading…
Peruuta
Tallenna