Browse Source

Refa: change LLM chat output from full to delta (incremental) (#6534)

### What problem does this PR solve?

Change LLM chat output from full to delta (incremental)

### Type of change

- [x] Refactoring
tags/v0.18.0
Yongteng Lei 7 months ago
parent
commit
df3890827d
No account linked to committer's email address
3 changed files with 276 additions and 398 deletions
  1. 103
    97
      api/apps/sdk/session.py
  2. 8
    5
      api/db/services/llm_service.py
  3. 165
    296
      rag/llm/chat_model.py

+ 103
- 97
api/apps/sdk/session.py View File

@@ -13,31 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import json
import re
import time

from api.db import LLMType
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.dialog_service import ask, chat
from flask import Response, jsonify, request

from agent.canvas import Canvas
from api.db import StatusEnum
from api.db import LLMType, StatusEnum
from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.dialog_service import DialogService
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.dialog_service import DialogService, ask, chat
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, validate_request
from api.utils.api_utils import get_result, token_required
from api.db.services.llm_service import LLMBundle
from api.db.services.file_service import FileService
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, get_result, token_required, validate_request

from flask import jsonify, request, Response

@manager.route('/chats/<chat_id>/sessions', methods=['POST']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create(tenant_id, chat_id):
req = request.json
@@ -50,7 +48,7 @@ def create(tenant_id, chat_id):
"dialog_id": req["dialog_id"],
"name": req.get("name", "New session"),
"message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}],
"user_id": req.get("user_id", "")
"user_id": req.get("user_id", ""),
}
if not conv.get("name"):
return get_error_data_result(message="`name` can not be empty.")
@@ -59,20 +57,20 @@ def create(tenant_id, chat_id):
if not e:
return get_error_data_result(message="Fail to create a session!")
conv = conv.to_dict()
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
conv["chat_id"] = conv.pop("dialog_id")
del conv["reference"]
return get_result(data=conv)


@manager.route('/agents/<agent_id>/sessions', methods=['POST']) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create_agent_session(tenant_id, agent_id):
req = request.json
if not request.is_json:
req = request.form
files = request.files
user_id = request.args.get('user_id', '')
user_id = request.args.get("user_id", "")

e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
@@ -113,7 +111,7 @@ def create_agent_session(tenant_id, agent_id):
ele.pop("value")
else:
if req is not None and req.get(ele["key"]):
ele["value"] = req[ele['key']]
ele["value"] = req[ele["key"]]
else:
if "value" in ele:
ele.pop("value")
@@ -121,20 +119,13 @@ def create_agent_session(tenant_id, agent_id):
for ans in canvas.run(stream=False):
pass
cvs.dsl = json.loads(str(canvas))
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": user_id,
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent",
"dsl": cvs.dsl
}
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
API4ConversationService.save(**conv)
conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv)


@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PUT"]) # noqa: F821
@token_required
def update(tenant_id, chat_id, session_id):
req = request.json
@@ -156,14 +147,14 @@ def update(tenant_id, chat_id, session_id):
return get_result()


@manager.route('/chats/<chat_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/chats/<chat_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def chat_completion(tenant_id, chat_id):
req = request.json
if not req:
req = {"question": ""}
if not req.get("session_id"):
req["question"]=""
req["question"] = ""
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(f"You don't own the chat {chat_id}")
if req.get("session_id"):
@@ -185,7 +176,7 @@ def chat_completion(tenant_id, chat_id):
return get_result(data=answer)


@manager.route('/chats_openai/<chat_id>/chat/completions', methods=['POST']) # noqa: F821
@manager.route("/chats_openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821
@token_required
def chat_completion_openai_like(tenant_id, chat_id):
@@ -260,35 +251,60 @@ def chat_completion_openai_like(tenant_id, chat_id):
def streamed_response_generator(chat_id, dia, msg):
token_used = 0
answer_cache = ""
reasoning_cache = ""
response = {
"id": f"chatcmpl-{chat_id}",
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"function_call": None,
"tool_calls": None
},
"finish_reason": None,
"index": 0,
"logprobs": None
}
],
"choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None, "reasoning_content": ""}, "finish_reason": None, "index": 0, "logprobs": None}],
"created": int(time.time()),
"model": "model",
"object": "chat.completion.chunk",
"system_fingerprint": "",
"usage": None
"usage": None,
}

