You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_trace_manager.py 33KB


  1. import json
  2. import logging
  3. import os
  4. import queue
  5. import threading
  6. import time
  7. from datetime import timedelta
  8. from typing import Any, Optional, Union
  9. from uuid import UUID, uuid4
  10. from cachetools import LRUCache
  11. from flask import current_app
  12. from sqlalchemy import select
  13. from sqlalchemy.orm import Session
  14. from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
  15. from core.ops.entities.config_entity import (
  16. OPS_FILE_PATH,
  17. TracingProviderEnum,
  18. )
  19. from core.ops.entities.trace_entity import (
  20. DatasetRetrievalTraceInfo,
  21. GenerateNameTraceInfo,
  22. MessageTraceInfo,
  23. ModerationTraceInfo,
  24. SuggestedQuestionTraceInfo,
  25. TaskData,
  26. ToolTraceInfo,
  27. TraceTaskName,
  28. WorkflowTraceInfo,
  29. )
  30. from core.ops.utils import get_message_data
  31. from core.workflow.entities.workflow_execution_entities import WorkflowExecution
  32. from extensions.ext_database import db
  33. from extensions.ext_storage import storage
  34. from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
  35. from models.workflow import WorkflowAppLog, WorkflowRun
  36. from tasks.ops_trace_task import process_trace_tasks
  37. class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
  38. def __getitem__(self, provider: str) -> dict[str, Any]:
  39. match provider:
  40. case TracingProviderEnum.LANGFUSE:
  41. from core.ops.entities.config_entity import LangfuseConfig
  42. from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
  43. return {
  44. "config_class": LangfuseConfig,
  45. "secret_keys": ["public_key", "secret_key"],
  46. "other_keys": ["host", "project_key"],
  47. "trace_instance": LangFuseDataTrace,
  48. }
  49. case TracingProviderEnum.LANGSMITH:
  50. from core.ops.entities.config_entity import LangSmithConfig
  51. from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
  52. return {
  53. "config_class": LangSmithConfig,
  54. "secret_keys": ["api_key"],
  55. "other_keys": ["project", "endpoint"],
  56. "trace_instance": LangSmithDataTrace,
  57. }
  58. case TracingProviderEnum.OPIK:
  59. from core.ops.entities.config_entity import OpikConfig
  60. from core.ops.opik_trace.opik_trace import OpikDataTrace
  61. return {
  62. "config_class": OpikConfig,
  63. "secret_keys": ["api_key"],
  64. "other_keys": ["project", "url", "workspace"],
  65. "trace_instance": OpikDataTrace,
  66. }
  67. case TracingProviderEnum.WEAVE:
  68. from core.ops.entities.config_entity import WeaveConfig
  69. from core.ops.weave_trace.weave_trace import WeaveDataTrace
  70. return {
  71. "config_class": WeaveConfig,
  72. "secret_keys": ["api_key"],
  73. "other_keys": ["project", "entity", "endpoint"],
  74. "trace_instance": WeaveDataTrace,
  75. }
  76. case _:
  77. raise KeyError(f"Unsupported tracing provider: {provider}")
  78. provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
  79. class OpsTraceManager:
  80. ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
  81. @classmethod
  82. def encrypt_tracing_config(
  83. cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
  84. ):
  85. """
  86. Encrypt tracing config.
  87. :param tenant_id: tenant id
  88. :param tracing_provider: tracing provider
  89. :param tracing_config: tracing config dictionary to be encrypted
  90. :param current_trace_config: current tracing configuration for keeping existing values
  91. :return: encrypted tracing configuration
  92. """
  93. # Get the configuration class and the keys that require encryption
  94. config_class, secret_keys, other_keys = (
  95. provider_config_map[tracing_provider]["config_class"],
  96. provider_config_map[tracing_provider]["secret_keys"],
  97. provider_config_map[tracing_provider]["other_keys"],
  98. )
  99. new_config = {}
  100. # Encrypt necessary keys
  101. for key in secret_keys:
  102. if key in tracing_config:
  103. if "*" in tracing_config[key]:
  104. # If the key contains '*', retain the original value from the current config
  105. new_config[key] = current_trace_config.get(key, tracing_config[key])
  106. else:
  107. # Otherwise, encrypt the key
  108. new_config[key] = encrypt_token(tenant_id, tracing_config[key])
  109. for key in other_keys:
  110. new_config[key] = tracing_config.get(key, "")
  111. # Create a new instance of the config class with the new configuration
  112. encrypted_config = config_class(**new_config)
  113. return encrypted_config.model_dump()
  114. @classmethod
  115. def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
  116. """
  117. Decrypt tracing config
  118. :param tenant_id: tenant id
  119. :param tracing_provider: tracing provider
  120. :param tracing_config: tracing config
  121. :return:
  122. """
  123. config_class, secret_keys, other_keys = (
  124. provider_config_map[tracing_provider]["config_class"],
  125. provider_config_map[tracing_provider]["secret_keys"],
  126. provider_config_map[tracing_provider]["other_keys"],
  127. )
  128. new_config = {}
  129. for key in secret_keys:
  130. if key in tracing_config:
  131. new_config[key] = decrypt_token(tenant_id, tracing_config[key])
  132. for key in other_keys:
  133. new_config[key] = tracing_config.get(key, "")
  134. return config_class(**new_config).model_dump()
  135. @classmethod
  136. def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
  137. """
  138. Decrypt tracing config
  139. :param tracing_provider: tracing provider
  140. :param decrypt_tracing_config: tracing config
  141. :return:
  142. """
  143. config_class, secret_keys, other_keys = (
  144. provider_config_map[tracing_provider]["config_class"],
  145. provider_config_map[tracing_provider]["secret_keys"],
  146. provider_config_map[tracing_provider]["other_keys"],
  147. )
  148. new_config = {}
  149. for key in secret_keys:
  150. if key in decrypt_tracing_config:
  151. new_config[key] = obfuscated_token(decrypt_tracing_config[key])
  152. for key in other_keys:
  153. new_config[key] = decrypt_tracing_config.get(key, "")
  154. return config_class(**new_config).model_dump()
  155. @classmethod
  156. def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
  157. """
  158. Get decrypted tracing config
  159. :param app_id: app id
  160. :param tracing_provider: tracing provider
  161. :return:
  162. """
  163. trace_config_data: Optional[TraceAppConfig] = (
  164. db.session.query(TraceAppConfig)
  165. .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
  166. .first()
  167. )
  168. if not trace_config_data:
  169. return None
  170. # decrypt_token
  171. app = db.session.query(App).filter(App.id == app_id).first()
  172. if not app:
  173. raise ValueError("App not found")
  174. tenant_id = app.tenant_id
  175. decrypt_tracing_config = cls.decrypt_tracing_config(
  176. tenant_id, tracing_provider, trace_config_data.tracing_config
  177. )
  178. return decrypt_tracing_config
  179. @classmethod
  180. def get_ops_trace_instance(
  181. cls,
  182. app_id: Optional[Union[UUID, str]] = None,
  183. ):
  184. """
  185. Get ops trace through model config
  186. :param app_id: app_id
  187. :return:
  188. """
  189. if isinstance(app_id, UUID):
  190. app_id = str(app_id)
  191. if app_id is None:
  192. return None
  193. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  194. if app is None:
  195. return None
  196. app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
  197. if app_ops_trace_config is None:
  198. return None
  199. if not app_ops_trace_config.get("enabled"):
  200. return None
  201. tracing_provider = app_ops_trace_config.get("tracing_provider")
  202. if tracing_provider is None:
  203. return None
  204. try:
  205. provider_config_map[tracing_provider]
  206. except KeyError:
  207. return None
  208. # decrypt_token
  209. decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
  210. if not decrypt_trace_config:
  211. return None
  212. trace_instance, config_class = (
  213. provider_config_map[tracing_provider]["trace_instance"],
  214. provider_config_map[tracing_provider]["config_class"],
  215. )
  216. decrypt_trace_config_key = str(decrypt_trace_config)
  217. tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
  218. if tracing_instance is None:
  219. # create new tracing_instance and update the cache if it absent
  220. tracing_instance = trace_instance(config_class(**decrypt_trace_config))
  221. cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
  222. logging.info(f"new tracing_instance for app_id: {app_id}")
  223. return tracing_instance
  224. @classmethod
  225. def get_app_config_through_message_id(cls, message_id: str):
  226. app_model_config = None
  227. message_data = db.session.query(Message).filter(Message.id == message_id).first()
  228. if not message_data:
  229. return None
  230. conversation_id = message_data.conversation_id
  231. conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
  232. if not conversation_data:
  233. return None
  234. if conversation_data.app_model_config_id:
  235. app_model_config = (
  236. db.session.query(AppModelConfig)
  237. .filter(AppModelConfig.id == conversation_data.app_model_config_id)
  238. .first()
  239. )
  240. elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
  241. app_model_config = conversation_data.override_model_configs
  242. return app_model_config
  243. @classmethod
  244. def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
  245. """
  246. Update app tracing config
  247. :param app_id: app id
  248. :param enabled: enabled
  249. :param tracing_provider: tracing provider
  250. :return:
  251. """
  252. # auth check
  253. try:
  254. provider_config_map[tracing_provider]
  255. except KeyError:
  256. raise ValueError(f"Invalid tracing provider: {tracing_provider}")
  257. app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  258. if not app_config:
  259. raise ValueError("App not found")
  260. app_config.tracing = json.dumps(
  261. {
  262. "enabled": enabled,
  263. "tracing_provider": tracing_provider,
  264. }
  265. )
  266. db.session.commit()
  267. @classmethod
  268. def get_app_tracing_config(cls, app_id: str):
  269. """
  270. Get app tracing config
  271. :param app_id: app id
  272. :return:
  273. """
  274. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  275. if not app:
  276. raise ValueError("App not found")
  277. if not app.tracing:
  278. return {"enabled": False, "tracing_provider": None}
  279. app_trace_config = json.loads(app.tracing)
  280. return app_trace_config
  281. @staticmethod
  282. def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
  283. """
  284. Check trace config is effective
  285. :param tracing_config: tracing config
  286. :param tracing_provider: tracing provider
  287. :return:
  288. """
  289. config_type, trace_instance = (
  290. provider_config_map[tracing_provider]["config_class"],
  291. provider_config_map[tracing_provider]["trace_instance"],
  292. )
  293. tracing_config = config_type(**tracing_config)
  294. return trace_instance(tracing_config).api_check()
  295. @staticmethod
  296. def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
  297. """
  298. get trace config is project key
  299. :param tracing_config: tracing config
  300. :param tracing_provider: tracing provider
  301. :return:
  302. """
  303. config_type, trace_instance = (
  304. provider_config_map[tracing_provider]["config_class"],
  305. provider_config_map[tracing_provider]["trace_instance"],
  306. )
  307. tracing_config = config_type(**tracing_config)
  308. return trace_instance(tracing_config).get_project_key()
  309. @staticmethod
  310. def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
  311. """
  312. get trace config is project key
  313. :param tracing_config: tracing config
  314. :param tracing_provider: tracing provider
  315. :return:
  316. """
  317. config_type, trace_instance = (
  318. provider_config_map[tracing_provider]["config_class"],
  319. provider_config_map[tracing_provider]["trace_instance"],
  320. )
  321. tracing_config = config_type(**tracing_config)
  322. return trace_instance(tracing_config).get_project_url()
  323. class TraceTask:
  324. def __init__(
  325. self,
  326. trace_type: Any,
  327. message_id: Optional[str] = None,
  328. workflow_execution: Optional[WorkflowExecution] = None,
  329. conversation_id: Optional[str] = None,
  330. user_id: Optional[str] = None,
  331. timer: Optional[Any] = None,
  332. **kwargs,
  333. ):
  334. self.trace_type = trace_type
  335. self.message_id = message_id
  336. self.workflow_run_id = workflow_execution.id if workflow_execution else None
  337. self.conversation_id = conversation_id
  338. self.user_id = user_id
  339. self.timer = timer
  340. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  341. self.app_id = None
  342. self.kwargs = kwargs
  343. def execute(self):
  344. return self.preprocess()
  345. def preprocess(self):
  346. preprocess_map = {
  347. TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
  348. TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
  349. workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
  350. ),
  351. TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
  352. TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
  353. message_id=self.message_id, timer=self.timer, **self.kwargs
  354. ),
  355. TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
  356. message_id=self.message_id, timer=self.timer, **self.kwargs
  357. ),
  358. TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
  359. message_id=self.message_id, timer=self.timer, **self.kwargs
  360. ),
  361. TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
  362. message_id=self.message_id, timer=self.timer, **self.kwargs
  363. ),
  364. TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
  365. conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
  366. ),
  367. }
  368. return preprocess_map.get(self.trace_type, lambda: None)()
  369. # process methods for different trace types
  370. def conversation_trace(self, **kwargs):
  371. return kwargs
  372. def workflow_trace(
  373. self,
  374. *,
  375. workflow_run_id: str | None,
  376. conversation_id: str | None,
  377. user_id: str | None,
  378. ):
  379. if not workflow_run_id:
  380. return {}
  381. with Session(db.engine) as session:
  382. workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
  383. workflow_run = session.scalars(workflow_run_stmt).first()
  384. if not workflow_run:
  385. raise ValueError("Workflow run not found")
  386. workflow_id = workflow_run.workflow_id
  387. tenant_id = workflow_run.tenant_id
  388. workflow_run_id = workflow_run.id
  389. workflow_run_elapsed_time = workflow_run.elapsed_time
  390. workflow_run_status = workflow_run.status
  391. workflow_run_inputs = workflow_run.inputs_dict
  392. workflow_run_outputs = workflow_run.outputs_dict
  393. workflow_run_version = workflow_run.version
  394. error = workflow_run.error or ""
  395. total_tokens = workflow_run.total_tokens
  396. file_list = workflow_run_inputs.get("sys.file") or []
  397. query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
  398. # get workflow_app_log_id
  399. workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
  400. WorkflowAppLog.tenant_id == tenant_id,
  401. WorkflowAppLog.app_id == workflow_run.app_id,
  402. WorkflowAppLog.workflow_run_id == workflow_run.id,
  403. )
  404. workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
  405. # get message_id
  406. message_id = None
  407. if conversation_id:
  408. message_data_stmt = select(Message.id).where(
  409. Message.conversation_id == conversation_id,
  410. Message.workflow_run_id == workflow_run_id,
  411. )
  412. message_id = session.scalar(message_data_stmt)
  413. metadata = {
  414. "workflow_id": workflow_id,
  415. "conversation_id": conversation_id,
  416. "workflow_run_id": workflow_run_id,
  417. "tenant_id": tenant_id,
  418. "elapsed_time": workflow_run_elapsed_time,
  419. "status": workflow_run_status,
  420. "version": workflow_run_version,
  421. "total_tokens": total_tokens,
  422. "file_list": file_list,
  423. "triggered_from": workflow_run.triggered_from,
  424. "user_id": user_id,
  425. }
  426. workflow_trace_info = WorkflowTraceInfo(
  427. workflow_data=workflow_run.to_dict(),
  428. conversation_id=conversation_id,
  429. workflow_id=workflow_id,
  430. tenant_id=tenant_id,
  431. workflow_run_id=workflow_run_id,
  432. workflow_run_elapsed_time=workflow_run_elapsed_time,
  433. workflow_run_status=workflow_run_status,
  434. workflow_run_inputs=workflow_run_inputs,
  435. workflow_run_outputs=workflow_run_outputs,
  436. workflow_run_version=workflow_run_version,
  437. error=error,
  438. total_tokens=total_tokens,
  439. file_list=file_list,
  440. query=query,
  441. metadata=metadata,
  442. workflow_app_log_id=workflow_app_log_id,
  443. message_id=message_id,
  444. start_time=workflow_run.created_at,
  445. end_time=workflow_run.finished_at,
  446. )
  447. return workflow_trace_info
  448. def message_trace(self, message_id: str | None):
  449. if not message_id:
  450. return {}
  451. message_data = get_message_data(message_id)
  452. if not message_data:
  453. return {}
  454. conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
  455. conversation_mode = db.session.scalars(conversation_mode_stmt).all()
  456. if not conversation_mode or len(conversation_mode) == 0:
  457. return {}
  458. conversation_mode = conversation_mode[0]
  459. created_at = message_data.created_at
  460. inputs = message_data.message
  461. # get message file data
  462. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  463. file_list = []
  464. if message_file_data and message_file_data.url is not None:
  465. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  466. file_list.append(file_url)
  467. metadata = {
  468. "conversation_id": message_data.conversation_id,
  469. "ls_provider": message_data.model_provider,
  470. "ls_model_name": message_data.model_id,
  471. "status": message_data.status,
  472. "from_end_user_id": message_data.from_end_user_id,
  473. "from_account_id": message_data.from_account_id,
  474. "agent_based": message_data.agent_based,
  475. "workflow_run_id": message_data.workflow_run_id,
  476. "from_source": message_data.from_source,
  477. "message_id": message_id,
  478. }
  479. message_tokens = message_data.message_tokens
  480. message_trace_info = MessageTraceInfo(
  481. message_id=message_id,
  482. message_data=message_data.to_dict(),
  483. conversation_model=conversation_mode,
  484. message_tokens=message_tokens,
  485. answer_tokens=message_data.answer_tokens,
  486. total_tokens=message_tokens + message_data.answer_tokens,
  487. error=message_data.error or "",
  488. inputs=inputs,
  489. outputs=message_data.answer,
  490. file_list=file_list,
  491. start_time=created_at,
  492. end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
  493. metadata=metadata,
  494. message_file_data=message_file_data,
  495. conversation_mode=conversation_mode,
  496. )
  497. return message_trace_info
  498. def moderation_trace(self, message_id, timer, **kwargs):
  499. moderation_result = kwargs.get("moderation_result")
  500. if not moderation_result:
  501. return {}
  502. inputs = kwargs.get("inputs")
  503. message_data = get_message_data(message_id)
  504. if not message_data:
  505. return {}
  506. metadata = {
  507. "message_id": message_id,
  508. "action": moderation_result.action,
  509. "preset_response": moderation_result.preset_response,
  510. "query": moderation_result.query,
  511. }
  512. # get workflow_app_log_id
  513. workflow_app_log_id = None
  514. if message_data.workflow_run_id:
  515. workflow_app_log_data = (
  516. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  517. )
  518. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  519. moderation_trace_info = ModerationTraceInfo(
  520. message_id=workflow_app_log_id or message_id,
  521. inputs=inputs,
  522. message_data=message_data.to_dict(),
  523. flagged=moderation_result.flagged,
  524. action=moderation_result.action,
  525. preset_response=moderation_result.preset_response,
  526. query=moderation_result.query,
  527. start_time=timer.get("start"),
  528. end_time=timer.get("end"),
  529. metadata=metadata,
  530. )
  531. return moderation_trace_info
  532. def suggested_question_trace(self, message_id, timer, **kwargs):
  533. suggested_question = kwargs.get("suggested_question", [])
  534. message_data = get_message_data(message_id)
  535. if not message_data:
  536. return {}
  537. metadata = {
  538. "message_id": message_id,
  539. "ls_provider": message_data.model_provider,
  540. "ls_model_name": message_data.model_id,
  541. "status": message_data.status,
  542. "from_end_user_id": message_data.from_end_user_id,
  543. "from_account_id": message_data.from_account_id,
  544. "agent_based": message_data.agent_based,
  545. "workflow_run_id": message_data.workflow_run_id,
  546. "from_source": message_data.from_source,
  547. }
  548. # get workflow_app_log_id
  549. workflow_app_log_id = None
  550. if message_data.workflow_run_id:
  551. workflow_app_log_data = (
  552. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  553. )
  554. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  555. suggested_question_trace_info = SuggestedQuestionTraceInfo(
  556. message_id=workflow_app_log_id or message_id,
  557. message_data=message_data.to_dict(),
  558. inputs=message_data.message,
  559. outputs=message_data.answer,
  560. start_time=timer.get("start"),
  561. end_time=timer.get("end"),
  562. metadata=metadata,
  563. total_tokens=message_data.message_tokens + message_data.answer_tokens,
  564. status=message_data.status,
  565. error=message_data.error,
  566. from_account_id=message_data.from_account_id,
  567. agent_based=message_data.agent_based,
  568. from_source=message_data.from_source,
  569. model_provider=message_data.model_provider,
  570. model_id=message_data.model_id,
  571. suggested_question=suggested_question,
  572. level=message_data.status,
  573. status_message=message_data.error,
  574. )
  575. return suggested_question_trace_info
  576. def dataset_retrieval_trace(self, message_id, timer, **kwargs):
  577. documents = kwargs.get("documents")
  578. message_data = get_message_data(message_id)
  579. if not message_data:
  580. return {}
  581. metadata = {
  582. "message_id": message_id,
  583. "ls_provider": message_data.model_provider,
  584. "ls_model_name": message_data.model_id,
  585. "status": message_data.status,
  586. "from_end_user_id": message_data.from_end_user_id,
  587. "from_account_id": message_data.from_account_id,
  588. "agent_based": message_data.agent_based,
  589. "workflow_run_id": message_data.workflow_run_id,
  590. "from_source": message_data.from_source,
  591. }
  592. dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
  593. message_id=message_id,
  594. inputs=message_data.query or message_data.inputs,
  595. documents=[doc.model_dump() for doc in documents] if documents else [],
  596. start_time=timer.get("start"),
  597. end_time=timer.get("end"),
  598. metadata=metadata,
  599. message_data=message_data.to_dict(),
  600. )
  601. return dataset_retrieval_trace_info
  602. def tool_trace(self, message_id, timer, **kwargs):
  603. tool_name = kwargs.get("tool_name", "")
  604. tool_inputs = kwargs.get("tool_inputs", {})
  605. tool_outputs = kwargs.get("tool_outputs", {})
  606. message_data = get_message_data(message_id)
  607. if not message_data:
  608. return {}
  609. tool_config = {}
  610. time_cost = 0
  611. error = None
  612. tool_parameters = {}
  613. created_time = message_data.created_at
  614. end_time = message_data.updated_at
  615. agent_thoughts = message_data.agent_thoughts
  616. for agent_thought in agent_thoughts:
  617. if tool_name in agent_thought.tools:
  618. created_time = agent_thought.created_at
  619. tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
  620. tool_config = tool_meta_data.get("tool_config", {})
  621. time_cost = tool_meta_data.get("time_cost", 0)
  622. end_time = created_time + timedelta(seconds=time_cost)
  623. error = tool_meta_data.get("error", "")
  624. tool_parameters = tool_meta_data.get("tool_parameters", {})
  625. metadata = {
  626. "message_id": message_id,
  627. "tool_name": tool_name,
  628. "tool_inputs": tool_inputs,
  629. "tool_outputs": tool_outputs,
  630. "tool_config": tool_config,
  631. "time_cost": time_cost,
  632. "error": error,
  633. "tool_parameters": tool_parameters,
  634. }
  635. file_url = ""
  636. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  637. if message_file_data:
  638. message_file_id = message_file_data.id if message_file_data else None
  639. type = message_file_data.type
  640. created_by_role = message_file_data.created_by_role
  641. created_user_id = message_file_data.created_by
  642. file_url = f"{self.file_base_url}/{message_file_data.url}"
  643. metadata.update(
  644. {
  645. "message_file_id": message_file_id,
  646. "created_by_role": created_by_role,
  647. "created_user_id": created_user_id,
  648. "type": type,
  649. }
  650. )
  651. tool_trace_info = ToolTraceInfo(
  652. message_id=message_id,
  653. message_data=message_data.to_dict(),
  654. tool_name=tool_name,
  655. start_time=timer.get("start") if timer else created_time,
  656. end_time=timer.get("end") if timer else end_time,
  657. tool_inputs=tool_inputs,
  658. tool_outputs=tool_outputs,
  659. metadata=metadata,
  660. message_file_data=message_file_data,
  661. error=error,
  662. inputs=message_data.message,
  663. outputs=message_data.answer,
  664. tool_config=tool_config,
  665. time_cost=time_cost,
  666. tool_parameters=tool_parameters,
  667. file_url=file_url,
  668. )
  669. return tool_trace_info
  670. def generate_name_trace(self, conversation_id, timer, **kwargs):
  671. generate_conversation_name = kwargs.get("generate_conversation_name")
  672. inputs = kwargs.get("inputs")
  673. tenant_id = kwargs.get("tenant_id")
  674. if not tenant_id:
  675. return {}
  676. start_time = timer.get("start")
  677. end_time = timer.get("end")
  678. metadata = {
  679. "conversation_id": conversation_id,
  680. "tenant_id": tenant_id,
  681. }
  682. generate_name_trace_info = GenerateNameTraceInfo(
  683. conversation_id=conversation_id,
  684. inputs=inputs,
  685. outputs=generate_conversation_name,
  686. start_time=start_time,
  687. end_time=end_time,
  688. metadata=metadata,
  689. tenant_id=tenant_id,
  690. )
  691. return generate_name_trace_info
  692. trace_manager_timer: Optional[threading.Timer] = None
  693. trace_manager_queue: queue.Queue = queue.Queue()
  694. trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
  695. trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
  696. class TraceQueueManager:
  697. def __init__(self, app_id=None, user_id=None):
  698. global trace_manager_timer
  699. self.app_id = app_id
  700. self.user_id = user_id
  701. self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
  702. self.flask_app = current_app._get_current_object() # type: ignore
  703. if trace_manager_timer is None:
  704. self.start_timer()
  705. def add_trace_task(self, trace_task: TraceTask):
  706. global trace_manager_timer, trace_manager_queue
  707. try:
  708. if self.trace_instance:
  709. trace_task.app_id = self.app_id
  710. trace_manager_queue.put(trace_task)
  711. except Exception as e:
  712. logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
  713. finally:
  714. self.start_timer()
  715. def collect_tasks(self):
  716. global trace_manager_queue
  717. tasks: list[TraceTask] = []
  718. while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
  719. task = trace_manager_queue.get_nowait()
  720. tasks.append(task)
  721. trace_manager_queue.task_done()
  722. return tasks
  723. def run(self):
  724. try:
  725. tasks = self.collect_tasks()
  726. if tasks:
  727. self.send_to_celery(tasks)
  728. except Exception as e:
  729. logging.exception("Error processing trace tasks")
  730. def start_timer(self):
  731. global trace_manager_timer
  732. if trace_manager_timer is None or not trace_manager_timer.is_alive():
  733. trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
  734. trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
  735. trace_manager_timer.daemon = False
  736. trace_manager_timer.start()
  737. def send_to_celery(self, tasks: list[TraceTask]):
  738. with self.flask_app.app_context():
  739. for task in tasks:
  740. if task.app_id is None:
  741. continue
  742. file_id = uuid4().hex
  743. trace_info = task.execute()
  744. task_data = TaskData(
  745. app_id=task.app_id,
  746. trace_info_type=type(trace_info).__name__,
  747. trace_info=trace_info.model_dump() if trace_info else None,
  748. )
  749. file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
  750. storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
  751. file_info = {
  752. "file_id": file_id,
  753. "app_id": task.app_id,
  754. }
  755. process_trace_tasks.delay(file_info)