Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. import json
  18. import logging
  19. import random
  20. import time
  21. from base64 import b64encode
  22. from functools import wraps
  23. from hmac import HMAC
  24. from io import BytesIO
  25. from urllib.parse import quote, urlencode
  26. from uuid import uuid1
  27. import requests
  28. from flask import (
  29. Response,
  30. jsonify,
  31. make_response,
  32. send_file,
  33. )
  34. from flask import (
  35. request as flask_request,
  36. )
  37. from itsdangerous import URLSafeTimedSerializer
  38. from werkzeug.http import HTTP_STATUS_CODES
  39. from api import settings
  40. from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
  41. from api.db.db_models import APIToken
  42. from api.utils import CustomJSONEncoder, get_uuid, json_dumps
  43. requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
  44. def request(**kwargs):
  45. sess = requests.Session()
  46. stream = kwargs.pop("stream", sess.stream)
  47. timeout = kwargs.pop("timeout", None)
  48. kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()}
  49. prepped = requests.Request(**kwargs).prepare()
  50. if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
  51. timestamp = str(round(time() * 1000))
  52. nonce = str(uuid1())
  53. signature = b64encode(
  54. HMAC(
  55. settings.SECRET_KEY.encode("ascii"),
  56. b"\n".join(
  57. [
  58. timestamp.encode("ascii"),
  59. nonce.encode("ascii"),
  60. settings.HTTP_APP_KEY.encode("ascii"),
  61. prepped.path_url.encode("ascii"),
  62. prepped.body if kwargs.get("json") else b"",
  63. urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
  64. ]
  65. ),
  66. "sha1",
  67. ).digest()
  68. ).decode("ascii")
  69. prepped.headers.update(
  70. {
  71. "TIMESTAMP": timestamp,
  72. "NONCE": nonce,
  73. "APP-KEY": settings.HTTP_APP_KEY,
  74. "SIGNATURE": signature,
  75. }
  76. )
  77. return sess.send(prepped, stream=stream, timeout=timeout)
  78. def get_exponential_backoff_interval(retries, full_jitter=False):
  79. """Calculate the exponential backoff wait time."""
  80. # Will be zero if factor equals 0
  81. countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
  82. # Full jitter according to
  83. # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
  84. if full_jitter:
  85. countdown = random.randrange(countdown + 1)
  86. # Adjust according to maximum wait time and account for negative values.
  87. return max(0, countdown)
  88. def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"):
  89. logging.exception(Exception(message))
  90. result_dict = {"code": code, "message": message}
  91. response = {}
  92. for key, value in result_dict.items():
  93. if value is None and key != "code":
  94. continue
  95. else:
  96. response[key] = value
  97. return jsonify(response)
  98. def server_error_response(e):
  99. logging.exception(e)
  100. try:
  101. if e.code == 401:
  102. return get_json_result(code=401, message=repr(e))
  103. except BaseException:
  104. pass
  105. if len(e.args) > 1:
  106. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
  107. if repr(e).find("index_not_found_exception") >= 0:
  108. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
  109. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
  110. def error_response(response_code, message=None):
  111. if message is None:
  112. message = HTTP_STATUS_CODES.get(response_code, "Unknown Error")
  113. return Response(
  114. json.dumps(
  115. {
  116. "message": message,
  117. "code": response_code,
  118. }
  119. ),
  120. status=response_code,
  121. mimetype="application/json",
  122. )
  123. def validate_request(*args, **kwargs):
  124. def wrapper(func):
  125. @wraps(func)
  126. def decorated_function(*_args, **_kwargs):
  127. input_arguments = flask_request.json or flask_request.form.to_dict()
  128. no_arguments = []
  129. error_arguments = []
  130. for arg in args:
  131. if arg not in input_arguments:
  132. no_arguments.append(arg)
  133. for k, v in kwargs.items():
  134. config_value = input_arguments.get(k, None)
  135. if config_value is None:
  136. no_arguments.append(k)
  137. elif isinstance(v, (tuple, list)):
  138. if config_value not in v:
  139. error_arguments.append((k, set(v)))
  140. elif config_value != v:
  141. error_arguments.append((k, v))
  142. if no_arguments or error_arguments:
  143. error_string = ""
  144. if no_arguments:
  145. error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
  146. if error_arguments:
  147. error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
  148. return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
  149. return func(*_args, **_kwargs)
  150. return decorated_function
  151. return wrapper
  152. def not_allowed_parameters(*params):
  153. def decorator(f):
  154. def wrapper(*args, **kwargs):
  155. input_arguments = flask_request.json or flask_request.form.to_dict()
  156. for param in params:
  157. if param in input_arguments:
  158. return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
  159. return f(*args, **kwargs)
  160. return wrapper
  161. return decorator
  162. def is_localhost(ip):
  163. return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
  164. def send_file_in_mem(data, filename):
  165. if not isinstance(data, (str, bytes)):
  166. data = json_dumps(data)
  167. if isinstance(data, str):
  168. data = data.encode("utf-8")
  169. f = BytesIO()
  170. f.write(data)
  171. f.seek(0)
  172. return send_file(f, as_attachment=True, attachment_filename=filename)
  173. def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
  174. response = {"code": code, "message": message, "data": data}
  175. return jsonify(response)
  176. def apikey_required(func):
  177. @wraps(func)
  178. def decorated_function(*args, **kwargs):
  179. token = flask_request.headers.get("Authorization").split()[1]
  180. objs = APIToken.query(token=token)
  181. if not objs:
  182. return build_error_result(message="API-KEY is invalid!", code=settings.RetCode.FORBIDDEN)
  183. kwargs["tenant_id"] = objs[0].tenant_id
  184. return func(*args, **kwargs)
  185. return decorated_function
  186. def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"):
  187. response = {"code": code, "message": message}
  188. response = jsonify(response)
  189. response.status_code = code
  190. return response
  191. def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None):
  192. result_dict = {"code": code, "message": message, "data": data}
  193. response_dict = {}
  194. for key, value in result_dict.items():
  195. if value is None and key != "code":
  196. continue
  197. else:
  198. response_dict[key] = value
  199. response = make_response(jsonify(response_dict))
  200. if auth:
  201. response.headers["Authorization"] = auth
  202. response.headers["Access-Control-Allow-Origin"] = "*"
  203. response.headers["Access-Control-Allow-Method"] = "*"
  204. response.headers["Access-Control-Allow-Headers"] = "*"
  205. response.headers["Access-Control-Allow-Headers"] = "*"
  206. response.headers["Access-Control-Expose-Headers"] = "Authorization"
  207. return response
  208. def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"):
  209. result_dict = {"code": code, "message": message}
  210. response = {}
  211. for key, value in result_dict.items():
  212. if value is None and key != "code":
  213. continue
  214. else:
  215. response[key] = value
  216. return jsonify(response)
  217. def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
  218. if data is None:
  219. return jsonify({"code": code, "message": message})
  220. else:
  221. return jsonify({"code": code, "message": message, "data": data})
  222. def construct_error_response(e):
  223. logging.exception(e)
  224. try:
  225. if e.code == 401:
  226. return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
  227. except BaseException:
  228. pass
  229. if len(e.args) > 1:
  230. return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
  231. return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
  232. def token_required(func):
  233. @wraps(func)
  234. def decorated_function(*args, **kwargs):
  235. authorization_str = flask_request.headers.get("Authorization")
  236. if not authorization_str:
  237. return get_json_result(data=False, message="`Authorization` can't be empty")
  238. authorization_list = authorization_str.split()
  239. if len(authorization_list) < 2:
  240. return get_json_result(data=False, message="Please check your authorization format.")
  241. token = authorization_list[1]
  242. objs = APIToken.query(token=token)
  243. if not objs:
  244. return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR)
  245. kwargs["tenant_id"] = objs[0].tenant_id
  246. return func(*args, **kwargs)
  247. return decorated_function
  248. def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
  249. if code == 0:
  250. if data is not None:
  251. response = {"code": code, "data": data}
  252. else:
  253. response = {"code": code}
  254. else:
  255. response = {"code": code, "message": message}
  256. return jsonify(response)
  257. def get_error_data_result(
  258. message="Sorry! Data missing!",
  259. code=settings.RetCode.DATA_ERROR,
  260. ):
  261. result_dict = {"code": code, "message": message}
  262. response = {}
  263. for key, value in result_dict.items():
  264. if value is None and key != "code":
  265. continue
  266. else:
  267. response[key] = value
  268. return jsonify(response)
  269. def get_error_argument_result(message="Invalid arguments"):
  270. return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
  271. def generate_confirmation_token(tenant_id):
  272. serializer = URLSafeTimedSerializer(tenant_id)
  273. return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
  274. def valid(permission, valid_permission, chunk_method, valid_chunk_method):
  275. if valid_parameter(permission, valid_permission):
  276. return valid_parameter(permission, valid_permission)
  277. if valid_parameter(chunk_method, valid_chunk_method):
  278. return valid_parameter(chunk_method, valid_chunk_method)
  279. def valid_parameter(parameter, valid_values):
  280. if parameter and parameter not in valid_values:
  281. return get_error_data_result(f"'{parameter}' is not in {valid_values}")
  282. def dataset_readonly_fields(field_name):
  283. return field_name in ["chunk_count", "create_date", "create_time", "update_date", "update_time", "created_by", "document_count", "token_num", "status", "tenant_id", "id"]
  284. def get_parser_config(chunk_method, parser_config):
  285. if parser_config:
  286. return parser_config
  287. if not chunk_method:
  288. chunk_method = "naive"
  289. key_mapping = {
  290. "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}},
  291. "qa": {"raptor": {"use_raptor": False}},
  292. "tag": None,
  293. "resume": None,
  294. "manual": {"raptor": {"use_raptor": False}},
  295. "table": None,
  296. "paper": {"raptor": {"use_raptor": False}},
  297. "book": {"raptor": {"use_raptor": False}},
  298. "laws": {"raptor": {"use_raptor": False}},
  299. "presentation": {"raptor": {"use_raptor": False}},
  300. "one": None,
  301. "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?", "entity_types": ["organization", "person", "location", "event", "time"]},
  302. "email": None,
  303. "picture": None,
  304. }
  305. parser_config = key_mapping[chunk_method]
  306. return parser_config
  307. def get_data_openai(
  308. id=None,
  309. created=None,
  310. model=None,
  311. prompt_tokens=0,
  312. completion_tokens=0,
  313. content=None,
  314. finish_reason=None,
  315. object="chat.completion",
  316. param=None,
  317. ):
  318. total_tokens = prompt_tokens + completion_tokens
  319. return {
  320. "id": f"{id}",
  321. "object": object,
  322. "created": int(time.time()) if created else None,
  323. "model": model,
  324. "param": param,
  325. "usage": {
  326. "prompt_tokens": prompt_tokens,
  327. "completion_tokens": completion_tokens,
  328. "total_tokens": total_tokens,
  329. "completion_tokens_details": {"reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0},
  330. },
  331. "choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": finish_reason, "index": 0}],
  332. }
  333. def valid_parser_config(parser_config):
  334. if not parser_config:
  335. return
  336. scopes = set(
  337. [
  338. "chunk_token_num",
  339. "delimiter",
  340. "raptor",
  341. "graphrag",
  342. "layout_recognize",
  343. "task_page_size",
  344. "pages",
  345. "html4excel",
  346. "auto_keywords",
  347. "auto_questions",
  348. "tag_kb_ids",
  349. "topn_tags",
  350. "filename_embd_weight",
  351. ]
  352. )
  353. for k in parser_config.keys():
  354. assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
  355. assert isinstance(parser_config.get("chunk_token_num", 1), int), "chunk_token_num should be int"
  356. assert 1 <= parser_config.get("chunk_token_num", 1) < 100000000, "chunk_token_num should be in range from 1 to 100000000"
  357. assert isinstance(parser_config.get("task_page_size", 1), int), "task_page_size should be int"
  358. assert 1 <= parser_config.get("task_page_size", 1) < 100000000, "task_page_size should be in range from 1 to 100000000"
  359. assert isinstance(parser_config.get("auto_keywords", 1), int), "auto_keywords should be int"
  360. assert 0 <= parser_config.get("auto_keywords", 0) < 32, "auto_keywords should be in range from 0 to 32"
  361. assert isinstance(parser_config.get("auto_questions", 1), int), "auto_questions should be int"
  362. assert 0 <= parser_config.get("auto_questions", 0) < 10, "auto_questions should be in range from 0 to 10"
  363. assert isinstance(parser_config.get("topn_tags", 1), int), "topn_tags should be int"
  364. assert 0 <= parser_config.get("topn_tags", 0) < 10, "topn_tags should be in range from 0 to 10"
  365. assert isinstance(parser_config.get("html4excel", False), bool), "html4excel should be True or False"
  366. assert isinstance(parser_config.get("delimiter", ""), str), "delimiter should be str"
  367. def check_duplicate_ids(ids, id_type="item"):
  368. """
  369. Check for duplicate IDs in a list and return unique IDs and error messages.
  370. Args:
  371. ids (list): List of IDs to check for duplicates
  372. id_type (str): Type of ID for error messages (e.g., 'document', 'dataset', 'chunk')
  373. Returns:
  374. tuple: (unique_ids, error_messages)
  375. - unique_ids (list): List of unique IDs
  376. - error_messages (list): List of error messages for duplicate IDs
  377. """
  378. id_count = {}
  379. duplicate_messages = []
  380. # Count occurrences of each ID
  381. for id_value in ids:
  382. id_count[id_value] = id_count.get(id_value, 0) + 1
  383. # Check for duplicates
  384. for id_value, count in id_count.items():
  385. if count > 1:
  386. duplicate_messages.append(f"Duplicate {id_type} ids: {id_value}")
  387. # Return unique IDs and error messages
  388. return list(set(ids)), duplicate_messages