try:
for ans in chat(dia, msg, True):
answer = ans["answer"]
incremental = answer.replace(answer_cache, "", 1)
answer_cache = answer.rstrip("</think>")
token_used += len(incremental)
response["choices"][0]["delta"]["content"] = incremental

reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
if reasoning_match:
reasoning_part = reasoning_match.group(1)
content_part = answer[reasoning_match.end() :]
else:
reasoning_part = ""
content_part = answer

reasoning_incremental = ""
if reasoning_part:
if reasoning_part.startswith(reasoning_cache):
reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1)
else:
reasoning_incremental = reasoning_part
reasoning_cache = reasoning_part

content_incremental = ""
if content_part:
if content_part.startswith(answer_cache):
content_incremental = content_part.replace(answer_cache, "", 1)
else:
content_incremental = content_part
answer_cache = content_part

token_used += len(reasoning_incremental) + len(content_incremental)

if not any([reasoning_incremental, content_incremental]):
continue

if reasoning_incremental:
response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental
else:
response["choices"][0]["delta"]["reasoning_content"] = None

if content_incremental:
response["choices"][0]["delta"]["content"] = content_incremental
else:
response["choices"][0]["delta"]["content"] = None

yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
except Exception as e:
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
@@ -296,16 +312,12 @@ def chat_completion_openai_like(tenant_id, chat_id):

# The last chunk
response["choices"][0]["delta"]["content"] = None
response["choices"][0]["delta"]["reasoning_content"] = None
response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {
"prompt_tokens": len(prompt),
"completion_tokens": token_used,
"total_tokens": len(prompt) + token_used
}
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n"


resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
@@ -320,7 +332,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
break
content = answer["answer"]

response = {
response = {
"id": f"chatcmpl-{chat_id}",
"object": "chat.completion",
"created": int(time.time()),
@@ -332,25 +344,15 @@ def chat_completion_openai_like(tenant_id, chat_id):
"completion_tokens_details": {
"reasoning_tokens": context_token_used,
"accepted_prediction_tokens": len(content),
"rejected_prediction_tokens": 0 # 0 for simplicity
}
"rejected_prediction_tokens": 0, # 0 for simplicity
},
},
"choices": [
{
"message": {
"role": "assistant",
"content": content
},
"logprobs": None,
"finish_reason": "stop",
"index": 0
}
]
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
}
return jsonify(response)


@manager.route('/agents/<agent_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
def agent_completions(tenant_id, agent_id):
req = request.json
@@ -361,8 +363,8 @@ def agent_completions(tenant_id, agent_id):
dsl = cvs[0].dsl
if not isinstance(dsl, str):
dsl = json.dumps(dsl)
#canvas = Canvas(dsl, tenant_id)
#if canvas.get_preset_param():
# canvas = Canvas(dsl, tenant_id)
# if canvas.get_preset_param():
# req["question"] = ""
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
if not conv:
@@ -376,9 +378,7 @@ def agent_completions(tenant_id, agent_id):
states = {field: current_dsl.get(field, []) for field in state_fields}
current_dsl.update(new_dsl)
current_dsl.update(states)
API4ConversationService.update_by_id(req["session_id"], {
"dsl": current_dsl
})
API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl})
else:
req["question"] = ""
if req.get("stream", True):
@@ -395,7 +395,7 @@ def agent_completions(tenant_id, agent_id):
return get_error_data_result(str(e))


