Преглед на файлове

Feat: OpenAI-compatible-API supports references (#8997)

### What problem does this PR solve?

OpenAI-compatible-API supports references.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.20.0
Yongteng Lei преди 3 месеца
родител
ревизия
8345e92671
No account linked to committer's email address
променени са 2 файла, в които са добавени 114 реда и са изтрити 65 реда
  1. 100
    61
      api/apps/sdk/session.py
  2. 14
    4
      docs/references/python_api_reference.md

+ 100
- 61
api/apps/sdk/session.py Целия файл

@@ -19,21 +19,22 @@ import time

import tiktoken
from flask import Response, jsonify, request
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.canvas_service import completion as agent_completion, completionOpenAI

from agent.canvas import Canvas
from api.db import LLMType, StatusEnum
from api.db.db_models import APIToken
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.canvas_service import UserCanvasService, completionOpenAI
from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.dialog_service import DialogService, ask, chat
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils import get_uuid
from api.utils.api_utils import get_result, token_required, get_data_openai, get_error_data_result, validate_request, check_duplicate_ids
from api.db.services.llm_service import LLMBundle

from api.utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_result, token_required, validate_request
from rag.prompts import chunks_format


@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@@ -181,10 +182,16 @@ def chat_completion(tenant_id, chat_id):
def chat_completion_openai_like(tenant_id, chat_id):
"""
OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint.
This function allows users to interact with a model and receive responses based on a series of historical messages.
If `stream` is set to True (by default), the response will be streamed in chunks, mimicking the OpenAI-style API.
Set `stream` to False explicitly, the response will be returned in a single complete answer.

Reference:

- If `stream` is True, the final answer and reference information will appear in the **last chunk** of the stream.
- If `stream` is False, the reference will be included in `choices[0].message.reference`.

Example usage:

curl -X POST https://ragflow_address.com/api/v1/chats_openai/<chat_id>/chat/completions \
@@ -202,7 +209,10 @@ def chat_completion_openai_like(tenant_id, chat_id):

model = "model"
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")

stream = True
reference = True

completion = client.chat.completions.create(
model=model,
messages=[
@@ -211,17 +221,24 @@ def chat_completion_openai_like(tenant_id, chat_id):
{"role": "assistant", "content": "I am an AI assistant named..."},
{"role": "user", "content": "Can you tell me how to install neovim"},
],
stream=True
stream=stream,
extra_body={"reference": reference}
)
stream = True

if stream:
for chunk in completion:
print(chunk)
for chunk in completion:
print(chunk)
if reference and chunk.choices[0].finish_reason == "stop":
print(f"Reference:\n{chunk.choices[0].delta.reference}")
print(f"Final content:\n{chunk.choices[0].delta.final_content}")
else:
print(completion.choices[0].message.content)
if reference:
print(completion.choices[0].message.reference)
"""
req = request.json
req = request.get_json()

need_reference = bool(req.get("reference", False))

messages = req.get("messages", [])
# To prevent empty [] input
@@ -261,9 +278,23 @@ def chat_completion_openai_like(tenant_id, chat_id):
token_used = 0
answer_cache = ""
reasoning_cache = ""
last_ans = {}
response = {
"id": f"chatcmpl-{chat_id}",
"choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None, "reasoning_content": ""}, "finish_reason": None, "index": 0, "logprobs": None}],
"choices": [
{
"delta": {
"content": "",
"role": "assistant",
"function_call": None,
"tool_calls": None,
"reasoning_content": "",
},
"finish_reason": None,
"index": 0,
"logprobs": None,
}
],
"created": int(time.time()),
"model": "model",
"object": "chat.completion.chunk",
@@ -272,7 +303,8 @@ def chat_completion_openai_like(tenant_id, chat_id):
}

try:
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools):
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
last_ans = ans
answer = ans["answer"]

reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
@@ -324,6 +356,9 @@ def chat_completion_openai_like(tenant_id, chat_id):
response["choices"][0]["delta"]["reasoning_content"] = None
response["choices"][0]["finish_reason"] = "stop"
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
if need_reference:
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
yield "data:[DONE]\n\n"

@@ -335,7 +370,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
return resp
else:
answer = None
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools):
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
# focus answer content only
answer = ans
break
@@ -356,14 +391,28 @@ def chat_completion_openai_like(tenant_id, chat_id):
"rejected_prediction_tokens": 0, # 0 for simplicity
},
},
"choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}],
"choices": [
{
"message": {
"role": "assistant",
"content": content,
},
"logprobs": None,
"finish_reason": "stop",
"index": 0,
}
],
}
if need_reference:
response["choices"][0]["message"]["reference"] = chunks_format(answer.get("reference", []))

