You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_trace_manager.py 32KB


  1. import json
  2. import logging
  3. import os
  4. import queue
  5. import threading
  6. import time
  7. from datetime import timedelta
  8. from typing import Any, Optional, Union
  9. from uuid import UUID, uuid4
  10. from cachetools import LRUCache
  11. from flask import current_app
  12. from sqlalchemy import select
  13. from sqlalchemy.orm import Session
  14. from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
  15. from core.ops.entities.config_entity import (
  16. OPS_FILE_PATH,
  17. LangfuseConfig,
  18. LangSmithConfig,
  19. OpikConfig,
  20. TracingProviderEnum,
  21. WeaveConfig,
  22. )
  23. from core.ops.entities.trace_entity import (
  24. DatasetRetrievalTraceInfo,
  25. GenerateNameTraceInfo,
  26. MessageTraceInfo,
  27. ModerationTraceInfo,
  28. SuggestedQuestionTraceInfo,
  29. TaskData,
  30. ToolTraceInfo,
  31. TraceTaskName,
  32. WorkflowTraceInfo,
  33. )
  34. from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
  35. from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
  36. from core.ops.opik_trace.opik_trace import OpikDataTrace
  37. from core.ops.utils import get_message_data
  38. from core.ops.weave_trace.weave_trace import WeaveDataTrace
  39. from extensions.ext_database import db
  40. from extensions.ext_storage import storage
  41. from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
  42. from models.workflow import WorkflowAppLog, WorkflowRun
  43. from tasks.ops_trace_task import process_trace_tasks
  44. def build_opik_trace_instance(config: OpikConfig):
  45. return OpikDataTrace(config)
  46. provider_config_map: dict[str, dict[str, Any]] = {
  47. TracingProviderEnum.LANGFUSE.value: {
  48. "config_class": LangfuseConfig,
  49. "secret_keys": ["public_key", "secret_key"],
  50. "other_keys": ["host", "project_key"],
  51. "trace_instance": LangFuseDataTrace,
  52. },
  53. TracingProviderEnum.LANGSMITH.value: {
  54. "config_class": LangSmithConfig,
  55. "secret_keys": ["api_key"],
  56. "other_keys": ["project", "endpoint"],
  57. "trace_instance": LangSmithDataTrace,
  58. },
  59. TracingProviderEnum.OPIK.value: {
  60. "config_class": OpikConfig,
  61. "secret_keys": ["api_key"],
  62. "other_keys": ["project", "url", "workspace"],
  63. "trace_instance": lambda config: build_opik_trace_instance(config),
  64. },
  65. TracingProviderEnum.WEAVE.value: {
  66. "config_class": WeaveConfig,
  67. "secret_keys": ["api_key"],
  68. "other_keys": ["project", "entity", "endpoint"],
  69. "trace_instance": WeaveDataTrace,
  70. },
  71. }
  72. class OpsTraceManager:
  73. ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
  74. @classmethod
  75. def encrypt_tracing_config(
  76. cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
  77. ):
  78. """
  79. Encrypt tracing config.
  80. :param tenant_id: tenant id
  81. :param tracing_provider: tracing provider
  82. :param tracing_config: tracing config dictionary to be encrypted
  83. :param current_trace_config: current tracing configuration for keeping existing values
  84. :return: encrypted tracing configuration
  85. """
  86. # Get the configuration class and the keys that require encryption
  87. config_class, secret_keys, other_keys = (
  88. provider_config_map[tracing_provider]["config_class"],
  89. provider_config_map[tracing_provider]["secret_keys"],
  90. provider_config_map[tracing_provider]["other_keys"],
  91. )
  92. new_config = {}
  93. # Encrypt necessary keys
  94. for key in secret_keys:
  95. if key in tracing_config:
  96. if "*" in tracing_config[key]:
  97. # If the key contains '*', retain the original value from the current config
  98. new_config[key] = current_trace_config.get(key, tracing_config[key])
  99. else:
  100. # Otherwise, encrypt the key
  101. new_config[key] = encrypt_token(tenant_id, tracing_config[key])
  102. for key in other_keys:
  103. new_config[key] = tracing_config.get(key, "")
  104. # Create a new instance of the config class with the new configuration
  105. encrypted_config = config_class(**new_config)
  106. return encrypted_config.model_dump()
  107. @classmethod
  108. def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
  109. """
  110. Decrypt tracing config
  111. :param tenant_id: tenant id
  112. :param tracing_provider: tracing provider
  113. :param tracing_config: tracing config
  114. :return:
  115. """
  116. config_class, secret_keys, other_keys = (
  117. provider_config_map[tracing_provider]["config_class"],
  118. provider_config_map[tracing_provider]["secret_keys"],
  119. provider_config_map[tracing_provider]["other_keys"],
  120. )
  121. new_config = {}
  122. for key in secret_keys:
  123. if key in tracing_config:
  124. new_config[key] = decrypt_token(tenant_id, tracing_config[key])
  125. for key in other_keys:
  126. new_config[key] = tracing_config.get(key, "")
  127. return config_class(**new_config).model_dump()
  128. @classmethod
  129. def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
  130. """
  131. Decrypt tracing config
  132. :param tracing_provider: tracing provider
  133. :param decrypt_tracing_config: tracing config
  134. :return:
  135. """
  136. config_class, secret_keys, other_keys = (
  137. provider_config_map[tracing_provider]["config_class"],
  138. provider_config_map[tracing_provider]["secret_keys"],
  139. provider_config_map[tracing_provider]["other_keys"],
  140. )
  141. new_config = {}
  142. for key in secret_keys:
  143. if key in decrypt_tracing_config:
  144. new_config[key] = obfuscated_token(decrypt_tracing_config[key])
  145. for key in other_keys:
  146. new_config[key] = decrypt_tracing_config.get(key, "")
  147. return config_class(**new_config).model_dump()
  148. @classmethod
  149. def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
  150. """
  151. Get decrypted tracing config
  152. :param app_id: app id
  153. :param tracing_provider: tracing provider
  154. :return:
  155. """
  156. trace_config_data: Optional[TraceAppConfig] = (
  157. db.session.query(TraceAppConfig)
  158. .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
  159. .first()
  160. )
  161. if not trace_config_data:
  162. return None
  163. # decrypt_token
  164. app = db.session.query(App).filter(App.id == app_id).first()
  165. if not app:
  166. raise ValueError("App not found")
  167. tenant_id = app.tenant_id
  168. decrypt_tracing_config = cls.decrypt_tracing_config(
  169. tenant_id, tracing_provider, trace_config_data.tracing_config
  170. )
  171. return decrypt_tracing_config
  172. @classmethod
  173. def get_ops_trace_instance(
  174. cls,
  175. app_id: Optional[Union[UUID, str]] = None,
  176. ):
  177. """
  178. Get ops trace through model config
  179. :param app_id: app_id
  180. :return:
  181. """
  182. if isinstance(app_id, UUID):
  183. app_id = str(app_id)
  184. if app_id is None:
  185. return None
  186. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  187. if app is None:
  188. return None
  189. app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
  190. if app_ops_trace_config is None:
  191. return None
  192. if not app_ops_trace_config.get("enabled"):
  193. return None
  194. tracing_provider = app_ops_trace_config.get("tracing_provider")
  195. if tracing_provider is None or tracing_provider not in provider_config_map:
  196. return None
  197. # decrypt_token
  198. decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
  199. if not decrypt_trace_config:
  200. return None
  201. trace_instance, config_class = (
  202. provider_config_map[tracing_provider]["trace_instance"],
  203. provider_config_map[tracing_provider]["config_class"],
  204. )
  205. decrypt_trace_config_key = str(decrypt_trace_config)
  206. tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
  207. if tracing_instance is None:
  208. # create new tracing_instance and update the cache if it absent
  209. tracing_instance = trace_instance(config_class(**decrypt_trace_config))
  210. cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
  211. logging.info(f"new tracing_instance for app_id: {app_id}")
  212. return tracing_instance
  213. @classmethod
  214. def get_app_config_through_message_id(cls, message_id: str):
  215. app_model_config = None
  216. message_data = db.session.query(Message).filter(Message.id == message_id).first()
  217. if not message_data:
  218. return None
  219. conversation_id = message_data.conversation_id
  220. conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
  221. if not conversation_data:
  222. return None
  223. if conversation_data.app_model_config_id:
  224. app_model_config = (
  225. db.session.query(AppModelConfig)
  226. .filter(AppModelConfig.id == conversation_data.app_model_config_id)
  227. .first()
  228. )
  229. elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
  230. app_model_config = conversation_data.override_model_configs
  231. return app_model_config
  232. @classmethod
  233. def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
  234. """
  235. Update app tracing config
  236. :param app_id: app id
  237. :param enabled: enabled
  238. :param tracing_provider: tracing provider
  239. :return:
  240. """
  241. # auth check
  242. if tracing_provider not in provider_config_map and tracing_provider is not None:
  243. raise ValueError(f"Invalid tracing provider: {tracing_provider}")
  244. app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  245. if not app_config:
  246. raise ValueError("App not found")
  247. app_config.tracing = json.dumps(
  248. {
  249. "enabled": enabled,
  250. "tracing_provider": tracing_provider,
  251. }
  252. )
  253. db.session.commit()
  254. @classmethod
  255. def get_app_tracing_config(cls, app_id: str):
  256. """
  257. Get app tracing config
  258. :param app_id: app id
  259. :return:
  260. """
  261. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  262. if not app:
  263. raise ValueError("App not found")
  264. if not app.tracing:
  265. return {"enabled": False, "tracing_provider": None}
  266. app_trace_config = json.loads(app.tracing)
  267. return app_trace_config
  268. @staticmethod
  269. def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
  270. """
  271. Check trace config is effective
  272. :param tracing_config: tracing config
  273. :param tracing_provider: tracing provider
  274. :return:
  275. """
  276. config_type, trace_instance = (
  277. provider_config_map[tracing_provider]["config_class"],
  278. provider_config_map[tracing_provider]["trace_instance"],
  279. )
  280. tracing_config = config_type(**tracing_config)
  281. return trace_instance(tracing_config).api_check()
  282. @staticmethod
  283. def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
  284. """
  285. get trace config is project key
  286. :param tracing_config: tracing config
  287. :param tracing_provider: tracing provider
  288. :return:
  289. """
  290. config_type, trace_instance = (
  291. provider_config_map[tracing_provider]["config_class"],
  292. provider_config_map[tracing_provider]["trace_instance"],
  293. )
  294. tracing_config = config_type(**tracing_config)
  295. return trace_instance(tracing_config).get_project_key()
  296. @staticmethod
  297. def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
  298. """
  299. get trace config is project key
  300. :param tracing_config: tracing config
  301. :param tracing_provider: tracing provider
  302. :return:
  303. """
  304. config_type, trace_instance = (
  305. provider_config_map[tracing_provider]["config_class"],
  306. provider_config_map[tracing_provider]["trace_instance"],
  307. )
  308. tracing_config = config_type(**tracing_config)
  309. return trace_instance(tracing_config).get_project_url()
  310. class TraceTask:
  311. def __init__(
  312. self,
  313. trace_type: Any,
  314. message_id: Optional[str] = None,
  315. workflow_run: Optional[WorkflowRun] = None,
  316. conversation_id: Optional[str] = None,
  317. user_id: Optional[str] = None,
  318. timer: Optional[Any] = None,
  319. **kwargs,
  320. ):
  321. self.trace_type = trace_type
  322. self.message_id = message_id
  323. self.workflow_run_id = workflow_run.id if workflow_run else None
  324. self.conversation_id = conversation_id
  325. self.user_id = user_id
  326. self.timer = timer
  327. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  328. self.app_id = None
  329. self.kwargs = kwargs
  330. def execute(self):
  331. return self.preprocess()
  332. def preprocess(self):
  333. preprocess_map = {
  334. TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
  335. TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
  336. workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
  337. ),
  338. TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
  339. TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
  340. message_id=self.message_id, timer=self.timer, **self.kwargs
  341. ),
  342. TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
  343. message_id=self.message_id, timer=self.timer, **self.kwargs
  344. ),
  345. TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
  346. message_id=self.message_id, timer=self.timer, **self.kwargs
  347. ),
  348. TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
  349. message_id=self.message_id, timer=self.timer, **self.kwargs
  350. ),
  351. TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
  352. conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
  353. ),
  354. }
  355. return preprocess_map.get(self.trace_type, lambda: None)()
  356. # process methods for different trace types
  357. def conversation_trace(self, **kwargs):
  358. return kwargs
  359. def workflow_trace(
  360. self,
  361. *,
  362. workflow_run_id: str | None,
  363. conversation_id: str | None,
  364. user_id: str | None,
  365. ):
  366. if not workflow_run_id:
  367. return {}
  368. with Session(db.engine) as session:
  369. workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
  370. workflow_run = session.scalars(workflow_run_stmt).first()
  371. if not workflow_run:
  372. raise ValueError("Workflow run not found")
  373. workflow_id = workflow_run.workflow_id
  374. tenant_id = workflow_run.tenant_id
  375. workflow_run_id = workflow_run.id
  376. workflow_run_elapsed_time = workflow_run.elapsed_time
  377. workflow_run_status = workflow_run.status
  378. workflow_run_inputs = workflow_run.inputs_dict
  379. workflow_run_outputs = workflow_run.outputs_dict
  380. workflow_run_version = workflow_run.version
  381. error = workflow_run.error or ""
  382. total_tokens = workflow_run.total_tokens
  383. file_list = workflow_run_inputs.get("sys.file") or []
  384. query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
  385. # get workflow_app_log_id
  386. workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
  387. WorkflowAppLog.tenant_id == tenant_id,
  388. WorkflowAppLog.app_id == workflow_run.app_id,
  389. WorkflowAppLog.workflow_run_id == workflow_run.id,
  390. )
  391. workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
  392. # get message_id
  393. message_id = None
  394. if conversation_id:
  395. message_data_stmt = select(Message.id).where(
  396. Message.conversation_id == conversation_id,
  397. Message.workflow_run_id == workflow_run_id,
  398. )
  399. message_id = session.scalar(message_data_stmt)
  400. metadata = {
  401. "workflow_id": workflow_id,
  402. "conversation_id": conversation_id,
  403. "workflow_run_id": workflow_run_id,
  404. "tenant_id": tenant_id,
  405. "elapsed_time": workflow_run_elapsed_time,
  406. "status": workflow_run_status,
  407. "version": workflow_run_version,
  408. "total_tokens": total_tokens,
  409. "file_list": file_list,
  410. "triggered_from": workflow_run.triggered_from,
  411. "user_id": user_id,
  412. }
  413. workflow_trace_info = WorkflowTraceInfo(
  414. workflow_data=workflow_run.to_dict(),
  415. conversation_id=conversation_id,
  416. workflow_id=workflow_id,
  417. tenant_id=tenant_id,
  418. workflow_run_id=workflow_run_id,
  419. workflow_run_elapsed_time=workflow_run_elapsed_time,
  420. workflow_run_status=workflow_run_status,
  421. workflow_run_inputs=workflow_run_inputs,
  422. workflow_run_outputs=workflow_run_outputs,
  423. workflow_run_version=workflow_run_version,
  424. error=error,
  425. total_tokens=total_tokens,
  426. file_list=file_list,
  427. query=query,
  428. metadata=metadata,
  429. workflow_app_log_id=workflow_app_log_id,
  430. message_id=message_id,
  431. start_time=workflow_run.created_at,
  432. end_time=workflow_run.finished_at,
  433. )
  434. return workflow_trace_info
  435. def message_trace(self, message_id: str | None):
  436. if not message_id:
  437. return {}
  438. message_data = get_message_data(message_id)
  439. if not message_data:
  440. return {}
  441. conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
  442. conversation_mode = db.session.scalars(conversation_mode_stmt).all()
  443. if not conversation_mode or len(conversation_mode) == 0:
  444. return {}
  445. conversation_mode = conversation_mode[0]
  446. created_at = message_data.created_at
  447. inputs = message_data.message
  448. # get message file data
  449. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  450. file_list = []
  451. if message_file_data and message_file_data.url is not None:
  452. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  453. file_list.append(file_url)
  454. metadata = {
  455. "conversation_id": message_data.conversation_id,
  456. "ls_provider": message_data.model_provider,
  457. "ls_model_name": message_data.model_id,
  458. "status": message_data.status,
  459. "from_end_user_id": message_data.from_end_user_id,
  460. "from_account_id": message_data.from_account_id,
  461. "agent_based": message_data.agent_based,
  462. "workflow_run_id": message_data.workflow_run_id,
  463. "from_source": message_data.from_source,
  464. "message_id": message_id,
  465. }
  466. message_tokens = message_data.message_tokens
  467. message_trace_info = MessageTraceInfo(
  468. message_id=message_id,
  469. message_data=message_data.to_dict(),
  470. conversation_model=conversation_mode,
  471. message_tokens=message_tokens,
  472. answer_tokens=message_data.answer_tokens,
  473. total_tokens=message_tokens + message_data.answer_tokens,
  474. error=message_data.error or "",
  475. inputs=inputs,
  476. outputs=message_data.answer,
  477. file_list=file_list,
  478. start_time=created_at,
  479. end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
  480. metadata=metadata,
  481. message_file_data=message_file_data,
  482. conversation_mode=conversation_mode,
  483. )
  484. return message_trace_info
  485. def moderation_trace(self, message_id, timer, **kwargs):
  486. moderation_result = kwargs.get("moderation_result")
  487. if not moderation_result:
  488. return {}
  489. inputs = kwargs.get("inputs")
  490. message_data = get_message_data(message_id)
  491. if not message_data:
  492. return {}
  493. metadata = {
  494. "message_id": message_id,
  495. "action": moderation_result.action,
  496. "preset_response": moderation_result.preset_response,
  497. "query": moderation_result.query,
  498. }
  499. # get workflow_app_log_id
  500. workflow_app_log_id = None
  501. if message_data.workflow_run_id:
  502. workflow_app_log_data = (
  503. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  504. )
  505. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  506. moderation_trace_info = ModerationTraceInfo(
  507. message_id=workflow_app_log_id or message_id,
  508. inputs=inputs,
  509. message_data=message_data.to_dict(),
  510. flagged=moderation_result.flagged,
  511. action=moderation_result.action,
  512. preset_response=moderation_result.preset_response,
  513. query=moderation_result.query,
  514. start_time=timer.get("start"),
  515. end_time=timer.get("end"),
  516. metadata=metadata,
  517. )
  518. return moderation_trace_info
  519. def suggested_question_trace(self, message_id, timer, **kwargs):
  520. suggested_question = kwargs.get("suggested_question", [])
  521. message_data = get_message_data(message_id)
  522. if not message_data:
  523. return {}
  524. metadata = {
  525. "message_id": message_id,
  526. "ls_provider": message_data.model_provider,
  527. "ls_model_name": message_data.model_id,
  528. "status": message_data.status,
  529. "from_end_user_id": message_data.from_end_user_id,
  530. "from_account_id": message_data.from_account_id,
  531. "agent_based": message_data.agent_based,
  532. "workflow_run_id": message_data.workflow_run_id,
  533. "from_source": message_data.from_source,
  534. }
  535. # get workflow_app_log_id
  536. workflow_app_log_id = None
  537. if message_data.workflow_run_id:
  538. workflow_app_log_data = (
  539. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  540. )
  541. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  542. suggested_question_trace_info = SuggestedQuestionTraceInfo(
  543. message_id=workflow_app_log_id or message_id,
  544. message_data=message_data.to_dict(),
  545. inputs=message_data.message,
  546. outputs=message_data.answer,
  547. start_time=timer.get("start"),
  548. end_time=timer.get("end"),
  549. metadata=metadata,
  550. total_tokens=message_data.message_tokens + message_data.answer_tokens,
  551. status=message_data.status,
  552. error=message_data.error,
  553. from_account_id=message_data.from_account_id,
  554. agent_based=message_data.agent_based,
  555. from_source=message_data.from_source,
  556. model_provider=message_data.model_provider,
  557. model_id=message_data.model_id,
  558. suggested_question=suggested_question,
  559. level=message_data.status,
  560. status_message=message_data.error,
  561. )
  562. return suggested_question_trace_info
  563. def dataset_retrieval_trace(self, message_id, timer, **kwargs):
  564. documents = kwargs.get("documents")
  565. message_data = get_message_data(message_id)
  566. if not message_data:
  567. return {}
  568. metadata = {
  569. "message_id": message_id,
  570. "ls_provider": message_data.model_provider,
  571. "ls_model_name": message_data.model_id,
  572. "status": message_data.status,
  573. "from_end_user_id": message_data.from_end_user_id,
  574. "from_account_id": message_data.from_account_id,
  575. "agent_based": message_data.agent_based,
  576. "workflow_run_id": message_data.workflow_run_id,
  577. "from_source": message_data.from_source,
  578. }
  579. dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
  580. message_id=message_id,
  581. inputs=message_data.query or message_data.inputs,
  582. documents=[doc.model_dump() for doc in documents] if documents else [],
  583. start_time=timer.get("start"),
  584. end_time=timer.get("end"),
  585. metadata=metadata,
  586. message_data=message_data.to_dict(),
  587. )
  588. return dataset_retrieval_trace_info
  589. def tool_trace(self, message_id, timer, **kwargs):
  590. tool_name = kwargs.get("tool_name", "")
  591. tool_inputs = kwargs.get("tool_inputs", {})
  592. tool_outputs = kwargs.get("tool_outputs", {})
  593. message_data = get_message_data(message_id)
  594. if not message_data:
  595. return {}
  596. tool_config = {}
  597. time_cost = 0
  598. error = None
  599. tool_parameters = {}
  600. created_time = message_data.created_at
  601. end_time = message_data.updated_at
  602. agent_thoughts = message_data.agent_thoughts
  603. for agent_thought in agent_thoughts:
  604. if tool_name in agent_thought.tools:
  605. created_time = agent_thought.created_at
  606. tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
  607. tool_config = tool_meta_data.get("tool_config", {})
  608. time_cost = tool_meta_data.get("time_cost", 0)
  609. end_time = created_time + timedelta(seconds=time_cost)
  610. error = tool_meta_data.get("error", "")
  611. tool_parameters = tool_meta_data.get("tool_parameters", {})
  612. metadata = {
  613. "message_id": message_id,
  614. "tool_name": tool_name,
  615. "tool_inputs": tool_inputs,
  616. "tool_outputs": tool_outputs,
  617. "tool_config": tool_config,
  618. "time_cost": time_cost,
  619. "error": error,
  620. "tool_parameters": tool_parameters,
  621. }
  622. file_url = ""
  623. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  624. if message_file_data:
  625. message_file_id = message_file_data.id if message_file_data else None
  626. type = message_file_data.type
  627. created_by_role = message_file_data.created_by_role
  628. created_user_id = message_file_data.created_by
  629. file_url = f"{self.file_base_url}/{message_file_data.url}"
  630. metadata.update(
  631. {
  632. "message_file_id": message_file_id,
  633. "created_by_role": created_by_role,
  634. "created_user_id": created_user_id,
  635. "type": type,
  636. }
  637. )
  638. tool_trace_info = ToolTraceInfo(
  639. message_id=message_id,
  640. message_data=message_data.to_dict(),
  641. tool_name=tool_name,
  642. start_time=timer.get("start") if timer else created_time,
  643. end_time=timer.get("end") if timer else end_time,
  644. tool_inputs=tool_inputs,
  645. tool_outputs=tool_outputs,
  646. metadata=metadata,
  647. message_file_data=message_file_data,
  648. error=error,
  649. inputs=message_data.message,
  650. outputs=message_data.answer,
  651. tool_config=tool_config,
  652. time_cost=time_cost,
  653. tool_parameters=tool_parameters,
  654. file_url=file_url,
  655. )
  656. return tool_trace_info
  657. def generate_name_trace(self, conversation_id, timer, **kwargs):
  658. generate_conversation_name = kwargs.get("generate_conversation_name")
  659. inputs = kwargs.get("inputs")
  660. tenant_id = kwargs.get("tenant_id")
  661. if not tenant_id:
  662. return {}
  663. start_time = timer.get("start")
  664. end_time = timer.get("end")
  665. metadata = {
  666. "conversation_id": conversation_id,
  667. "tenant_id": tenant_id,
  668. }
  669. generate_name_trace_info = GenerateNameTraceInfo(
  670. conversation_id=conversation_id,
  671. inputs=inputs,
  672. outputs=generate_conversation_name,
  673. start_time=start_time,
  674. end_time=end_time,
  675. metadata=metadata,
  676. tenant_id=tenant_id,
  677. )
  678. return generate_name_trace_info
  679. trace_manager_timer: Optional[threading.Timer] = None
  680. trace_manager_queue: queue.Queue = queue.Queue()
  681. trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
  682. trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
  683. class TraceQueueManager:
  684. def __init__(self, app_id=None, user_id=None):
  685. global trace_manager_timer
  686. self.app_id = app_id
  687. self.user_id = user_id
  688. self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
  689. self.flask_app = current_app._get_current_object() # type: ignore
  690. if trace_manager_timer is None:
  691. self.start_timer()
  692. def add_trace_task(self, trace_task: TraceTask):
  693. global trace_manager_timer, trace_manager_queue
  694. try:
  695. if self.trace_instance:
  696. trace_task.app_id = self.app_id
  697. trace_manager_queue.put(trace_task)
  698. except Exception as e:
  699. logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
  700. finally:
  701. self.start_timer()
  702. def collect_tasks(self):
  703. global trace_manager_queue
  704. tasks: list[TraceTask] = []
  705. while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
  706. task = trace_manager_queue.get_nowait()
  707. tasks.append(task)
  708. trace_manager_queue.task_done()
  709. return tasks
  710. def run(self):
  711. try:
  712. tasks = self.collect_tasks()
  713. if tasks:
  714. self.send_to_celery(tasks)
  715. except Exception as e:
  716. logging.exception("Error processing trace tasks")
  717. def start_timer(self):
  718. global trace_manager_timer
  719. if trace_manager_timer is None or not trace_manager_timer.is_alive():
  720. trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
  721. trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
  722. trace_manager_timer.daemon = False
  723. trace_manager_timer.start()
  724. def send_to_celery(self, tasks: list[TraceTask]):
  725. with self.flask_app.app_context():
  726. for task in tasks:
  727. if task.app_id is None:
  728. continue
  729. file_id = uuid4().hex
  730. trace_info = task.execute()
  731. task_data = TaskData(
  732. app_id=task.app_id,
  733. trace_info_type=type(trace_info).__name__,
  734. trace_info=trace_info.model_dump() if trace_info else None,
  735. )
  736. file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
  737. storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
  738. file_info = {
  739. "file_id": file_id,
  740. "app_id": task.app_id,
  741. }
  742. process_trace_tasks.delay(file_info)