Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import asyncio
  17. import functools
  18. import json
  19. import logging
  20. import queue
  21. import random
  22. import threading
  23. import time
  24. from base64 import b64encode
  25. from copy import deepcopy
  26. from functools import wraps
  27. from hmac import HMAC
  28. from io import BytesIO
  29. from typing import Any, Callable, Coroutine, Optional, Type, Union
  30. from urllib.parse import quote, urlencode
  31. from uuid import uuid1
  32. import requests
  33. import trio
  34. from flask import (
  35. Response,
  36. jsonify,
  37. make_response,
  38. send_file,
  39. )
  40. from flask import (
  41. request as flask_request,
  42. )
  43. from itsdangerous import URLSafeTimedSerializer
  44. from peewee import OperationalError
  45. from werkzeug.http import HTTP_STATUS_CODES
  46. from api import settings
  47. from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
  48. from api.db.db_models import APIToken
  49. from api.db.services.llm_service import LLMService, TenantLLMService
  50. from api.utils import CustomJSONEncoder, get_uuid, json_dumps
  51. from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
  52. requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder)
  53. def request(**kwargs):
  54. sess = requests.Session()
  55. stream = kwargs.pop("stream", sess.stream)
  56. timeout = kwargs.pop("timeout", None)
  57. kwargs["headers"] = {k.replace("_", "-").upper(): v for k, v in kwargs.get("headers", {}).items()}
  58. prepped = requests.Request(**kwargs).prepare()
  59. if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
  60. timestamp = str(round(time() * 1000))
  61. nonce = str(uuid1())
  62. signature = b64encode(
  63. HMAC(
  64. settings.SECRET_KEY.encode("ascii"),
  65. b"\n".join(
  66. [
  67. timestamp.encode("ascii"),
  68. nonce.encode("ascii"),
  69. settings.HTTP_APP_KEY.encode("ascii"),
  70. prepped.path_url.encode("ascii"),
  71. prepped.body if kwargs.get("json") else b"",
  72. urlencode(sorted(kwargs["data"].items()), quote_via=quote, safe="-._~").encode("ascii") if kwargs.get("data") and isinstance(kwargs["data"], dict) else b"",
  73. ]
  74. ),
  75. "sha1",
  76. ).digest()
  77. ).decode("ascii")
  78. prepped.headers.update(
  79. {
  80. "TIMESTAMP": timestamp,
  81. "NONCE": nonce,
  82. "APP-KEY": settings.HTTP_APP_KEY,
  83. "SIGNATURE": signature,
  84. }
  85. )
  86. return sess.send(prepped, stream=stream, timeout=timeout)
  87. def get_exponential_backoff_interval(retries, full_jitter=False):
  88. """Calculate the exponential backoff wait time."""
  89. # Will be zero if factor equals 0
  90. countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2**retries))
  91. # Full jitter according to
  92. # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
  93. if full_jitter:
  94. countdown = random.randrange(countdown + 1)
  95. # Adjust according to maximum wait time and account for negative values.
  96. return max(0, countdown)
  97. def get_data_error_result(code=settings.RetCode.DATA_ERROR, message="Sorry! Data missing!"):
  98. logging.exception(Exception(message))
  99. result_dict = {"code": code, "message": message}
  100. response = {}
  101. for key, value in result_dict.items():
  102. if value is None and key != "code":
  103. continue
  104. else:
  105. response[key] = value
  106. return jsonify(response)
  107. def server_error_response(e):
  108. logging.exception(e)
  109. try:
  110. if e.code == 401:
  111. return get_json_result(code=401, message=repr(e))
  112. except BaseException:
  113. pass
  114. if len(e.args) > 1:
  115. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
  116. if repr(e).find("index_not_found_exception") >= 0:
  117. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message="No chunk found, please upload file and parse it.")
  118. return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
  119. def error_response(response_code, message=None):
  120. if message is None:
  121. message = HTTP_STATUS_CODES.get(response_code, "Unknown Error")
  122. return Response(
  123. json.dumps(
  124. {
  125. "message": message,
  126. "code": response_code,
  127. }
  128. ),
  129. status=response_code,
  130. mimetype="application/json",
  131. )
  132. def validate_request(*args, **kwargs):
  133. def wrapper(func):
  134. @wraps(func)
  135. def decorated_function(*_args, **_kwargs):
  136. input_arguments = flask_request.json or flask_request.form.to_dict()
  137. no_arguments = []
  138. error_arguments = []
  139. for arg in args:
  140. if arg not in input_arguments:
  141. no_arguments.append(arg)
  142. for k, v in kwargs.items():
  143. config_value = input_arguments.get(k, None)
  144. if config_value is None:
  145. no_arguments.append(k)
  146. elif isinstance(v, (tuple, list)):
  147. if config_value not in v:
  148. error_arguments.append((k, set(v)))
  149. elif config_value != v:
  150. error_arguments.append((k, v))
  151. if no_arguments or error_arguments:
  152. error_string = ""
  153. if no_arguments:
  154. error_string += "required argument are missing: {}; ".format(",".join(no_arguments))
  155. if error_arguments:
  156. error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
  157. return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
  158. return func(*_args, **_kwargs)
  159. return decorated_function
  160. return wrapper
  161. def not_allowed_parameters(*params):
  162. def decorator(f):
  163. def wrapper(*args, **kwargs):
  164. input_arguments = flask_request.json or flask_request.form.to_dict()
  165. for param in params:
  166. if param in input_arguments:
  167. return get_json_result(code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
  168. return f(*args, **kwargs)
  169. return wrapper
  170. return decorator
  171. def is_localhost(ip):
  172. return ip in {"127.0.0.1", "::1", "[::1]", "localhost"}
  173. def send_file_in_mem(data, filename):
  174. if not isinstance(data, (str, bytes)):
  175. data = json_dumps(data)
  176. if isinstance(data, str):
  177. data = data.encode("utf-8")
  178. f = BytesIO()
  179. f.write(data)
  180. f.seek(0)
  181. return send_file(f, as_attachment=True, attachment_filename=filename)
  182. def get_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
  183. response = {"code": code, "message": message, "data": data}
  184. return jsonify(response)
  185. def apikey_required(func):
  186. @wraps(func)
  187. def decorated_function(*args, **kwargs):
  188. token = flask_request.headers.get("Authorization").split()[1]
  189. objs = APIToken.query(token=token)
  190. if not objs:
  191. return build_error_result(message="API-KEY is invalid!", code=settings.RetCode.FORBIDDEN)
  192. kwargs["tenant_id"] = objs[0].tenant_id
  193. return func(*args, **kwargs)
  194. return decorated_function
  195. def build_error_result(code=settings.RetCode.FORBIDDEN, message="success"):
  196. response = {"code": code, "message": message}
  197. response = jsonify(response)
  198. response.status_code = code
  199. return response
  200. def construct_response(code=settings.RetCode.SUCCESS, message="success", data=None, auth=None):
  201. result_dict = {"code": code, "message": message, "data": data}
  202. response_dict = {}
  203. for key, value in result_dict.items():
  204. if value is None and key != "code":
  205. continue
  206. else:
  207. response_dict[key] = value
  208. response = make_response(jsonify(response_dict))
  209. if auth:
  210. response.headers["Authorization"] = auth
  211. response.headers["Access-Control-Allow-Origin"] = "*"
  212. response.headers["Access-Control-Allow-Method"] = "*"
  213. response.headers["Access-Control-Allow-Headers"] = "*"
  214. response.headers["Access-Control-Allow-Headers"] = "*"
  215. response.headers["Access-Control-Expose-Headers"] = "Authorization"
  216. return response
  217. def construct_result(code=settings.RetCode.DATA_ERROR, message="data is missing"):
  218. result_dict = {"code": code, "message": message}
  219. response = {}
  220. for key, value in result_dict.items():
  221. if value is None and key != "code":
  222. continue
  223. else:
  224. response[key] = value
  225. return jsonify(response)
  226. def construct_json_result(code=settings.RetCode.SUCCESS, message="success", data=None):
  227. if data is None:
  228. return jsonify({"code": code, "message": message})
  229. else:
  230. return jsonify({"code": code, "message": message, "data": data})
  231. def construct_error_response(e):
  232. logging.exception(e)
  233. try:
  234. if e.code == 401:
  235. return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
  236. except BaseException:
  237. pass
  238. if len(e.args) > 1:
  239. return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
  240. return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
  241. def token_required(func):
  242. @wraps(func)
  243. def decorated_function(*args, **kwargs):
  244. authorization_str = flask_request.headers.get("Authorization")
  245. if not authorization_str:
  246. return get_json_result(data=False, message="`Authorization` can't be empty")
  247. authorization_list = authorization_str.split()
  248. if len(authorization_list) < 2:
  249. return get_json_result(data=False, message="Please check your authorization format.")
  250. token = authorization_list[1]
  251. objs = APIToken.query(token=token)
  252. if not objs:
  253. return get_json_result(data=False, message="Authentication error: API key is invalid!", code=settings.RetCode.AUTHENTICATION_ERROR)
  254. kwargs["tenant_id"] = objs[0].tenant_id
  255. return func(*args, **kwargs)
  256. return decorated_function
  257. def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
  258. if code == 0:
  259. if data is not None:
  260. response = {"code": code, "data": data}
  261. else:
  262. response = {"code": code}
  263. else:
  264. response = {"code": code, "message": message}
  265. return jsonify(response)
  266. def get_error_data_result(
  267. message="Sorry! Data missing!",
  268. code=settings.RetCode.DATA_ERROR,
  269. ):
  270. result_dict = {"code": code, "message": message}
  271. response = {}
  272. for key, value in result_dict.items():
  273. if value is None and key != "code":
  274. continue
  275. else:
  276. response[key] = value
  277. return jsonify(response)
  278. def get_error_argument_result(message="Invalid arguments"):
  279. return get_result(code=settings.RetCode.ARGUMENT_ERROR, message=message)
  280. def get_error_permission_result(message="Permission error"):
  281. return get_result(code=settings.RetCode.PERMISSION_ERROR, message=message)
  282. def get_error_operating_result(message="Operating error"):
  283. return get_result(code=settings.RetCode.OPERATING_ERROR, message=message)
  284. def generate_confirmation_token(tenant_id):
  285. serializer = URLSafeTimedSerializer(tenant_id)
  286. return "ragflow-" + serializer.dumps(get_uuid(), salt=tenant_id)[2:34]
  287. def get_parser_config(chunk_method, parser_config):
  288. if not chunk_method:
  289. chunk_method = "naive"
  290. # Define default configurations for each chunk method
  291. key_mapping = {
  292. "naive": {"chunk_token_num": 512, "delimiter": r"\n", "html4excel": False, "layout_recognize": "DeepDOC", "raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  293. "qa": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  294. "tag": None,
  295. "resume": None,
  296. "manual": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  297. "table": None,
  298. "paper": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  299. "book": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  300. "laws": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  301. "presentation": {"raptor": {"use_raptor": False}, "graphrag": {"use_graphrag": False}},
  302. "one": None,
  303. "knowledge_graph": {
  304. "chunk_token_num": 8192,
  305. "delimiter": r"\n",
  306. "entity_types": ["organization", "person", "location", "event", "time"],
  307. "raptor": {"use_raptor": False},
  308. "graphrag": {"use_graphrag": False},
  309. },
  310. "email": None,
  311. "picture": None,
  312. }
  313. default_config = key_mapping[chunk_method]
  314. # If no parser_config provided, return default
  315. if not parser_config:
  316. return default_config
  317. # If parser_config is provided, merge with defaults to ensure required fields exist
  318. if default_config is None:
  319. return parser_config
  320. # Ensure raptor and graphrag fields have default values if not provided
  321. merged_config = deep_merge(default_config, parser_config)
  322. return merged_config
  323. def get_data_openai(
  324. id=None,
  325. created=None,
  326. model=None,
  327. prompt_tokens=0,
  328. completion_tokens=0,
  329. content=None,
  330. finish_reason=None,
  331. object="chat.completion",
  332. param=None,
  333. ):
  334. total_tokens = prompt_tokens + completion_tokens
  335. return {
  336. "id": f"{id}",
  337. "object": object,
  338. "created": int(time.time()) if created else None,
  339. "model": model,
  340. "param": param,
  341. "usage": {
  342. "prompt_tokens": prompt_tokens,
  343. "completion_tokens": completion_tokens,
  344. "total_tokens": total_tokens,
  345. "completion_tokens_details": {"reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0},
  346. },
  347. "choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": finish_reason, "index": 0}],
  348. }
  349. def check_duplicate_ids(ids, id_type="item"):
  350. """
  351. Check for duplicate IDs in a list and return unique IDs and error messages.
  352. Args:
  353. ids (list): List of IDs to check for duplicates
  354. id_type (str): Type of ID for error messages (e.g., 'document', 'dataset', 'chunk')
  355. Returns:
  356. tuple: (unique_ids, error_messages)
  357. - unique_ids (list): List of unique IDs
  358. - error_messages (list): List of error messages for duplicate IDs
  359. """
  360. id_count = {}
  361. duplicate_messages = []
  362. # Count occurrences of each ID
  363. for id_value in ids:
  364. id_count[id_value] = id_count.get(id_value, 0) + 1
  365. # Check for duplicates
  366. for id_value, count in id_count.items():
  367. if count > 1:
  368. duplicate_messages.append(f"Duplicate {id_type} ids: {id_value}")
  369. # Return unique IDs and error messages
  370. return list(set(ids)), duplicate_messages
  371. def verify_embedding_availability(embd_id: str, tenant_id: str) -> tuple[bool, Response | None]:
  372. """
  373. Verifies availability of an embedding model for a specific tenant.
  374. Performs comprehensive verification through:
  375. 1. Identifier Parsing: Decomposes embd_id into name and factory components
  376. 2. System Verification: Checks model registration in LLMService
  377. 3. Tenant Authorization: Validates tenant-specific model assignments
  378. 4. Built-in Model Check: Confirms inclusion in predefined system models
  379. Args:
  380. embd_id (str): Unique identifier for the embedding model in format "model_name@factory"
  381. tenant_id (str): Tenant identifier for access control
  382. Returns:
  383. tuple[bool, Response | None]:
  384. - First element (bool):
  385. - True: Model is available and authorized
  386. - False: Validation failed
  387. - Second element contains:
  388. - None on success
  389. - Error detail dict on failure
  390. Raises:
  391. ValueError: When model identifier format is invalid
  392. OperationalError: When database connection fails (auto-handled)
  393. Examples:
  394. >>> verify_embedding_availability("text-embedding@openai", "tenant_123")
  395. (True, None)
  396. >>> verify_embedding_availability("invalid_model", "tenant_123")
  397. (False, {'code': 101, 'message': "Unsupported model: <invalid_model>"})
  398. """
  399. try:
  400. llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(embd_id)
  401. in_llm_service = bool(LLMService.query(llm_name=llm_name, fid=llm_factory, model_type="embedding"))
  402. tenant_llms = TenantLLMService.get_my_llms(tenant_id=tenant_id)
  403. is_tenant_model = any(llm["llm_name"] == llm_name and llm["llm_factory"] == llm_factory and llm["model_type"] == "embedding" for llm in tenant_llms)
  404. is_builtin_model = embd_id in settings.BUILTIN_EMBEDDING_MODELS
  405. if not (is_builtin_model or is_tenant_model or in_llm_service):
  406. return False, get_error_argument_result(f"Unsupported model: <{embd_id}>")
  407. if not (is_builtin_model or is_tenant_model):
  408. return False, get_error_argument_result(f"Unauthorized model: <{embd_id}>")
  409. except OperationalError as e:
  410. logging.exception(e)
  411. return False, get_error_data_result(message="Database operation failed")
  412. return True, None
  413. def deep_merge(default: dict, custom: dict) -> dict:
  414. """
  415. Recursively merges two dictionaries with priority given to `custom` values.
  416. Creates a deep copy of the `default` dictionary and iteratively merges nested
  417. dictionaries using a stack-based approach. Non-dict values in `custom` will
  418. completely override corresponding entries in `default`.
  419. Args:
  420. default (dict): Base dictionary containing default values.
  421. custom (dict): Dictionary containing overriding values.
  422. Returns:
  423. dict: New merged dictionary combining values from both inputs.
  424. Example:
  425. >>> from copy import deepcopy
  426. >>> default = {"a": 1, "nested": {"x": 10, "y": 20}}
  427. >>> custom = {"b": 2, "nested": {"y": 99, "z": 30}}
  428. >>> deep_merge(default, custom)
  429. {'a': 1, 'b': 2, 'nested': {'x': 10, 'y': 99, 'z': 30}}
  430. >>> deep_merge({"config": {"mode": "auto"}}, {"config": "manual"})
  431. {'config': 'manual'}
  432. Notes:
  433. 1. Merge priority is always given to `custom` values at all nesting levels
  434. 2. Non-dict values (e.g. list, str) in `custom` will replace entire values
  435. in `default`, even if the original value was a dictionary
  436. 3. Time complexity: O(N) where N is total key-value pairs in `custom`
  437. 4. Recommended for configuration merging and nested data updates
  438. """
  439. merged = deepcopy(default)
  440. stack = [(merged, custom)]
  441. while stack:
  442. base_dict, override_dict = stack.pop()
  443. for key, val in override_dict.items():
  444. if key in base_dict and isinstance(val, dict) and isinstance(base_dict[key], dict):
  445. stack.append((base_dict[key], val))
  446. else:
  447. base_dict[key] = val
  448. return merged
  449. def remap_dictionary_keys(source_data: dict, key_aliases: dict = None) -> dict:
  450. """
  451. Transform dictionary keys using a configurable mapping schema.
  452. Args:
  453. source_data: Original dictionary to process
  454. key_aliases: Custom key transformation rules (Optional)
  455. When provided, overrides default key mapping
  456. Format: {<original_key>: <new_key>, ...}
  457. Returns:
  458. dict: New dictionary with transformed keys preserving original values
  459. Example:
  460. >>> input_data = {"old_key": "value", "another_field": 42}
  461. >>> remap_dictionary_keys(input_data, {"old_key": "new_key"})
  462. {'new_key': 'value', 'another_field': 42}
  463. """
  464. DEFAULT_KEY_MAP = {
  465. "chunk_num": "chunk_count",
  466. "doc_num": "document_count",
  467. "parser_id": "chunk_method",
  468. "embd_id": "embedding_model",
  469. }
  470. transformed_data = {}
  471. mapping = key_aliases or DEFAULT_KEY_MAP
  472. for original_key, value in source_data.items():
  473. mapped_key = mapping.get(original_key, original_key)
  474. transformed_data[mapped_key] = value
  475. return transformed_data
  476. def get_mcp_tools(mcp_servers: list, timeout: float | int = 10) -> tuple[dict, str]:
  477. results = {}
  478. tool_call_sessions = []
  479. try:
  480. for mcp_server in mcp_servers:
  481. server_key = mcp_server.id
  482. cached_tools = mcp_server.variables.get("tools", {})
  483. tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
  484. tool_call_sessions.append(tool_call_session)
  485. try:
  486. tools = tool_call_session.get_tools(timeout)
  487. except Exception:
  488. tools = []
  489. results[server_key] = []
  490. for tool in tools:
  491. tool_dict = tool.model_dump()
  492. cached_tool = cached_tools.get(tool_dict["name"], {})
  493. tool_dict["enabled"] = cached_tool.get("enabled", True)
  494. results[server_key].append(tool_dict)
  495. # PERF: blocking call to close sessions — consider moving to background thread or task queue
  496. close_multiple_mcp_toolcall_sessions(tool_call_sessions)
  497. return results, ""
  498. except Exception as e:
  499. return {}, str(e)
  500. TimeoutException = Union[Type[BaseException], BaseException]
  501. OnTimeoutCallback = Union[Callable[..., Any], Coroutine[Any, Any, Any]]
  502. def timeout(seconds: float | int = None, attempts: int = 2, *, exception: Optional[TimeoutException] = None, on_timeout: Optional[OnTimeoutCallback] = None):
  503. def decorator(func):
  504. @wraps(func)
  505. def wrapper(*args, **kwargs):
  506. result_queue = queue.Queue(maxsize=1)
  507. def target():
  508. try:
  509. result = func(*args, **kwargs)
  510. result_queue.put(result)
  511. except Exception as e:
  512. result_queue.put(e)
  513. thread = threading.Thread(target=target)
  514. thread.daemon = True
  515. thread.start()
  516. for a in range(attempts):
  517. try:
  518. result = result_queue.get(timeout=seconds)
  519. if isinstance(result, Exception):
  520. raise result
  521. return result
  522. except queue.Empty:
  523. pass
  524. raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
  525. @wraps(func)
  526. async def async_wrapper(*args, **kwargs) -> Any:
  527. if seconds is None:
  528. return await func(*args, **kwargs)
  529. for a in range(attempts):
  530. try:
  531. with trio.fail_after(seconds):
  532. return await func(*args, **kwargs)
  533. except trio.TooSlowError:
  534. if a < attempts - 1:
  535. continue
  536. if on_timeout is not None:
  537. if callable(on_timeout):
  538. result = on_timeout()
  539. if isinstance(result, Coroutine):
  540. return await result
  541. return result
  542. return on_timeout
  543. if exception is None:
  544. raise TimeoutError(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
  545. if isinstance(exception, BaseException):
  546. raise exception
  547. if isinstance(exception, type) and issubclass(exception, BaseException):
  548. raise exception(f"Operation timed out after {seconds} seconds and {attempts} attempts.")
  549. raise RuntimeError("Invalid exception type provided")
  550. if asyncio.iscoroutinefunction(func):
  551. return async_wrapper
  552. return wrapper
  553. return decorator
  554. async def is_strong_enough(chat_model, embedding_model):
  555. @timeout(60, 2)
  556. async def _is_strong_enough():
  557. nonlocal chat_model, embedding_model
  558. if embedding_model:
  559. with trio.fail_after(10):
  560. _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
  561. if chat_model:
  562. with trio.fail_after(30):
  563. res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {}))
  564. if res.find("**ERROR**") >= 0:
  565. raise Exception(res)
  566. # Pressure test for GraphRAG task
  567. async with trio.open_nursery() as nursery:
  568. for _ in range(32):
  569. nursery.start_soon(_is_strong_enough)