Browse Source

Feat: conversation completion can specify different model (#9485)

### What problem does this PR solve?

Conversation completion can specify different model

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
tags/v0.20.2
Yongteng Lei 2 months ago
parent
commit
ffc095bd50
No account linked to committer's email address
3 changed files with 57 additions and 13 deletions
  1. 38
    5
      api/apps/conversation_app.py
  2. 6
    5
      api/db/db_models.py
  3. 13
    3
      api/db/services/dialog_service.py

+ 38
- 5
api/apps/conversation_app.py View File

@@ -29,7 +29,8 @@ from api.db.services.conversation_service import ConversationService, structure_
from api.db.services.dialog_service import DialogService, ask, chat
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService, TenantService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.tag import label_question
@@ -66,8 +67,14 @@ def set_conversation():
e, dia = DialogService.get_by_id(req["dialog_id"])
if not e:
return get_data_error_result(message="Dialog not found")
conv = {"id": conv_id, "dialog_id": req["dialog_id"], "name": name, "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],"user_id": current_user.id,
"reference":[],}
conv = {
"id": conv_id,
"dialog_id": req["dialog_id"],
"name": name,
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}],
"user_id": current_user.id,
"reference": [],
}
ConversationService.save(**conv)
return get_json_result(data=conv)
except Exception as e:
@@ -174,6 +181,21 @@ def completion():
continue
msg.append(m)
message_id = msg[-1].get("id")
chat_model_id = req.get("llm_id", "")
req.pop("llm_id", None)

chat_model_config = {}
for model_config in [
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"max_tokens",
]:
config = req.get(model_config)
if config:
chat_model_config[model_config] = config

try:
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
@@ -190,13 +212,23 @@ def completion():
conv.reference = [r for r in conv.reference if r]
conv.reference.append({"chunks": [], "doc_aggs": []})

if chat_model_id:
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
req.pop("chat_model_id", None)
req.pop("chat_model_config", None)
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
dia.llm_id = chat_model_id
dia.llm_setting = chat_model_config

is_embedded = bool(chat_model_id)
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
ans = structure_answer(conv, ans, message_id, conv.id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
traceback.print_exc()
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
@@ -214,7 +246,8 @@ def completion():
answer = None
for ans in chat(dia, msg, **req):
answer = structure_answer(conv, ans, message_id, conv.id)
ConversationService.update_by_id(conv.id, conv.to_dict())
if not is_embedded:
ConversationService.update_by_id(conv.id, conv.to_dict())
break
return get_json_result(data=answer)
except Exception as e:

+ 6
- 5
api/db/db_models.py View File

@@ -881,11 +881,12 @@ class Search(DataBaseModel):
# chat settings
"summary": False,
"chat_id": "",
# Leave it here for reference, don't need to set default values
"llm_setting": {
"temperature": 0.1,
"top_p": 0.3,
"frequency_penalty": 0.7,
"presence_penalty": 0.4,
# "temperature": 0.1,
# "top_p": 0.3,
# "frequency_penalty": 0.7,
# "presence_penalty": 0.4,
},
"chat_settingcross_languages": [],
"highlight": False,
@@ -1020,4 +1021,4 @@ def migrate_db():
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
except Exception:
pass
logging.disable(logging.NOTSET)
logging.disable(logging.NOTSET)

+ 13
- 3
api/db/services/dialog_service.py View File

@@ -99,7 +99,6 @@ class DialogService(CommonService):

return list(chats.dicts())


@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None):
@@ -256,9 +255,10 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):

def meta_filter(metas: dict, filters: list[dict]):
doc_ids = []

def filter_out(v2docs, operator, value):
nonlocal doc_ids
for input,docids in v2docs.items():
for input, docids in v2docs.items():
try:
input = float(input)
value = float(value)
@@ -389,7 +389,17 @@ def chat(dialog, messages, stream=True, **kwargs):
reasoner = DeepResearcher(
chat_mdl,
prompt_config,
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3, doc_ids=attachments),
partial(
retriever.retrieval,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
kb_ids=dialog.kb_ids,
page=1,
page_size=dialog.top_n,
similarity_threshold=0.2,
vector_similarity_weight=0.3,
doc_ids=attachments,
),
)

for think in reasoner.thinking(kbinfos, " ".join(questions)):

Loading…
Cancel
Save