Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

api_utils.py 18KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. import json
  18. import logging
  19. import random
  20. import time
  21. from base64 import b64encode
  22. from copy import deepcopy
  23. from functools import wraps
  24. from hmac import HMAC
  25. from io import BytesIO
  26. from urllib.parse import quote, urlencode
  27. from uuid import uuid1
  28. import requests
  29. from flask import (
  30. Response,
  31. jsonify,
  32. make_response,
  33. send_file,
  34. )
  35. from flask import (
  36. request as flask_request,
  37. )
  38. from itsdangerous import URLSafeTimedSerializer
  39. from peewee import OperationalError
  40. from werkzeug.http import HTTP_STATUS_CODES
  41. from api import settings
  42. from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
  43. from api.db.db_models import APIToken
  44. from api.db.services.llm_service import LLMService, TenantLLMService
  45. from api.utils import CustomJSONEncoder, get_uuid, json_dumps
  46. requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
  47. def request(**kwargs):
  48. sess = requests.Session()
  49. stream = kwargs.pop("stream", sess.stream)
  50. timeout = kwargs.pop("timeout", None)
  51. kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()}
  52. prepped = requests.Request(**kwargs).prepare()
  53. if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
  54. timestamp = str(round(time() * 1000))
  55. nonce = str(uuid1())
  56. signature = b64encode(
  57. HMAC(
  58. settings.SECRET_KEY.encode("ascii"),
  59. b"\n".join(
  60. [
  61. timestamp.encode("ascii"),
  62. nonce.encode("ascii"),
  63. settings.HTTP_APP_KEY.encode("ascii"),
  64. prepped.path_url.encode("ascii"),
  65. prepped.body if kwargs.get("json") else b"",
  66. urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
  67. ]
  68. ),
  69. "sha1",
  70. ).digest()
  71. ).decode("ascii")
  72. prepped.headers.update(
  73. {
  74. "TIMESTAMP": timestamp,
  75. "NONCE": nonce,
  76. "APP-KEY": settings.HTTP_APP_KEY,
  77. "SIGNATURE": signature,
  78. }
  79. )
  80. return sess.send(prepped, stream=stream, timeout=timeout)
  81. def get_exponential_backoff_interval(retries, full_jitter=False):
  82. """Calculate the exponential backoff wait time."""
  83. # Will be zero if factor equals 0
  84. countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
  85. # Full jitter according to
  86. # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
  87. if full_jitter:
  88. countdown = random.randrange(countdown + 1)
  89. # Adjust according to maximum wait time and account for negative values.
  90. return max(0, countdown)
  91. def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"):
  92. logging.exception(Exception(message))
  93. result_dict = {"code": code, "message": message}
  94. response = {}
  95. for key, value in result_dict.items():
  96. if value is None and key != "code":
  97. continue
  98. else:
  99. response[key] = value
  100. return jsonify(response)
  101. def server_error_response(e):
  102. logging.exception(e)
  103. try:
  104. if e.code == 401:
  105. return get_json_result(code=401, message=repr(e))
  106. except BaseException:
  107. pass
  108. if len(e.args) > 1:
  109. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
  110. if repr(e).find("index_not_found_exception") >= 0:
  111. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
  112. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
  113. def error_response(response_code, message=None):
  114. if message is None:
  115. message = HTTP_STATUS_CODES.get(response_code, "Unknown Error")
  116. return Response(
  117. json.dumps(
  118. {
  119. "message": message,
  120. "code": response_code,
  121. }
  122. ),
  123. status=response_code,
  124. mimetype="application/json",
  125. )
  126. def validate_request(*args, **kwargs):
  127. def wrapper(func):
  128. @wraps(func)
  129. def decorated_function(*_args, **_kwargs):
  130. input_arguments = flask_request.json or flask_request.form.to_dict()
  131. no_arguments = []
  132. error_arguments = []
  133. for arg in args:
  134. if arg not in input_arguments:
  135. no_arguments.append(arg)
  136. for k, v in kwargs.items():
  137. config_value = input_arguments.get(k, None)
  138. if config_value is None:
  139. no_arguments.append(k)
  140. elif isinstance(v, (tuple, list)):
  141. if config_value not in v:
  142. error_arguments.append((k, set(v)))
  143. elif config_value != v:
  144. error_arguments.append((k, v))
  145. if no_arguments or error_arguments:
  146. error_string = ""
  147. if no_arguments:
  148. error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
  149. if error_arguments:
  150. error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
  151. return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
  152. return func(*_args, **_kwargs)
  153. return decorated_function
  154. return wrapper
  155. def not_allowed_parameters(*params):
  156. def decorator(f):
  157. def wrapper(*args, **kwargs):
  158. input_arguments = flask_request.json or flask_request.form.to_dict()
  159. for param in params:
  160. if param in input_arguments:
  161. return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
  162. return f(*args, **kwargs)
  163. return wrapper
  164. return decorator
  165. def is_localhost(ip):
  166. return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
  167. def send_file_in_mem(data, filename):
  168. if not isinstance(data, (str, bytes)):
  169. data = json_dumps(data)
  170. if isinstance(data, str):
  171. data = data.encode("utf-8")
  172. f = BytesIO()
  173. f.write(data)
  174. f.seek(0)
  175. return send_file(f, as_attachment=True, attachment_filename=filename)
  176. def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
  177. response = {"code": code, "message": message, "data": data}
  178. return jsonify(response)
  179. def apikey_required(func):
  180. @wraps(func)
  181. def decorated_function(*args, **kwargs):
  182. token = flask_request.headers.get("Authorization").split()[1]
  183. objs = APIToken.query(token=token)
  184. if not objs:
  185. return build_error_result(message="API-KEY is invalid!", code=settings.RetCode.FORBIDDEN)
  186. kwargs["tenant_id"] = objs[0].tenant_id
  187. return func(*args, **kwargs)
  188. return decorated_function
  189. def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"):
  190. response = {"code": code, "message": message}
  191. response = jsonify(response)
  192. response.status_code = code
  193. return response
  194. def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None):
  195. result_dict = {"code": code, "message": message, "data": data}
  196. response_dict = {}
  197. for key, value in result_dict.items():
  198. if value is None and key != "code":
  199. continue
  200. else:
  201. response_dict[key] = value
  202. response = make_response(jsonify(response_dict))
  203. if auth:
  204. response.headers["Authorization"] = auth
  205. response.headers["Access-Control-Allow-Origin"] = "*"
  206. response.headers["Access-Control-Allow-Method"] = "*"
  207. response.headers["Access-Control-Allow-Headers"] = "*"
  208. response.headers["Access-Control-Allow-Headers"] = "*"
  209. response.headers["Access-Control-Expose-Headers"] = "Authorization"
  210. return response
  211. def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"):
  212. result_dict = {"code": code, "message": message}
  213. response = {}
  214. for key, value in result_dict.items():
  215. if value is None and key != "code":
  216. continue
  217. else:
  218. response[key] = value
  219. return jsonify(response)
  220. def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
  221. if data is None:
  222. return jsonify({"code": code, "message": message})
  223. else:
  224. return jsonify({"code": code, "message": message, "data": data})
  225. def construct_error_response(e):
  226. logging.exception(e)
  227. try:
  228. if e.code == 401:
  229. return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
  230. except BaseException:
  231. pass
  232. if len(e.args) > 1:
  233. return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
  234. return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
  235. def token_required(func):
  236. @wraps(func)
  237. def decorated_function(*args, **kwargs):
  238. authorization_str = flask_request.headers.get("Authorization")
  239. if not authorization_str:
  240. return get_json_result(data=False, message="`Authorization` can't be empty")
  241. authorization_list = authorization_str.split()
  242. if len(authorization_list) < 2:
  243. return get_json_result(data=False, message="Please check your authorization format.")
  244. token = authorization_list[1]
  245. objs = APIToken.query(token=token)
  246. if not objs:
  247. return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR)
  248. kwargs["tenant_id"] = objs[0].tenant_id
  249. return func(*args, **kwargs)
  250. return decorated_function
  251. def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
  252. if code == 0:
  253. if data is not None:
  254. response = {"code": code, "data": data}
  255. else:
  256. response = {"code": code}
  257. else:
  258. response = {"code": code, "message": message}
  259. return jsonify(response)
  260. def get_error_data_result(
  261. message="Sorry! Data missing!",
  262. code=settings.RetCode.DATA_ERROR,
  263. ):
  264. result_dict = {"code": code, "message": message}
  265. response = {}
  266. for key, value in result_dict.items():
  267. if value is None and key != "code":
  268. continue
  269. else:
  270. response[key] = value
  271. return jsonify(response)
  272. def get_error_argument_result(message="Invalid arguments"):
  273. return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
  274. def generate_confirmation_token(tenant_id):
  275. serializer = URLSafeTimedSerializer(tenant_id)
  276. return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
  277. def get_parser_config(chunk_method, parser_config):
  278. if parser_config:
  279. return parser_config
  280. if not chunk_method:
  281. chunk_method = "naive"
  282. key_mapping = {
  283. "naive": {"chunk_token_num": 128, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}},
  284. "qa": {"raptor": {"use_raptor": False}},
  285. "tag": None,
  286. "resume": None,
  287. "manual": {"raptor": {"use_raptor": False}},
  288. "table": None,
  289. "paper": {"raptor": {"use_raptor": False}},
  290. "book": {"raptor": {"use_raptor": False}},
  291. "laws": {"raptor": {"use_raptor": False}},
  292. "presentation": {"raptor": {"use_raptor": False}},
  293. "one": None,
  294. "knowledge_graph": {"chunk_token_num": 8192, "delimiter": r"\n", "entity_types": ["organization", "person", "location", "event", "time"]},
  295. "email": None,
  296. "picture": None,
  297. }
  298. parser_config = key_mapping[chunk_method]
  299. return parser_config
  300. def get_data_openai(
  301. id=None,
  302. created=None,
  303. model=None,
  304. prompt_tokens=0,
  305. completion_tokens=0,
  306. content=None,
  307. finish_reason=None,
  308. object="chat.completion",
  309. param=None,
  310. ):
  311. total_tokens = prompt_tokens + completion_tokens
  312. return {
  313. "id": f"{id}",
  314. "object": object,
  315. "created": int(time.time()) if created else None,
  316. "model": model,
  317. "param": param,
  318. "usage": {
  319. "prompt_tokens": prompt_tokens,
  320. "completion_tokens": completion_tokens,
  321. "total_tokens": total_tokens,
  322. "completion_tokens_details": {"reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0},
  323. },
  324. "choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": finish_reason, "index": 0}],
  325. }
  326. def check_duplicate_ids(ids, id_type="item"):
  327. """
  328. Check for duplicate IDs in a list and return unique IDs and error messages.
  329. Args:
  330. ids (list): List of IDs to check for duplicates
  331. id_type (str): Type of ID for error messages (e.g., 'document', 'dataset', 'chunk')
  332. Returns:
  333. tuple: (unique_ids, error_messages)
  334. - unique_ids (list): List of unique IDs
  335. - error_messages (list): List of error messages for duplicate IDs
  336. """
  337. id_count = {}
  338. duplicate_messages = []
  339. # Count occurrences of each ID
  340. for id_value in ids:
  341. id_count[id_value] = id_count.get(id_value, 0) + 1
  342. # Check for duplicates
  343. for id_value, count in id_count.items():
  344. if count > 1:
  345. duplicate_messages.append(f"Duplicate {id_type} ids: {id_value}")
  346. # Return unique IDs and error messages
  347. return list(set(ids)), duplicate_messages
  348. def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
  349. """
  350. Verifies availability of an embedding model for a specific tenant.
  351. Implements a four-stage validation process:
  352. 1. Model identifier parsing and validation
  353. 2. System support verification
  354. 3. Tenant authorization check
  355. 4. Database operation error handling
  356. Args:
  357. embd_id (str): Unique identifier for the embedding model in format "model_name@factory"
  358. tenant_id (str): Tenant identifier for access control
  359. Returns:
  360. tuple[bool, Response | None]:
  361. - First element (bool):
  362. - True: Model is available and authorized
  363. - False: Validation failed
  364. - Second element contains:
  365. - None on success
  366. - Error detail dict on failure
  367. Raises:
  368. ValueError: When model identifier format is invalid
  369. OperationalError: When database connection fails (auto-handled)
  370. Examples:
  371. >>> verify_embedding_availability("text-embedding@openai", "tenant_123")
  372. (True, None)
  373. >>> verify_embedding_availability("invalid_model", "tenant_123")
  374. (False, {'code': 101, 'message': "Unsupported model: <invalid_model>"})
  375. """
  376. try:
  377. llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id)
  378. if not LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"):
  379. return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
  380. # Tongyi-Qianwen is added to TenantLLM by default, but remains unusable with empty api_key
  381. tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
  382. is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
  383. is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
  384. if not (is_builtin_model or is_tenant_model):
  385. return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>")
  386. except OperationalError as e:
  387. logging.exception(e)
  388. return False, get_error_data_result(message="Database operation failed")
  389. return True, None
  390. def deep_merge(default: dict, custom: dict) -> dict:
  391. """
  392. Recursively merges two dictionaries with priority given to `custom` values.
  393. Creates a deep copy of the `default` dictionary and iteratively merges nested
  394. dictionaries using a stack-based approach. Non-dict values in `custom` will
  395. completely override corresponding entries in `default`.
  396. Args:
  397. default (dict): Base dictionary containing default values.
  398. custom (dict): Dictionary containing overriding values.
  399. Returns:
  400. dict: New merged dictionary combining values from both inputs.
  401. Example:
  402. >>> from copy import deepcopy
  403. >>> default = {"a": 1, "nested": {"x": 10, "y": 20}}
  404. >>> custom = {"b": 2, "nested": {"y": 99, "z": 30}}
  405. >>> deep_merge(default, custom)
  406. {'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}}
  407. >>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"})
  408. {'config': 'manual'}
  409. Notes:
  410. 1. Merge priority is always given to `custom` values at all nesting levels
  411. 2. Non-dict values (e.g. list, str) in `custom` will replace entire values
  412. in `default`, even if the original value was a dictionary
  413. 3. Time complexity: O(N) where N is total key-value pairs in `custom`
  414. 4. Recommended for configuration merging and nested data updates
  415. """
  416. merged = deepcopy(default)
  417. stack = [(merged, custom)]
  418. while stack:
  419. base_dict, override_dict = stack.pop()
  420. for key, val in override_dict.items():
  421. if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict):
  422. stack.append((base_dict[key], val))
  423. else:
  424. base_dict[key] = val
  425. return merged