@manager.route('/chats/<chat_id>/sessions', methods=['GET']) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@token_required
def list_session(tenant_id, chat_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
@@ -414,7 +414,7 @@ def list_session(tenant_id, chat_id):
if not convs:
return get_result(data=[])
for conv in convs:
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
infos = conv["messages"]
for info in infos:
if "prompt" in info:
@@ -448,7 +448,7 @@ def list_session(tenant_id, chat_id):
return get_result(data=convs)


@manager.route('/agents/<agent_id>/sessions', methods=['GET']) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
@token_required
def list_agent_session(tenant_id, agent_id):
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
@@ -464,12 +464,11 @@ def list_agent_session(tenant_id, agent_id):
desc = True
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
user_id, include_dsl)
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
if not convs:
return get_result(data=[])
for conv in convs:
conv['messages'] = conv.pop("message")
conv["messages"] = conv.pop("message")
infos = conv["messages"]
for info in infos:
if "prompt" in info:
@@ -502,7 +501,7 @@ def list_agent_session(tenant_id, agent_id):
return get_result(data=convs)


@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"]) # noqa: F821
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@@ -528,14 +527,14 @@ def delete(tenant_id, chat_id):
return get_result()


@manager.route('/agents/<agent_id>/sessions', methods=["DELETE"]) # noqa: F821
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
@token_required
def delete_agent_session(tenant_id, agent_id):
req = request.json
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
convs = API4ConversationService.query(dialog_id=agent_id)
if not convs:
return get_error_data_result(f"Agent {agent_id} has no sessions")
@@ -551,16 +550,16 @@ def delete_agent_session(tenant_id, agent_id):
conv_list.append(conv.id)
else:
conv_list = ids
for session_id in conv_list:
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
if not conv:
return get_error_data_result(f"The agent doesn't own the session ${session_id}")
API4ConversationService.delete_by_id(session_id)
return get_result()

@manager.route('/sessions/ask', methods=['POST']) # noqa: F821

@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821
@token_required
def ask_about(tenant_id):
req = request.json
@@ -586,9 +585,7 @@ def ask_about(tenant_id):
for ans in ask(req["question"], req["kb_ids"], uid):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"

resp = Response(stream(), mimetype="text/event-stream")
@@ -599,7 +596,7 @@ def ask_about(tenant_id):
return resp


@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821
@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821
@token_required
def related_questions(tenant_id):
req = request.json
@@ -631,18 +628,27 @@ Reason:
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.

"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": f"""
ans = chat_mdl.chat(
prompt,
[
{
"role": "user",
"content": f"""
Keywords: {question}
Related search terms:
"""}], {"temperature": 0.9})
""",
}
],
{"temperature": 0.9},
)
return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])


@manager.route('/chatbots/<dialog_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/chatbots/<dialog_id>/completions", methods=["POST"]) # noqa: F821
def chatbot_completions(dialog_id):
req = request.json

token = request.headers.get('Authorization').split()
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
token = token[1]
@@ -665,11 +671,11 @@ def chatbot_completions(dialog_id):
return get_result(data=answer)


@manager.route('/agentbots/<agent_id>/completions', methods=['POST']) # noqa: F821
@manager.route("/agentbots/<agent_id>/completions", methods=["POST"]) # noqa: F821
def agent_bot_completions(agent_id):
req = request.json

token = request.headers.get('Authorization').split()
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
token = token[1]

+ 8
- 5
api/db/services/llm_service.py View File

@@ -324,15 +324,18 @@ class LLMBundle:
if self.langfuse:
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})

output = ""
ans = ""
for txt in self.mdl.chat_streamly(system, history, gen_conf):
if isinstance(txt, int):
if self.langfuse:
generation.end(output={"output": output})
generation.end(output={"output": ans})

if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))
return
return ans

if txt.endswith("</think>"):
ans = ans.rstrip("</think>")

output = txt
yield txt
ans += txt
yield ans

+ 165
- 296
rag/llm/chat_model.py
File diff suppressed because it is too large
View File


Loading…
Cancel
Save