Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

api_utils.py 26KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import asyncio
  17. import functools
  18. import json
  19. import logging
  20. import os
  21. import queue
  22. import random
  23. import threading
  24. import time
  25. from base64 import b64encode
  26. from copy import deepcopy
  27. from functools import wraps
  28. from hmac import HMAC
  29. from io import BytesIO
  30. from typing import Any, Callable, Coroutine, Optional, Type, Union
  31. from urllib.parse import quote, urlencode
  32. from uuid import uuid1
  33. import requests
  34. import trio
  35. from flask import (
  36. Response,
  37. jsonify,
  38. make_response,
  39. send_file,
  40. )
  41. from flask import (
  42. request as flask_request,
  43. )
  44. from itsdangerous import URLSafeTimedSerializer
  45. from peewee import OperationalError
  46. from werkzeug.http import HTTP_STATUS_CODES
  47. from api import settings
  48. from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
  49. from api.db.db_models import APIToken
  50. from api.db.services.llm_service import LLMService
  51. from api.db.services.tenant_llm_service import TenantLLMService
  52. from api.utils import CustomJSONEncoder, get_uuid, json_dumps
  53. from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
  54. requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
  55. def request(**kwargs):
  56. sess = requests.Session()
  57. stream = kwargs.pop("stream", sess.stream)
  58. timeout = kwargs.pop("timeout", None)
  59. kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()}
  60. prepped = requests.Request(**kwargs).prepare()
  61. if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
  62. timestamp = str(round(time() * 1000))
  63. nonce = str(uuid1())
  64. signature = b64encode(
  65. HMAC(
  66. settings.SECRET_KEY.encode("ascii"),
  67. b"\n".join(
  68. [
  69. timestamp.encode("ascii"),
  70. nonce.encode("ascii"),
  71. settings.HTTP_APP_KEY.encode("ascii"),
  72. prepped.path_url.encode("ascii"),
  73. prepped.body if kwargs.get("json") else b"",
  74. urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
  75. ]
  76. ),
  77. "sha1",
  78. ).digest()
  79. ).decode("ascii")
  80. prepped.headers.update(
  81. {
  82. "TIMESTAMP": timestamp,
  83. "NONCE": nonce,
  84. "APP-KEY": settings.HTTP_APP_KEY,
  85. "SIGNATURE": signature,
  86. }
  87. )
  88. return sess.send(prepped, stream=stream, timeout=timeout)
  89. def get_exponential_backoff_interval(retries, full_jitter=False):
  90. """Calculate the exponential backoff wait time."""
  91. # Will be zero if factor equals 0
  92. countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
  93. # Full jitter according to
  94. # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
  95. if full_jitter:
  96. countdown = random.randrange(countdown + 1)
  97. # Adjust according to maximum wait time and account for negative values.
  98. return max(0, countdown)
  99. def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"):
  100. logging.exception(Exception(message))
  101. result_dict = {"code": code, "message": message}
  102. response = {}
  103. for key, value in result_dict.items():
  104. if value is None and key != "code":
  105. continue
  106. else:
  107. response[key] = value
  108. return jsonify(response)
  109. def server_error_response(e):
  110. logging.exception(e)
  111. try:
  112. if e.code == 401:
  113. return get_json_result(code=401, message=repr(e))
  114. except BaseException:
  115. pass
  116. if len(e.args) > 1:
  117. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
  118. if repr(e).find("index_not_found_exception") >= 0:
  119. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
  120. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
  121. def error_response(response_code, message=None):
  122. if message is None:
  123. message = HTTP_STATUS_CODES.get(response_code, "Unknown Error")
  124. return Response(
  125. json.dumps(
  126. {
  127. "message": message,
  128. "code": response_code,
  129. }
  130. ),
  131. status=response_code,
  132. mimetype="application/json",
  133. )
  134. def validate_request(*args, **kwargs):
  135. def wrapper(func):
  136. @wraps(func)
  137. def decorated_function(*_args, **_kwargs):
  138. input_arguments = flask_request.json or flask_request.form.to_dict()
  139. no_arguments = []
  140. error_arguments = []
  141. for arg in args:
  142. if arg not in input_arguments:
  143. no_arguments.append(arg)
  144. for k, v in kwargs.items():
  145. config_value = input_arguments.get(k, None)
  146. if config_value is None:
  147. no_arguments.append(k)
  148. elif isinstance(v, (tuple, list)):
  149. if config_value not in v:
  150. error_arguments.append((k, set(v)))
  151. elif config_value != v:
  152. error_arguments.append((k, v))
  153. if no_arguments or error_arguments:
  154. error_string = ""
  155. if no_arguments:
  156. error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
  157. if error_arguments:
  158. error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
  159. return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
  160. return func(*_args, **_kwargs)
  161. return decorated_function
  162. return wrapper
  163. def not_allowed_parameters(*params):
  164. def decorator(f):
  165. def wrapper(*args, **kwargs):
  166. input_arguments = flask_request.json or flask_request.form.to_dict()
  167. for param in params:
  168. if param in input_arguments:
  169. return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
  170. return f(*args, **kwargs)
  171. return wrapper
  172. return decorator
  173. def is_localhost(ip):
  174. return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
  175. def send_file_in_mem(data, filename):
  176. if not isinstance(data, (str, bytes)):
  177. data = json_dumps(data)
  178. if isinstance(data, str):
  179. data = data.encode("utf-8")
  180. f = BytesIO()
  181. f.write(data)
  182. f.seek(0)
  183. return send_file(f, as_attachment=True, attachment_filename=filename)
  184. def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
  185. response = {"code": code, "message": message, "data": data}
  186. return jsonify(response)
  187. def apikey_required(func):
  188. @wraps(func)
  189. def decorated_function(*args, **kwargs):
  190. token = flask_request.headers.get("Authorization").split()[1]
  191. objs = APIToken.query(token=token)
  192. if not objs:
  193. return build_error_result(message="API-KEY is invalid!", code=settings.RetCode.FORBIDDEN)
  194. kwargs["tenant_id"] = objs[0].tenant_id
  195. return func(*args, **kwargs)
  196. return decorated_function
  197. def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"):
  198. response = {"code": code, "message": message}
  199. response = jsonify(response)
  200. response.status_code = code
  201. return response
  202. def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None):
  203. result_dict = {"code": code, "message": message, "data": data}
  204. response_dict = {}
  205. for key, value in result_dict.items():
  206. if value is None and key != "code":
  207. continue
  208. else:
  209. response_dict[key] = value
  210. response = make_response(jsonify(response_dict))
  211. if auth:
  212. response.headers["Authorization"] = auth
  213. response.headers["Access-Control-Allow-Origin"] = "*"
  214. response.headers["Access-Control-Allow-Method"] = "*"
  215. response.headers["Access-Control-Allow-Headers"] = "*"
  216. response.headers["Access-Control-Allow-Headers"] = "*"
  217. response.headers["Access-Control-Expose-Headers"] = "Authorization"
  218. return response
  219. def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"):
  220. result_dict = {"code": code, "message": message}
  221. response = {}
  222. for key, value in result_dict.items():
  223. if value is None and key != "code":
  224. continue
  225. else:
  226. response[key] = value
  227. return jsonify(response)
  228. def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
  229. if data is None:
  230. return jsonify({"code": code, "message": message})
  231. else:
  232. return jsonify({"code": code, "message": message, "data": data})
  233. def construct_error_response(e):
  234. logging.exception(e)
  235. try:
  236. if e.code == 401:
  237. return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
  238. except BaseException:
  239. pass
  240. if len(e.args) > 1:
  241. return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
  242. return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
  243. def token_required(func):
  244. @wraps(func)
  245. def decorated_function(*args, **kwargs):
  246. authorization_str = flask_request.headers.get("Authorization")
  247. if not authorization_str:
  248. return get_json_result(data=False, message="`Authorization` can't be empty")
  249. authorization_list = authorization_str.split()
  250. if len(authorization_list) < 2:
  251. return get_json_result(data=False, message="Please check your authorization format.")
  252. token = authorization_list[1]
  253. objs = APIToken.query(token=token)
  254. if not objs:
  255. return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR)
  256. kwargs["tenant_id"] = objs[0].tenant_id
  257. return func(*args, **kwargs)
  258. return decorated_function
  259. def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
  260. if code == 0:
  261. if data is not None:
  262. response = {"code": code, "data": data}
  263. else:
  264. response = {"code": code}
  265. else:
  266. response = {"code": code, "message": message}
  267. return jsonify(response)
  268. def get_error_data_result(
  269. message="Sorry! Data missing!",
  270. code=settings.RetCode.DATA_ERROR,
  271. ):
  272. result_dict = {"code": code, "message": message}
  273. response = {}
  274. for key, value in result_dict.items():
  275. if value is None and key != "code":
  276. continue
  277. else:
  278. response[key] = value
  279. return jsonify(response)
  280. def get_error_argument_result(message="Invalid arguments"):
  281. return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
  282. def get_error_permission_result(message="Permission error"):
  283. return get_result(code=settings.RetCode.PERMISSION_ERROR, message=message)
  284. def get_error_operating_result(message="Operating error"):
  285. return get_result(code=settings.RetCode.OPERATING_ERROR, message=message)
  286. def generate_confirmation_token(tenant_id):
  287. serializer = URLSafeTimedSerializer(tenant_id)
  288. return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
  289. def get_parser_config(chunk_method, parser_config):
  290. if not chunk_method:
  291. chunk_method = "naive"
  292. # Define default configurations for each chunking method
  293. key_mapping = {
  294. "naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  295. "qa": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  296. "tag": None,
  297. "resume": None,
  298. "manual": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  299. "table": None,
  300. "paper": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  301. "book": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  302. "laws": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  303. "presentation": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  304. "one": None,
  305. "knowledge_graph": {
  306. "chunk_token_num": 8192,
  307. "delimiter": r"\n",
  308. "entity_types": ["organization", "person", "location", "event", "time"],
  309. "raptor": {"use_raptor": False},
  310. "graphrag": {"use_graphrag": False},
  311. },
  312. "email": None,
  313. "picture": None,
  314. }
  315. default_config = key_mapping[chunk_method]
  316. # If no parser_config provided, return default
  317. if not parser_config:
  318. return default_config
  319. # If parser_config is provided, merge with defaults to ensure required fields exist
  320. if default_config is None:
  321. return parser_config
  322. # Ensure raptor and graphrag fields have default values if not provided
  323. merged_config = deep_merge(default_config, parser_config)
  324. return merged_config
  325. def get_data_openai(
  326. id=None,
  327. created=None,
  328. model=None,
  329. prompt_tokens=0,
  330. completion_tokens=0,
  331. content=None,
  332. finish_reason=None,
  333. object="chat.completion",
  334. param=None,
  335. stream=False
  336. ):
  337. total_tokens = prompt_tokens + completion_tokens
  338. if stream:
  339. return {
  340. "id": f"{id}",
  341. "object": "chat.completion.chunk",
  342. "model": model,
  343. "choices": [{
  344. "delta": {"content": content},
  345. "finish_reason": finish_reason,
  346. "index": 0,
  347. }],
  348. }
  349. return {
  350. "id": f"{id}",
  351. "object": object,
  352. "created": int(time.time()) if created else None,
  353. "model": model,
  354. "param": param,
  355. "usage": {
  356. "prompt_tokens": prompt_tokens,
  357. "completion_tokens": completion_tokens,
  358. "total_tokens": total_tokens,
  359. "completion_tokens_details": {
  360. "reasoning_tokens": 0,
  361. "accepted_prediction_tokens": 0,
  362. "rejected_prediction_tokens": 0,
  363. },
  364. },
  365. "choices": [{
  366. "message": {
  367. "role": "assistant",
  368. "content": content
  369. },
  370. "logprobs": None,
  371. "finish_reason": finish_reason,
  372. "index": 0,
  373. }],
  374. }
  375. def check_duplicate_ids(ids, id_type="item"):
  376. """
  377. Check for duplicate IDs in a list and return unique IDs and error messages.
  378. Args:
  379. ids (list): List of IDs to check for duplicates
  380. id_type (str): Type of ID for error messages (e.g., 'document', 'dataset', 'chunk')
  381. Returns:
  382. tuple: (unique_ids, error_messages)
  383. - unique_ids (list): List of unique IDs
  384. - error_messages (list): List of error messages for duplicate IDs
  385. """
  386. id_count = {}
  387. duplicate_messages = []
  388. # Count occurrences of each ID
  389. for id_value in ids:
  390. id_count[id_value] = id_count.get(id_value, 0) + 1
  391. # Check for duplicates
  392. for id_value, count in id_count.items():
  393. if count > 1:
  394. duplicate_messages.append(f"Duplicate {id_type} ids: {id_value}")
  395. # Return unique IDs and error messages
  396. return list(set(ids)), duplicate_messages
  397. def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
  398. """
  399. Verifies availability of an embedding model for a specific tenant.
  400. Performs comprehensive verification through:
  401. 1. Identifier Parsing: Decomposes embd_id into name and factory components
  402. 2. System Verification: Checks model registration in LLMService
  403. 3. Tenant Authorization: Validates tenant-specific model assignments
  404. 4. Built-in Model Check: Confirms inclusion in predefined system models
  405. Args:
  406. embd_id (str): Unique identifier for the embedding model in format "model_name@factory"
  407. tenant_id (str): Tenant identifier for access control
  408. Returns:
  409. tuple[bool, Response | None]:
  410. - First element (bool):
  411. - True: Model is available and authorized
  412. - False: Validation failed
  413. - Second element contains:
  414. - None on success
  415. - Error detail dict on failure
  416. Raises:
  417. ValueError: When model identifier format is invalid
  418. OperationalError: When database connection fails (auto-handled)
  419. Examples:
  420. >>> verify_embedding_availability("text-embedding@openai", "tenant_123")
  421. (True, None)
  422. >>> verify_embedding_availability("invalid_model", "tenant_123")
  423. (False, {'code': 101, 'message': "Unsupported model: <invalid_model>"})
  424. """
  425. try:
  426. llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id)
  427. in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"))
  428. tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
  429. is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
  430. is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
  431. if not (is_builtin_model or is_tenant_model or in_llm_service):
  432. return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
  433. if not (is_builtin_model or is_tenant_model):
  434. return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>")
  435. except OperationalError as e:
  436. logging.exception(e)
  437. return False, get_error_data_result(message="Database operation failed")
  438. return True, None
  439. def deep_merge(default: dict, custom: dict) -> dict:
  440. """
  441. Recursively merges two dictionaries with priority given to `custom` values.
  442. Creates a deep copy of the `default` dictionary and iteratively merges nested
  443. dictionaries using a stack-based approach. Non-dict values in `custom` will
  444. completely override corresponding entries in `default`.
  445. Args:
  446. default (dict): Base dictionary containing default values.
  447. custom (dict): Dictionary containing overriding values.
  448. Returns:
  449. dict: New merged dictionary combining values from both inputs.
  450. Example:
  451. >>> from copy import deepcopy
  452. >>> default = {"a": 1, "nested": {"x": 10, "y": 20}}
  453. >>> custom = {"b": 2, "nested": {"y": 99, "z": 30}}
  454. >>> deep_merge(default, custom)
  455. {'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}}
  456. >>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"})
  457. {'config': 'manual'}
  458. Notes:
  459. 1. Merge priority is always given to `custom` values at all nesting levels
  460. 2. Non-dict values (e.g. list, str) in `custom` will replace entire values
  461. in `default`, even if the original value was a dictionary
  462. 3. Time complexity: O(N) where N is total key-value pairs in `custom`
  463. 4. Recommended for configuration merging and nested data updates
  464. """
  465. merged = deepcopy(default)
  466. stack = [(merged, custom)]
  467. while stack:
  468. base_dict, override_dict = stack.pop()
  469. for key, val in override_dict.items():
  470. if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict):
  471. stack.append((base_dict[key], val))
  472. else:
  473. base_dict[key] = val
  474. return merged
  475. def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
  476. """
  477. Transform dictionary keys using a configurable mapping schema.
  478. Args:
  479. source_data: Original dictionary to process
  480. key_aliases: Custom key transformation rules (Optional)
  481. When provided, overrides default key mapping
  482. Format: {<original_key>: <new_key>, ...}
  483. Returns:
  484. dict: New dictionary with transformed keys preserving original values
  485. Example:
  486. >>> input_data = {"old_key": "value", "another_field": 42}
  487. >>> remap_dictionary_keys(input_data, {"old_key": "new_key"})
  488. {'new_key': 'value', 'another_field': 42}
  489. """
  490. DEFAULT_KEY_MAP = {
  491. "chunk_num": "chunk_count",
  492. "doc_num": "document_count",
  493. "parser_id": "chunk_method",
  494. "embd_id": "embedding_model",
  495. }
  496. transformed_data = {}
  497. mapping = key_aliases or DEFAULT_KEY_MAP
  498. for original_key, value in source_data.items():
  499. mapped_key = mapping.get(original_key, original_key)
  500. transformed_data[mapped_key] = value
  501. return transformed_data
  502. def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
  503. results = {}
  504. tool_call_sessions = []
  505. try:
  506. for mcp_server in mcp_servers:
  507. server_key = mcp_server.id
  508. cached_tools = mcp_server.variables.get("tools", {})
  509. tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
  510. tool_call_sessions.append(tool_call_session)
  511. try:
  512. tools = tool_call_session.get_tools(timeout)
  513. except Exception:
  514. tools = []
  515. results[server_key] = []
  516. for tool in tools:
  517. tool_dict = tool.model_dump()
  518. cached_tool = cached_tools.get(tool_dict["name"], {})
  519. tool_dict["enabled"] = cached_tool.get("enabled", True)
  520. results[server_key].append(tool_dict)
  521. # PERF: blocking call to close sessions — consider moving to background thread or task queue
  522. close_multiple_mcp_toolcall_sessions(tool_call_sessions)
  523. return results, ""
  524. except Exception as e:
  525. return {}, str(e)
  526. TimeoutException = Union[Type[BaseException], BaseException]
  527. OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
  528. def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
  529. def decorator(func):
  530. @wraps(func)
  531. def wrapper(*args, **kwargs):
  532. result_queue = queue.Queue(maxsize=1)
  533. def target():
  534. try:
  535. result = func(*args, **kwargs)
  536. result_queue.put(result)
  537. except Exception as e:
  538. result_queue.put(e)
  539. thread = threading.Thread(target=target)
  540. thread.daemon = True
  541. thread.start()
  542. for a in range(attempts):
  543. try:
  544. if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
  545. result = result_queue.get(timeout=seconds)
  546. else:
  547. result = result_queue.get()
  548. if isinstance(result, Exception):
  549. raise result
  550. return result
  551. except queue.Empty:
  552. pass
  553. raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
  554. @wraps(func)
  555. async def async_wrapper(*args, **kwargs) -> Any:
  556. if seconds is None:
  557. return await func(*args, **kwargs)
  558. for a in range(attempts):
  559. try:
  560. if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
  561. with trio.fail_after(seconds):
  562. return await func(*args, **kwargs)
  563. else:
  564. return await func(*args, **kwargs)
  565. except trio.TooSlowError:
  566. if a < attempts - 1:
  567. continue
  568. if on_timeout is not None:
  569. if callable(on_timeout):
  570. result = on_timeout()
  571. if isinstance(result, Coroutine):
  572. return await result
  573. return result
  574. return on_timeout
  575. if exception is None:
  576. raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
  577. if isinstance(exception, BaseException):
  578. raise exception
  579. if isinstance(exception, type) and issubclass(exception, BaseException):
  580. raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
  581. raise RuntimeError("Invalid exception type provided")
  582. if asyncio.iscoroutinefunction(func):
  583. return async_wrapper
  584. return wrapper
  585. return decorator
  586. async def is_strong_enough(chat_model, embedding_model):
  587. count = settings.STRONG_TEST_COUNT
  588. if not chat_model or not embedding_model:
  589. return
  590. if isinstance(count, int) and count <= 0:
  591. return
  592. @timeout(60, 2)
  593. async def _is_strong_enough():
  594. nonlocal chat_model, embedding_model
  595. if embedding_model:
  596. with trio.fail_after(10):
  597. _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
  598. if chat_model:
  599. with trio.fail_after(30):
  600. res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
  601. if res.find("**ERROR**") >= 0:
  602. raise Exception(res)
  603. # Pressure test for GraphRAG task
  604. async with trio.open_nursery() as nursery:
  605. for _ in range(count):
  606. nursery.start_soon(_is_strong_enough)