|
|
|
@@ -29,7 +29,8 @@ from api.db.services.conversation_service import ConversationService, structure_ |
|
|
|
from api.db.services.dialog_service import DialogService, ask, chat |
|
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService |
|
|
|
from api.db.services.llm_service import LLMBundle |
|
|
|
from api.db.services.user_service import UserTenantService, TenantService |
|
|
|
from api.db.services.tenant_llm_service import TenantLLMService |
|
|
|
from api.db.services.user_service import TenantService, UserTenantService |
|
|
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request |
|
|
|
from graphrag.general.mind_map_extractor import MindMapExtractor |
|
|
|
from rag.app.tag import label_question |
|
|
|
@@ -66,8 +67,14 @@ def set_conversation(): |
|
|
|
e, dia = DialogService.get_by_id(req["dialog_id"]) |
|
|
|
if not e: |
|
|
|
return get_data_error_result(message="Dialog not found") |
|
|
|
conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": name, "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],"user_id": current_user.id, |
|
|
|
"reference":[],} |
|
|
|
conv = { |
|
|
|
"id": conv_id, |
|
|
|
"dialog_id": req["dialog_id"], |
|
|
|
"name": name, |
|
|
|
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}], |
|
|
|
"user_id": current_user.id, |
|
|
|
"reference": [], |
|
|
|
} |
|
|
|
ConversationService.save(**conv) |
|
|
|
return get_json_result(data=conv) |
|
|
|
except Exception as e: |
|
|
|
@@ -174,6 +181,21 @@ def completion(): |
|
|
|
continue |
|
|
|
msg.append(m) |
|
|
|
message_id = msg[-1].get("id") |
|
|
|
chat_model_id = req.get("llm_id", "") |
|
|
|
req.pop("llm_id", None) |
|
|
|
|
|
|
|
chat_model_config = {} |
|
|
|
for model_config in [ |
|
|
|
"temperature", |
|
|
|
"top_p", |
|
|
|
"frequency_penalty", |
|
|
|
"presence_penalty", |
|
|
|
"max_tokens", |
|
|
|
]: |
|
|
|
config = req.get(model_config) |
|
|
|
if config: |
|
|
|
chat_model_config[model_config] = config |
|
|
|
|
|
|
|
try: |
|
|
|
e, conv = ConversationService.get_by_id(req["conversation_id"]) |
|
|
|
if not e: |
|
|
|
@@ -190,13 +212,23 @@ def completion(): |
|
|
|
conv.reference = [r for r in conv.reference if r] |
|
|
|
conv.reference.append({"chunks": [], "doc_aggs": []}) |
|
|
|
|
|
|
|
if chat_model_id: |
|
|
|
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): |
|
|
|
req.pop("chat_model_id", None) |
|
|
|
req.pop("chat_model_config", None) |
|
|
|
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") |
|
|
|
dia.llm_id = chat_model_id |
|
|
|
dia.llm_setting = chat_model_config |
|
|
|
|
|
|
|
is_embedded = bool(chat_model_id) |
|
|
|
def stream(): |
|
|
|
nonlocal dia, msg, req, conv |
|
|
|
try: |
|
|
|
for ans in chat(dia, msg, True, **req): |
|
|
|
ans = structure_answer(conv, ans, message_id, conv.id) |
|
|
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" |
|
|
|
ConversationService.update_by_id(conv.id, conv.to_dict()) |
|
|
|
if not is_embedded: |
|
|
|
ConversationService.update_by_id(conv.id, conv.to_dict()) |
|
|
|
except Exception as e: |
|
|
|
traceback.print_exc() |
|
|
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" |
|
|
|
@@ -214,7 +246,8 @@ def completion(): |
|
|
|
answer = None |
|
|
|
for ans in chat(dia, msg, **req): |
|
|
|
answer = structure_answer(conv, ans, message_id, conv.id) |
|
|
|
ConversationService.update_by_id(conv.id, conv.to_dict()) |
|
|
|
if not is_embedded: |
|
|
|
ConversationService.update_by_id(conv.id, conv.to_dict()) |
|
|
|
break |
|
|
|
return get_json_result(data=answer) |
|
|
|
except Exception as e: |