### What problem does this PR solve? change default models to buildin models https://github.com/infiniflow/ragflow/issues/7774 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.19.0
| @@ -16,6 +16,7 @@ | |||
| import logging | |||
| from flask import request | |||
| from api import settings | |||
| from api.db import StatusEnum | |||
| from api.db.services.dialog_service import DialogService | |||
| @@ -23,15 +24,14 @@ from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import TenantLLMService | |||
| from api.db.services.user_service import TenantService | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_error_data_result, token_required, get_result, check_duplicate_ids | |||
| from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required | |||
| @manager.route('/chats', methods=['POST']) # noqa: F821 | |||
| @manager.route("/chats", methods=["POST"]) # noqa: F821 | |||
| @token_required | |||
| def create(tenant_id): | |||
| req = request.json | |||
| ids = [i for i in req.get("dataset_ids", []) if i] | |||
| ids = [i for i in req.get("dataset_ids", []) if i] | |||
| for kb_id in ids: | |||
| kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) | |||
| if not kbs: | |||
| @@ -40,34 +40,30 @@ def create(tenant_id): | |||
| kb = kbs[0] | |||
| if kb.chunk_num == 0: | |||
| return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") | |||
| kbs = KnowledgebaseService.get_by_ids(ids) if ids else [] | |||
| embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison | |||
| embd_count = list(set(embd_ids)) | |||
| if len(embd_count) > 1: | |||
| return get_result(message='Datasets use different embedding models."', | |||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req["kb_ids"] = ids | |||
| # llm | |||
| llm = req.get("llm") | |||
| if llm: | |||
| if "model_name" in llm: | |||
| req["llm_id"] = llm.pop("model_name") | |||
| if not TenantLLMService.query(tenant_id=tenant_id, llm_name=req["llm_id"], model_type="chat"): | |||
| return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") | |||
| if req.get("llm_id") is not None: | |||
| llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"]) | |||
| if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"): | |||
| return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") | |||
| req["llm_setting"] = req.pop("llm") | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| return get_error_data_result(message="Tenant not found!") | |||
| # prompt | |||
| prompt = req.get("prompt") | |||
| key_mapping = {"parameters": "variables", | |||
| "prologue": "opener", | |||
| "quote": "show_quote", | |||
| "system": "prompt", | |||
| "rerank_id": "rerank_model", | |||
| "vector_similarity_weight": "keywords_similarity_weight"} | |||
| key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id","top_k"] | |||
| key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"} | |||
| key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"] | |||
| if prompt: | |||
| for new_key, old_key in key_mapping.items(): | |||
| if old_key in prompt: | |||
| @@ -85,9 +81,7 @@ def create(tenant_id): | |||
| req["rerank_id"] = req.get("rerank_id", "") | |||
| if req.get("rerank_id"): | |||
| value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] | |||
| if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, | |||
| llm_name=req.get("rerank_id"), | |||
| model_type="rerank"): | |||
| if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"): | |||
| return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") | |||
| if not req.get("llm_id"): | |||
| req["llm_id"] = tenant.llm_id | |||
| @@ -106,27 +100,24 @@ def create(tenant_id): | |||
| {knowledge} | |||
| The above is the knowledge base.""", | |||
| "prologue": "Hi! I'm your assistant, what can I do for you?", | |||
| "parameters": [ | |||
| {"key": "knowledge", "optional": False} | |||
| ], | |||
| "parameters": [{"key": "knowledge", "optional": False}], | |||
| "empty_response": "Sorry! No relevant content was found in the knowledge base!", | |||
| "quote": True, | |||
| "tts": False, | |||
| "refine_multiturn": True | |||
| "refine_multiturn": True, | |||
| } | |||
| key_list_2 = ["system", "prologue", "parameters", "empty_response", "quote", "tts", "refine_multiturn"] | |||
| if "prompt_config" not in req: | |||
| req['prompt_config'] = {} | |||
| req["prompt_config"] = {} | |||
| for key in key_list_2: | |||
| temp = req['prompt_config'].get(key) | |||
| if (not temp and key == 'system') or (key not in req["prompt_config"]): | |||
| req['prompt_config'][key] = default_prompt[key] | |||
| for p in req['prompt_config']["parameters"]: | |||
| temp = req["prompt_config"].get(key) | |||
| if (not temp and key == "system") or (key not in req["prompt_config"]): | |||
| req["prompt_config"][key] = default_prompt[key] | |||
| for p in req["prompt_config"]["parameters"]: | |||
| if p["optional"]: | |||
| continue | |||
| if req['prompt_config']["system"].find("{%s}" % p["key"]) < 0: | |||
| return get_error_data_result( | |||
| message="Parameter '{}' is not used".format(p["key"])) | |||
| if req["prompt_config"]["system"].find("{%s}" % p["key"]) < 0: | |||
| return get_error_data_result(message="Parameter '{}' is not used".format(p["key"])) | |||
| # save | |||
| if not DialogService.save(**req): | |||
| return get_error_data_result(message="Fail to new a chat!") | |||
| @@ -141,10 +132,7 @@ def create(tenant_id): | |||
| renamed_dict[new_key] = value | |||
| res["prompt"] = renamed_dict | |||
| del res["prompt_config"] | |||
| new_dict = {"similarity_threshold": res["similarity_threshold"], | |||
| "keywords_similarity_weight": 1-res["vector_similarity_weight"], | |||
| "top_n": res["top_n"], | |||
| "rerank_model": res['rerank_id']} | |||
| new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]} | |||
| res["prompt"].update(new_dict) | |||
| for key in key_list: | |||
| del res[key] | |||
| @@ -156,11 +144,11 @@ def create(tenant_id): | |||
| return get_result(data=res) | |||
| @manager.route('/chats/<chat_id>', methods=['PUT']) # noqa: F821 | |||
| @manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821 | |||
| @token_required | |||
| def update(tenant_id, chat_id): | |||
| if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): | |||
| return get_error_data_result(message='You do not own the chat') | |||
| return get_error_data_result(message="You do not own the chat") | |||
| req = request.json | |||
| ids = req.get("dataset_ids") | |||
| if "show_quotation" in req: | |||
| @@ -174,14 +162,12 @@ def update(tenant_id, chat_id): | |||
| kb = kbs[0] | |||
| if kb.chunk_num == 0: | |||
| return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") | |||
| kbs = KnowledgebaseService.get_by_ids(ids) | |||
| embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison | |||
| embd_count = list(set(embd_ids)) | |||
| if len(embd_count) != 1: | |||
| return get_result( | |||
| message='Datasets use different embedding models."', | |||
| code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| return get_result(message='Datasets use different embedding models."', code=settings.RetCode.AUTHENTICATION_ERROR) | |||
| req["kb_ids"] = ids | |||
| llm = req.get("llm") | |||
| if llm: | |||
| @@ -195,13 +181,8 @@ def update(tenant_id, chat_id): | |||
| return get_error_data_result(message="Tenant not found!") | |||
| # prompt | |||
| prompt = req.get("prompt") | |||
| key_mapping = {"parameters": "variables", | |||
| "prologue": "opener", | |||
| "quote": "show_quote", | |||
| "system": "prompt", | |||
| "rerank_id": "rerank_model", | |||
| "vector_similarity_weight": "keywords_similarity_weight"} | |||
| key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id","top_k"] | |||
| key_mapping = {"parameters": "variables", "prologue": "opener", "quote": "show_quote", "system": "prompt", "rerank_id": "rerank_model", "vector_similarity_weight": "keywords_similarity_weight"} | |||
| key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id", "top_k"] | |||
| if prompt: | |||
| for new_key, old_key in key_mapping.items(): | |||
| if old_key in prompt: | |||
| @@ -214,16 +195,12 @@ def update(tenant_id, chat_id): | |||
| res = res.to_json() | |||
| if req.get("rerank_id"): | |||
| value_rerank_model = ["BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"] | |||
| if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, | |||
| llm_name=req.get("rerank_id"), | |||
| model_type="rerank"): | |||
| if req["rerank_id"] not in value_rerank_model and not TenantLLMService.query(tenant_id=tenant_id, llm_name=req.get("rerank_id"), model_type="rerank"): | |||
| return get_error_data_result(f"`rerank_model` {req.get('rerank_id')} doesn't exist") | |||
| if "name" in req: | |||
| if not req.get("name"): | |||
| return get_error_data_result(message="`name` cannot be empty.") | |||
| if req["name"].lower() != res["name"].lower() \ | |||
| and len( | |||
| DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: | |||
| if req["name"].lower() != res["name"].lower() and len(DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: | |||
| return get_error_data_result(message="Duplicated chat name in updating chat.") | |||
| if "prompt_config" in req: | |||
| res["prompt_config"].update(req["prompt_config"]) | |||
| @@ -246,7 +223,7 @@ def update(tenant_id, chat_id): | |||
| return get_result() | |||
| @manager.route('/chats', methods=['DELETE']) # noqa: F821 | |||
| @manager.route("/chats", methods=["DELETE"]) # noqa: F821 | |||
| @token_required | |||
| def delete(tenant_id): | |||
| errors = [] | |||
| @@ -273,30 +250,23 @@ def delete(tenant_id): | |||
| temp_dict = {"status": StatusEnum.INVALID.value} | |||
| DialogService.update_by_id(id, temp_dict) | |||
| success_count += 1 | |||
| if errors: | |||
| if success_count > 0: | |||
| return get_result( | |||
| data={"success_count": success_count, "errors": errors}, | |||
| message=f"Partially deleted {success_count} chats with {len(errors)} errors" | |||
| ) | |||
| return get_result(data={"success_count": success_count, "errors": errors}, message=f"Partially deleted {success_count} chats with {len(errors)} errors") | |||
| else: | |||
| return get_error_data_result(message="; ".join(errors)) | |||
| if duplicate_messages: | |||
| if success_count > 0: | |||
| return get_result( | |||
| message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", | |||
| data={"success_count": success_count, "errors": duplicate_messages} | |||
| ) | |||
| return get_result(message=f"Partially deleted {success_count} chats with {len(duplicate_messages)} errors", data={"success_count": success_count, "errors": duplicate_messages}) | |||
| else: | |||
| return get_error_data_result(message=";".join(duplicate_messages)) | |||
| return get_result() | |||
| return get_result() | |||
| @manager.route('/chats', methods=['GET']) # noqa: F821 | |||
| @manager.route("/chats", methods=["GET"]) # noqa: F821 | |||
| @token_required | |||
| def list_chat(tenant_id): | |||
| id = request.args.get("id") | |||
| @@ -316,13 +286,15 @@ def list_chat(tenant_id): | |||
| if not chats: | |||
| return get_result(data=[]) | |||
| list_assts = [] | |||
| key_mapping = {"parameters": "variables", | |||
| "prologue": "opener", | |||
| "quote": "show_quote", | |||
| "system": "prompt", | |||
| "rerank_id": "rerank_model", | |||
| "vector_similarity_weight": "keywords_similarity_weight", | |||
| "do_refer": "show_quotation"} | |||
| key_mapping = { | |||
| "parameters": "variables", | |||
| "prologue": "opener", | |||
| "quote": "show_quote", | |||
| "system": "prompt", | |||
| "rerank_id": "rerank_model", | |||
| "vector_similarity_weight": "keywords_similarity_weight", | |||
| "do_refer": "show_quotation", | |||
| } | |||
| key_list = ["similarity_threshold", "vector_similarity_weight", "top_n", "rerank_id"] | |||
| for res in chats: | |||
| renamed_dict = {} | |||
| @@ -331,10 +303,7 @@ def list_chat(tenant_id): | |||
| renamed_dict[new_key] = value | |||
| res["prompt"] = renamed_dict | |||
| del res["prompt_config"] | |||
| new_dict = {"similarity_threshold": res["similarity_threshold"], | |||
| "keywords_similarity_weight": 1-res["vector_similarity_weight"], | |||
| "top_n": res["top_n"], | |||
| "rerank_model": res['rerank_id']} | |||
| new_dict = {"similarity_threshold": res["similarity_threshold"], "keywords_similarity_weight": 1 - res["vector_similarity_weight"], "top_n": res["top_n"], "rerank_model": res["rerank_id"]} | |||
| res["prompt"].update(new_dict) | |||
| for key in key_list: | |||
| del res[key] | |||
| @@ -13,36 +13,37 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import logging | |||
| import json | |||
| import logging | |||
| import re | |||
| from datetime import datetime | |||
| from flask import request, session, redirect | |||
| from werkzeug.security import generate_password_hash, check_password_hash | |||
| from flask_login import login_required, current_user, login_user, logout_user | |||
| from flask import redirect, request, session | |||
| from flask_login import current_user, login_required, login_user, logout_user | |||
| from werkzeug.security import check_password_hash, generate_password_hash | |||
| from api import settings | |||
| from api.apps.auth import get_auth_client | |||
| from api.db import FileType, UserTenantRole | |||
| from api.db.db_models import TenantLLM | |||
| from api.db.services.llm_service import TenantLLMService, LLMService | |||
| from api.utils.api_utils import ( | |||
| server_error_response, | |||
| validate_request, | |||
| get_data_error_result, | |||
| ) | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||
| from api.db.services.user_service import TenantService, UserService, UserTenantService | |||
| from api.utils import ( | |||
| get_uuid, | |||
| get_format_time, | |||
| decrypt, | |||
| download_img, | |||
| current_timestamp, | |||
| datetime_format, | |||
| decrypt, | |||
| download_img, | |||
| get_format_time, | |||
| get_uuid, | |||
| ) | |||
| from api.utils.api_utils import ( | |||
| construct_response, | |||
| get_data_error_result, | |||
| get_json_result, | |||
| server_error_response, | |||
| validate_request, | |||
| ) | |||
| from api.db import UserTenantRole, FileType | |||
| from api import settings | |||
| from api.db.services.user_service import UserService, TenantService, UserTenantService | |||
| from api.db.services.file_service import FileService | |||
| from api.utils.api_utils import get_json_result, construct_response | |||
| from api.apps.auth import get_auth_client | |||
| @manager.route("/login", methods=["POST", "GET"]) # noqa: F821 | |||
| @@ -77,9 +78,7 @@ def login(): | |||
| type: object | |||
| """ | |||
| if not request.json: | |||
| return get_json_result( | |||
| data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" | |||
| ) | |||
| return get_json_result(data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!") | |||
| email = request.json.get("email", "") | |||
| users = UserService.query(email=email) | |||
| @@ -94,9 +93,7 @@ def login(): | |||
| try: | |||
| password = decrypt(password) | |||
| except BaseException: | |||
| return get_json_result( | |||
| data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password" | |||
| ) | |||
| return get_json_result(data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password") | |||
| user = UserService.query_user(email, password) | |||
| if user: | |||
| @@ -116,7 +113,7 @@ def login(): | |||
| ) | |||
| @manager.route("/login/channels", methods=["GET"]) # noqa: F821 | |||
| @manager.route("/login/channels", methods=["GET"]) # noqa: F821 | |||
| def get_login_channels(): | |||
| """ | |||
| Get all supported authentication channels. | |||
| @@ -124,22 +121,20 @@ def get_login_channels(): | |||
| try: | |||
| channels = [] | |||
| for channel, config in settings.OAUTH_CONFIG.items(): | |||
| channels.append({ | |||
| "channel": channel, | |||
| "display_name": config.get("display_name", channel.title()), | |||
| "icon": config.get("icon", "sso"), | |||
| }) | |||
| channels.append( | |||
| { | |||
| "channel": channel, | |||
| "display_name": config.get("display_name", channel.title()), | |||
| "icon": config.get("icon", "sso"), | |||
| } | |||
| ) | |||
| return get_json_result(data=channels) | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| return get_json_result( | |||
| data=[], | |||
| message=f"Load channels failure, error: {str(e)}", | |||
| code=settings.RetCode.EXCEPTION_ERROR | |||
| ) | |||
| return get_json_result(data=[], message=f"Load channels failure, error: {str(e)}", code=settings.RetCode.EXCEPTION_ERROR) | |||
| @manager.route("/login/<channel>", methods=["GET"]) # noqa: F821 | |||
| @manager.route("/login/<channel>", methods=["GET"]) # noqa: F821 | |||
| def oauth_login(channel): | |||
| channel_config = settings.OAUTH_CONFIG.get(channel) | |||
| if not channel_config: | |||
| @@ -152,7 +147,7 @@ def oauth_login(channel): | |||
| return redirect(auth_url) | |||
| @manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821 | |||
| @manager.route("/oauth/callback/<channel>", methods=["GET"]) # noqa: F821 | |||
| def oauth_callback(channel): | |||
| """ | |||
| Handle the OAuth/OIDC callback for various channels dynamically. | |||
| @@ -190,7 +185,7 @@ def oauth_callback(channel): | |||
| # Login or register | |||
| users = UserService.query(email=user_info.email) | |||
| user_id = get_uuid() | |||
| if not users: | |||
| try: | |||
| try: | |||
| @@ -434,9 +429,7 @@ def user_info_from_feishu(access_token): | |||
| "Content-Type": "application/json; charset=utf-8", | |||
| "Authorization": f"Bearer {access_token}", | |||
| } | |||
| res = requests.get( | |||
| "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers | |||
| ) | |||
| res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) | |||
| user_info = res.json()["data"] | |||
| user_info["email"] = None if user_info.get("email") == "" else user_info["email"] | |||
| return user_info | |||
| @@ -446,17 +439,13 @@ def user_info_from_github(access_token): | |||
| import requests | |||
| headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} | |||
| res = requests.get( | |||
| f"https://api.github.com/user?access_token={access_token}", headers=headers | |||
| ) | |||
| res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) | |||
| user_info = res.json() | |||
| email_info = requests.get( | |||
| f"https://api.github.com/user/emails?access_token={access_token}", | |||
| headers=headers, | |||
| ).json() | |||
| user_info["email"] = next( | |||
| (email for email in email_info if email["primary"]), None | |||
| )["email"] | |||
| user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] | |||
| return user_info | |||
| @@ -516,9 +505,7 @@ def setting_user(): | |||
| request_data = request.json | |||
| if request_data.get("password"): | |||
| new_password = request_data.get("new_password") | |||
| if not check_password_hash( | |||
| current_user.password, decrypt(request_data["password"]) | |||
| ): | |||
| if not check_password_hash(current_user.password, decrypt(request_data["password"])): | |||
| return get_json_result( | |||
| data=False, | |||
| code=settings.RetCode.AUTHENTICATION_ERROR, | |||
| @@ -549,9 +536,7 @@ def setting_user(): | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| return get_json_result( | |||
| data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR | |||
| ) | |||
| return get_json_result(data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR) | |||
| @manager.route("/info", methods=["GET"]) # noqa: F821 | |||
| @@ -643,9 +628,23 @@ def user_register(user_id, user): | |||
| "model_type": llm.model_type, | |||
| "api_key": settings.API_KEY, | |||
| "api_base": settings.LLM_BASE_URL, | |||
| "max_tokens": llm.max_tokens if llm.max_tokens else 8192 | |||
| "max_tokens": llm.max_tokens if llm.max_tokens else 8192, | |||
| } | |||
| ) | |||
| if settings.LIGHTEN != 1: | |||
| for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS: | |||
| mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model) | |||
| tenant_llm.append( | |||
| { | |||
| "tenant_id": user_id, | |||
| "llm_factory": fid, | |||
| "llm_name": mdlnm, | |||
| "model_type": "embedding", | |||
| "api_key": "", | |||
| "api_base": "", | |||
| "max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512, | |||
| } | |||
| ) | |||
| if not UserService.save(**user): | |||
| return | |||
| @@ -81,7 +81,7 @@ def init_settings(): | |||
| DATABASE = decrypt_database_config(name=DATABASE_TYPE) | |||
| LLM = get_base_config("user_default_llm", {}) | |||
| LLM_DEFAULT_MODELS = LLM.get("default_models", {}) | |||
| LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") | |||
| LLM_FACTORY = LLM.get("factory") | |||
| LLM_BASE_URL = LLM.get("base_url") | |||
| try: | |||
| REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1")) | |||
| @@ -567,7 +567,7 @@ | |||
| { | |||
| "name": "Youdao", | |||
| "logo": "", | |||
| "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION", | |||
| "tags": "TEXT EMBEDDING", | |||
| "status": "1", | |||
| "llm": [ | |||
| { | |||
| @@ -755,7 +755,7 @@ | |||
| { | |||
| "name": "BAAI", | |||
| "logo": "", | |||
| "tags": "TEXT EMBEDDING, TEXT RE-RANK", | |||
| "tags": "TEXT EMBEDDING", | |||
| "status": "1", | |||
| "llm": [ | |||
| { | |||
| @@ -20,7 +20,9 @@ import pytest | |||
| import requests | |||
| HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") | |||
| ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra") | |||
| if ZHIPU_AI_API_KEY is None: | |||
| pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set") | |||
| # def generate_random_email(): | |||
| # return 'user_' + ''.join(random.choices(string.ascii_lowercase + string.digits, k=8))+'@1.com' | |||
| @@ -87,3 +89,64 @@ def get_auth(): | |||
| @pytest.fixture(scope="session") | |||
| def get_email(): | |||
| return EMAIL | |||
| def get_my_llms(auth, name): | |||
| url = HOST_ADDRESS + "/v1/llm/my_llms" | |||
| authorization = {"Authorization": auth} | |||
| response = requests.get(url=url, headers=authorization) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| if name in res.get("data"): | |||
| return True | |||
| return False | |||
| def add_models(auth): | |||
| url = HOST_ADDRESS + "/v1/llm/set_api_key" | |||
| authorization = {"Authorization": auth} | |||
| models_info = { | |||
| "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY}, | |||
| } | |||
| for name, model_info in models_info.items(): | |||
| if not get_my_llms(auth, name): | |||
| response = requests.post(url=url, headers=authorization, json=model_info) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| pytest.exit(f"Critical error in add_models: {res.get('message')}") | |||
| def get_tenant_info(auth): | |||
| url = HOST_ADDRESS + "/v1/user/tenant_info" | |||
| authorization = {"Authorization": auth} | |||
| response = requests.get(url=url, headers=authorization) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| return res["data"].get("tenant_id") | |||
| @pytest.fixture(scope="session", autouse=True) | |||
| def set_tenant_info(get_auth): | |||
| auth = get_auth | |||
| try: | |||
| add_models(auth) | |||
| tenant_id = get_tenant_info(auth) | |||
| except Exception as e: | |||
| pytest.exit(f"Error in set_tenant_info: {str(e)}") | |||
| url = HOST_ADDRESS + "/v1/user/set_tenant_info" | |||
| authorization = {"Authorization": get_auth} | |||
| tenant_info = { | |||
| "tenant_id": tenant_id, | |||
| "llm_id": "glm-4-flash@ZHIPU-AI", | |||
| "embd_id": "BAAI/bge-large-zh-v1.5@BAAI", | |||
| "img2txt_id": "glm-4v@ZHIPU-AI", | |||
| "asr_id": "", | |||
| "tts_id": None, | |||
| } | |||
| response = requests.post(url=url, headers=authorization, json=tenant_info) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| @@ -16,7 +16,6 @@ | |||
| import os | |||
| import pytest | |||
| import requests | |||
| from common import ( | |||
| add_chunk, | |||
| batch_create_datasets, | |||
| @@ -49,9 +48,6 @@ MARKER_EXPRESSIONS = { | |||
| "p3": "p1 or p2 or p3", | |||
| } | |||
| HOST_ADDRESS = os.getenv("HOST_ADDRESS", "http://127.0.0.1:9380") | |||
| ZHIPU_AI_API_KEY = os.getenv("ZHIPU_AI_API_KEY", "ca148e43209c40109e2bc2f56281dd11.BltyA2N1B043B7Ra") | |||
| if ZHIPU_AI_API_KEY is None: | |||
| pytest.exit("Error: Environment variable ZHIPU_AI_API_KEY must be set") | |||
| def pytest_addoption(parser: pytest.Parser) -> None: | |||
| @@ -85,67 +81,6 @@ def get_http_api_auth(get_api_key_fixture): | |||
| return RAGFlowHttpApiAuth(get_api_key_fixture) | |||
| def get_my_llms(auth, name): | |||
| url = HOST_ADDRESS + "/v1/llm/my_llms" | |||
| authorization = {"Authorization": auth} | |||
| response = requests.get(url=url, headers=authorization) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| if name in res.get("data"): | |||
| return True | |||
| return False | |||
| def add_models(auth): | |||
| url = HOST_ADDRESS + "/v1/llm/set_api_key" | |||
| authorization = {"Authorization": auth} | |||
| models_info = { | |||
| "ZHIPU-AI": {"llm_factory": "ZHIPU-AI", "api_key": ZHIPU_AI_API_KEY}, | |||
| } | |||
| for name, model_info in models_info.items(): | |||
| if not get_my_llms(auth, name): | |||
| response = requests.post(url=url, headers=authorization, json=model_info) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| pytest.exit(f"Critical error in add_models: {res.get('message')}") | |||
| def get_tenant_info(auth): | |||
| url = HOST_ADDRESS + "/v1/user/tenant_info" | |||
| authorization = {"Authorization": auth} | |||
| response = requests.get(url=url, headers=authorization) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| return res["data"].get("tenant_id") | |||
| @pytest.fixture(scope="session", autouse=True) | |||
| def set_tenant_info(get_auth): | |||
| auth = get_auth | |||
| try: | |||
| add_models(auth) | |||
| tenant_id = get_tenant_info(auth) | |||
| except Exception as e: | |||
| pytest.exit(f"Error in set_tenant_info: {str(e)}") | |||
| url = HOST_ADDRESS + "/v1/user/set_tenant_info" | |||
| authorization = {"Authorization": get_auth} | |||
| tenant_info = { | |||
| "tenant_id": tenant_id, | |||
| "llm_id": "glm-4-flash@ZHIPU-AI", | |||
| "embd_id": "BAAI/bge-large-zh-v1.5@BAAI", | |||
| "img2txt_id": "glm-4v@ZHIPU-AI", | |||
| "asr_id": "", | |||
| "tts_id": None, | |||
| } | |||
| response = requests.post(url=url, headers=authorization, json=tenant_info) | |||
| res = response.json() | |||
| if res.get("code") != 0: | |||
| raise Exception(res.get("message")) | |||
| @pytest.fixture(scope="function") | |||
| def clear_datasets(request, get_http_api_auth): | |||
| def cleanup(): | |||
| @@ -14,8 +14,9 @@ | |||
| # limitations under the License. | |||
| # | |||
| from ragflow_sdk import RAGFlow | |||
| from common import HOST_ADDRESS | |||
| from ragflow_sdk import RAGFlow | |||
| from ragflow_sdk.modules.chat import Chat | |||
| def test_create_chat_with_name(get_api_key_fixture): | |||
| @@ -31,7 +32,18 @@ def test_create_chat_with_name(get_api_key_fixture): | |||
| docs = kb.upload_documents(documents) | |||
| for doc in docs: | |||
| doc.add_chunk("This is a test to add chunk") | |||
| rag.create_chat("test_create_chat", dataset_ids=[kb.id]) | |||
| llm = Chat.LLM( | |||
| rag, | |||
| { | |||
| "model_name": "glm-4-flash@ZHIPU-AI", | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "presence_penalty": 0.4, | |||
| "frequency_penalty": 0.7, | |||
| "max_tokens": 512, | |||
| }, | |||
| ) | |||
| rag.create_chat("test_create_chat", dataset_ids=[kb.id], llm=llm) | |||
| def test_update_chat_with_name(get_api_key_fixture): | |||
| @@ -47,7 +59,18 @@ def test_update_chat_with_name(get_api_key_fixture): | |||
| docs = kb.upload_documents(documents) | |||
| for doc in docs: | |||
| doc.add_chunk("This is a test to add chunk") | |||
| chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id]) | |||
| llm = Chat.LLM( | |||
| rag, | |||
| { | |||
| "model_name": "glm-4-flash@ZHIPU-AI", | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "presence_penalty": 0.4, | |||
| "frequency_penalty": 0.7, | |||
| "max_tokens": 512, | |||
| }, | |||
| ) | |||
| chat = rag.create_chat("test_update_chat", dataset_ids=[kb.id], llm=llm) | |||
| chat.update({"name": "new_chat"}) | |||
| @@ -64,7 +87,18 @@ def test_delete_chats_with_success(get_api_key_fixture): | |||
| docs = kb.upload_documents(documents) | |||
| for doc in docs: | |||
| doc.add_chunk("This is a test to add chunk") | |||
| chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id]) | |||
| llm = Chat.LLM( | |||
| rag, | |||
| { | |||
| "model_name": "glm-4-flash@ZHIPU-AI", | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "presence_penalty": 0.4, | |||
| "frequency_penalty": 0.7, | |||
| "max_tokens": 512, | |||
| }, | |||
| ) | |||
| chat = rag.create_chat("test_delete_chat", dataset_ids=[kb.id], llm=llm) | |||
| rag.delete_chats(ids=[chat.id]) | |||
| @@ -81,6 +115,17 @@ def test_list_chats_with_success(get_api_key_fixture): | |||
| docs = kb.upload_documents(documents) | |||
| for doc in docs: | |||
| doc.add_chunk("This is a test to add chunk") | |||
| rag.create_chat("test_list_1", dataset_ids=[kb.id]) | |||
| rag.create_chat("test_list_2", dataset_ids=[kb.id]) | |||
| llm = Chat.LLM( | |||
| rag, | |||
| { | |||
| "model_name": "glm-4-flash@ZHIPU-AI", | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "presence_penalty": 0.4, | |||
| "frequency_penalty": 0.7, | |||
| "max_tokens": 512, | |||
| }, | |||
| ) | |||
| rag.create_chat("test_list_1", dataset_ids=[kb.id], llm=llm) | |||
| rag.create_chat("test_list_2", dataset_ids=[kb.id], llm=llm) | |||
| rag.list_chats() | |||