return jsonify(response)

@manager.route('/agents_openai/<agent_id>/chat/completions', methods=['POST']) # noqa: F821

@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
@validate_request("model", "messages") # noqa: F821
@token_required
def agents_completion_openai_compatibility (tenant_id, agent_id):
def agents_completion_openai_compatibility(tenant_id, agent_id):
req = request.json
tiktokenenc = tiktoken.get_encoding("cl100k_base")
messages = req.get("messages", [])
@@ -371,29 +420,31 @@ def agents_completion_openai_compatibility (tenant_id, agent_id):
return get_error_data_result("You must provide at least one message.")
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
return get_error_data_result(f"You don't own the agent {agent_id}")
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
if not filtered_messages:
return jsonify(get_data_openai(
id=agent_id,
content="No valid messages found (user or assistant).",
finish_reason="stop",
model=req.get("model", ""),
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
prompt_tokens=prompt_tokens,
))
return jsonify(
get_data_openai(
id=agent_id,
content="No valid messages found (user or assistant).",
finish_reason="stop",
model=req.get("model", ""),
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
prompt_tokens=prompt_tokens,
)
)

# Get the last user message as the question
question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
if req.get("stream", True):
return Response(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id","")), stream=True), mimetype="text/event-stream")
return Response(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id", "")), stream=True), mimetype="text/event-stream")
else:
# For non-streaming, just return the response directly
response = next(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id","")), stream=False))
response = next(completionOpenAI(tenant_id, agent_id, question, session_id=req.get("id", req.get("metadata", {}).get("id", "")), stream=False))
return jsonify(response)

@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
@token_required
@@ -547,7 +598,7 @@ def list_agent_session(tenant_id, agent_id):
def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You don't own the chat")
errors = []
success_count = 0
req = request.json
@@ -563,10 +614,10 @@ def delete(tenant_id, chat_id):
conv_list.append(conv.id)
else:
conv_list = ids
unique_conv_ids, duplicate_messages = check_duplicate_ids(conv_list, "session")
conv_list = unique_conv_ids
for id in conv_list:
conv = ConversationService.query(id=id, dialog_id=chat_id)
if not conv:
@@ -574,25 +625,19 @@ def delete(tenant_id, chat_id):
continue
ConversationService.delete_by_id(id)
success_count += 1
if errors:
if success_count > 0:
return get_result(
data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} sessions with {len(errors)} errors"
)
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
else:
return get_error_data_result(message="; ".join(errors))

if duplicate_messages:
if success_count > 0:
return get_result(
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages}
)
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
else:
return get_error_data_result(message=";".join(duplicate_messages))

return get_result()


@@ -632,25 +677,19 @@ def delete_agent_session(tenant_id, agent_id):
continue
API4ConversationService.delete_by_id(session_id)
success_count += 1
if errors:
if success_count > 0:
return get_result(
data={"success_count": success_count, "errors": errors},
message=f"Partially deleted {success_count} sessions with {len(errors)} errors"
)
return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} sessions with {len(errors)} errors")
else:
return get_error_data_result(message="; ".join(errors))

if duplicate_messages:
if success_count > 0:
return get_result(
message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors",
data={"success_count": success_count, "errors": duplicate_messages}
)
return get_result(message=f"Partially deleted {success_count} sessions with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages})
else:
return get_error_data_result(message=";".join(duplicate_messages))

return get_result()


@@ -719,7 +758,7 @@ Related search terms:

Reason:
- When searching, users often only use one or two keywords, making it difficult to fully express their information needs.
- Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
- Generating related search terms can help users dig deeper into relevant information and improve search efficiency.
- At the same time, related terms can also help search engines better understand user needs and return more accurate search results.

"""

+ 14
- 4
docs/references/python_api_reference.md Целия файл

@@ -69,21 +69,31 @@ from openai import OpenAI
model = "model"
client = OpenAI(api_key="ragflow-api-key", base_url=f"http://ragflow_address/api/v1/chats_openai/<chat_id>")

stream = True
reference = True

completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who are you?"},
{"role": "assistant", "content": "I am an AI assistant named..."},
{"role": "user", "content": "Can you tell me how to install neovim"},
],
stream=True
stream=stream,
extra_body={"reference": reference}
)

stream = True
if stream:
for chunk in completion:
print(chunk)
for chunk in completion:
print(chunk)
if reference and chunk.choices[0].finish_reason == "stop":
print(f"Reference:\n{chunk.choices[0].delta.reference}")
print(f"Final content:\n{chunk.choices[0].delta.final_content}")
else:
print(completion.choices[0].message.content)
if reference:
print(completion.choices[0].message.reference)
```

## DATASET MANAGEMENT

Loading…
Отказ
Запис