You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_trace_manager.py 35KB


  1. import json
  2. import logging
  3. import os
  4. import queue
  5. import threading
  6. import time
  7. from datetime import timedelta
  8. from typing import TYPE_CHECKING, Any, Optional, Union
  9. from uuid import UUID, uuid4
  10. from cachetools import LRUCache
  11. from flask import current_app
  12. from sqlalchemy import select
  13. from sqlalchemy.orm import Session
  14. from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
  15. from core.ops.entities.config_entity import (
  16. OPS_FILE_PATH,
  17. TracingProviderEnum,
  18. )
  19. from core.ops.entities.trace_entity import (
  20. DatasetRetrievalTraceInfo,
  21. GenerateNameTraceInfo,
  22. MessageTraceInfo,
  23. ModerationTraceInfo,
  24. SuggestedQuestionTraceInfo,
  25. TaskData,
  26. ToolTraceInfo,
  27. TraceTaskName,
  28. WorkflowTraceInfo,
  29. )
  30. from core.ops.utils import get_message_data
  31. from extensions.ext_database import db
  32. from extensions.ext_storage import storage
  33. from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
  34. from models.workflow import WorkflowAppLog, WorkflowRun
  35. from tasks.ops_trace_task import process_trace_tasks
  36. if TYPE_CHECKING:
  37. from core.workflow.entities import WorkflowExecution
  38. logger = logging.getLogger(__name__)
  39. class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
  40. def __getitem__(self, provider: str) -> dict[str, Any]:
  41. match provider:
  42. case TracingProviderEnum.LANGFUSE:
  43. from core.ops.entities.config_entity import LangfuseConfig
  44. from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
  45. return {
  46. "config_class": LangfuseConfig,
  47. "secret_keys": ["public_key", "secret_key"],
  48. "other_keys": ["host", "project_key"],
  49. "trace_instance": LangFuseDataTrace,
  50. }
  51. case TracingProviderEnum.LANGSMITH:
  52. from core.ops.entities.config_entity import LangSmithConfig
  53. from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
  54. return {
  55. "config_class": LangSmithConfig,
  56. "secret_keys": ["api_key"],
  57. "other_keys": ["project", "endpoint"],
  58. "trace_instance": LangSmithDataTrace,
  59. }
  60. case TracingProviderEnum.OPIK:
  61. from core.ops.entities.config_entity import OpikConfig
  62. from core.ops.opik_trace.opik_trace import OpikDataTrace
  63. return {
  64. "config_class": OpikConfig,
  65. "secret_keys": ["api_key"],
  66. "other_keys": ["project", "url", "workspace"],
  67. "trace_instance": OpikDataTrace,
  68. }
  69. case TracingProviderEnum.WEAVE:
  70. from core.ops.entities.config_entity import WeaveConfig
  71. from core.ops.weave_trace.weave_trace import WeaveDataTrace
  72. return {
  73. "config_class": WeaveConfig,
  74. "secret_keys": ["api_key"],
  75. "other_keys": ["project", "entity", "endpoint", "host"],
  76. "trace_instance": WeaveDataTrace,
  77. }
  78. case TracingProviderEnum.ARIZE:
  79. from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
  80. from core.ops.entities.config_entity import ArizeConfig
  81. return {
  82. "config_class": ArizeConfig,
  83. "secret_keys": ["api_key", "space_id"],
  84. "other_keys": ["project", "endpoint"],
  85. "trace_instance": ArizePhoenixDataTrace,
  86. }
  87. case TracingProviderEnum.PHOENIX:
  88. from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
  89. from core.ops.entities.config_entity import PhoenixConfig
  90. return {
  91. "config_class": PhoenixConfig,
  92. "secret_keys": ["api_key"],
  93. "other_keys": ["project", "endpoint"],
  94. "trace_instance": ArizePhoenixDataTrace,
  95. }
  96. case TracingProviderEnum.ALIYUN:
  97. from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
  98. from core.ops.entities.config_entity import AliyunConfig
  99. return {
  100. "config_class": AliyunConfig,
  101. "secret_keys": ["license_key"],
  102. "other_keys": ["endpoint", "app_name"],
  103. "trace_instance": AliyunDataTrace,
  104. }
  105. case _:
  106. raise KeyError(f"Unsupported tracing provider: {provider}")
  107. provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
  108. class OpsTraceManager:
  109. ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
  110. @classmethod
  111. def encrypt_tracing_config(
  112. cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
  113. ):
  114. """
  115. Encrypt tracing config.
  116. :param tenant_id: tenant id
  117. :param tracing_provider: tracing provider
  118. :param tracing_config: tracing config dictionary to be encrypted
  119. :param current_trace_config: current tracing configuration for keeping existing values
  120. :return: encrypted tracing configuration
  121. """
  122. # Get the configuration class and the keys that require encryption
  123. config_class, secret_keys, other_keys = (
  124. provider_config_map[tracing_provider]["config_class"],
  125. provider_config_map[tracing_provider]["secret_keys"],
  126. provider_config_map[tracing_provider]["other_keys"],
  127. )
  128. new_config = {}
  129. # Encrypt necessary keys
  130. for key in secret_keys:
  131. if key in tracing_config:
  132. if "*" in tracing_config[key]:
  133. # If the key contains '*', retain the original value from the current config
  134. new_config[key] = current_trace_config.get(key, tracing_config[key])
  135. else:
  136. # Otherwise, encrypt the key
  137. new_config[key] = encrypt_token(tenant_id, tracing_config[key])
  138. for key in other_keys:
  139. new_config[key] = tracing_config.get(key, "")
  140. # Create a new instance of the config class with the new configuration
  141. encrypted_config = config_class(**new_config)
  142. return encrypted_config.model_dump()
  143. @classmethod
  144. def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
  145. """
  146. Decrypt tracing config
  147. :param tenant_id: tenant id
  148. :param tracing_provider: tracing provider
  149. :param tracing_config: tracing config
  150. :return:
  151. """
  152. config_class, secret_keys, other_keys = (
  153. provider_config_map[tracing_provider]["config_class"],
  154. provider_config_map[tracing_provider]["secret_keys"],
  155. provider_config_map[tracing_provider]["other_keys"],
  156. )
  157. new_config = {}
  158. for key in secret_keys:
  159. if key in tracing_config:
  160. new_config[key] = decrypt_token(tenant_id, tracing_config[key])
  161. for key in other_keys:
  162. new_config[key] = tracing_config.get(key, "")
  163. return config_class(**new_config).model_dump()
  164. @classmethod
  165. def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
  166. """
  167. Decrypt tracing config
  168. :param tracing_provider: tracing provider
  169. :param decrypt_tracing_config: tracing config
  170. :return:
  171. """
  172. config_class, secret_keys, other_keys = (
  173. provider_config_map[tracing_provider]["config_class"],
  174. provider_config_map[tracing_provider]["secret_keys"],
  175. provider_config_map[tracing_provider]["other_keys"],
  176. )
  177. new_config = {}
  178. for key in secret_keys:
  179. if key in decrypt_tracing_config:
  180. new_config[key] = obfuscated_token(decrypt_tracing_config[key])
  181. for key in other_keys:
  182. new_config[key] = decrypt_tracing_config.get(key, "")
  183. return config_class(**new_config).model_dump()
  184. @classmethod
  185. def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
  186. """
  187. Get decrypted tracing config
  188. :param app_id: app id
  189. :param tracing_provider: tracing provider
  190. :return:
  191. """
  192. trace_config_data: Optional[TraceAppConfig] = (
  193. db.session.query(TraceAppConfig)
  194. .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
  195. .first()
  196. )
  197. if not trace_config_data:
  198. return None
  199. # decrypt_token
  200. stmt = select(App).where(App.id == app_id)
  201. app = db.session.scalar(stmt)
  202. if not app:
  203. raise ValueError("App not found")
  204. tenant_id = app.tenant_id
  205. decrypt_tracing_config = cls.decrypt_tracing_config(
  206. tenant_id, tracing_provider, trace_config_data.tracing_config
  207. )
  208. return decrypt_tracing_config
  209. @classmethod
  210. def get_ops_trace_instance(
  211. cls,
  212. app_id: Optional[Union[UUID, str]] = None,
  213. ):
  214. """
  215. Get ops trace through model config
  216. :param app_id: app_id
  217. :return:
  218. """
  219. if isinstance(app_id, UUID):
  220. app_id = str(app_id)
  221. if app_id is None:
  222. return None
  223. app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
  224. if app is None:
  225. return None
  226. app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
  227. if app_ops_trace_config is None:
  228. return None
  229. if not app_ops_trace_config.get("enabled"):
  230. return None
  231. tracing_provider = app_ops_trace_config.get("tracing_provider")
  232. if tracing_provider is None:
  233. return None
  234. try:
  235. provider_config_map[tracing_provider]
  236. except KeyError:
  237. return None
  238. # decrypt_token
  239. decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
  240. if not decrypt_trace_config:
  241. return None
  242. trace_instance, config_class = (
  243. provider_config_map[tracing_provider]["trace_instance"],
  244. provider_config_map[tracing_provider]["config_class"],
  245. )
  246. decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True)
  247. tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
  248. if tracing_instance is None:
  249. # create new tracing_instance and update the cache if it absent
  250. tracing_instance = trace_instance(config_class(**decrypt_trace_config))
  251. cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
  252. logger.info("new tracing_instance for app_id: %s", app_id)
  253. return tracing_instance
  254. @classmethod
  255. def get_app_config_through_message_id(cls, message_id: str):
  256. app_model_config = None
  257. message_stmt = select(Message).where(Message.id == message_id)
  258. message_data = db.session.scalar(message_stmt)
  259. if not message_data:
  260. return None
  261. conversation_id = message_data.conversation_id
  262. conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
  263. conversation_data = db.session.scalar(conversation_stmt)
  264. if not conversation_data:
  265. return None
  266. if conversation_data.app_model_config_id:
  267. config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
  268. app_model_config = db.session.scalar(config_stmt)
  269. elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
  270. app_model_config = conversation_data.override_model_configs
  271. return app_model_config
  272. @classmethod
  273. def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
  274. """
  275. Update app tracing config
  276. :param app_id: app id
  277. :param enabled: enabled
  278. :param tracing_provider: tracing provider
  279. :return:
  280. """
  281. # auth check
  282. if enabled:
  283. try:
  284. provider_config_map[tracing_provider]
  285. except KeyError:
  286. raise ValueError(f"Invalid tracing provider: {tracing_provider}")
  287. else:
  288. if tracing_provider is None:
  289. raise ValueError(f"Invalid tracing provider: {tracing_provider}")
  290. app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
  291. if not app_config:
  292. raise ValueError("App not found")
  293. app_config.tracing = json.dumps(
  294. {
  295. "enabled": enabled,
  296. "tracing_provider": tracing_provider,
  297. }
  298. )
  299. db.session.commit()
  300. @classmethod
  301. def get_app_tracing_config(cls, app_id: str):
  302. """
  303. Get app tracing config
  304. :param app_id: app id
  305. :return:
  306. """
  307. app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
  308. if not app:
  309. raise ValueError("App not found")
  310. if not app.tracing:
  311. return {"enabled": False, "tracing_provider": None}
  312. app_trace_config = json.loads(app.tracing)
  313. return app_trace_config
  314. @staticmethod
  315. def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
  316. """
  317. Check trace config is effective
  318. :param tracing_config: tracing config
  319. :param tracing_provider: tracing provider
  320. :return:
  321. """
  322. config_type, trace_instance = (
  323. provider_config_map[tracing_provider]["config_class"],
  324. provider_config_map[tracing_provider]["trace_instance"],
  325. )
  326. tracing_config = config_type(**tracing_config)
  327. return trace_instance(tracing_config).api_check()
  328. @staticmethod
  329. def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
  330. """
  331. get trace config is project key
  332. :param tracing_config: tracing config
  333. :param tracing_provider: tracing provider
  334. :return:
  335. """
  336. config_type, trace_instance = (
  337. provider_config_map[tracing_provider]["config_class"],
  338. provider_config_map[tracing_provider]["trace_instance"],
  339. )
  340. tracing_config = config_type(**tracing_config)
  341. return trace_instance(tracing_config).get_project_key()
  342. @staticmethod
  343. def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
  344. """
  345. get trace config is project key
  346. :param tracing_config: tracing config
  347. :param tracing_provider: tracing provider
  348. :return:
  349. """
  350. config_type, trace_instance = (
  351. provider_config_map[tracing_provider]["config_class"],
  352. provider_config_map[tracing_provider]["trace_instance"],
  353. )
  354. tracing_config = config_type(**tracing_config)
  355. return trace_instance(tracing_config).get_project_url()
  356. class TraceTask:
  357. def __init__(
  358. self,
  359. trace_type: Any,
  360. message_id: Optional[str] = None,
  361. workflow_execution: Optional["WorkflowExecution"] = None,
  362. conversation_id: Optional[str] = None,
  363. user_id: Optional[str] = None,
  364. timer: Optional[Any] = None,
  365. **kwargs,
  366. ):
  367. self.trace_type = trace_type
  368. self.message_id = message_id
  369. self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
  370. self.conversation_id = conversation_id
  371. self.user_id = user_id
  372. self.timer = timer
  373. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  374. self.app_id = None
  375. self.trace_id = None
  376. self.kwargs = kwargs
  377. external_trace_id = kwargs.get("external_trace_id")
  378. if external_trace_id:
  379. self.trace_id = external_trace_id
  380. def execute(self):
  381. return self.preprocess()
  382. def preprocess(self):
  383. preprocess_map = {
  384. TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
  385. TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
  386. workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
  387. ),
  388. TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
  389. TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
  390. message_id=self.message_id, timer=self.timer, **self.kwargs
  391. ),
  392. TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
  393. message_id=self.message_id, timer=self.timer, **self.kwargs
  394. ),
  395. TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
  396. message_id=self.message_id, timer=self.timer, **self.kwargs
  397. ),
  398. TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
  399. message_id=self.message_id, timer=self.timer, **self.kwargs
  400. ),
  401. TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
  402. conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
  403. ),
  404. }
  405. return preprocess_map.get(self.trace_type, lambda: None)()
  406. # process methods for different trace types
  407. def conversation_trace(self, **kwargs):
  408. return kwargs
  409. def workflow_trace(
  410. self,
  411. *,
  412. workflow_run_id: str | None,
  413. conversation_id: str | None,
  414. user_id: str | None,
  415. ):
  416. if not workflow_run_id:
  417. return {}
  418. with Session(db.engine) as session:
  419. workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
  420. workflow_run = session.scalars(workflow_run_stmt).first()
  421. if not workflow_run:
  422. raise ValueError("Workflow run not found")
  423. workflow_id = workflow_run.workflow_id
  424. tenant_id = workflow_run.tenant_id
  425. workflow_run_id = workflow_run.id
  426. workflow_run_elapsed_time = workflow_run.elapsed_time
  427. workflow_run_status = workflow_run.status
  428. workflow_run_inputs = workflow_run.inputs_dict
  429. workflow_run_outputs = workflow_run.outputs_dict
  430. workflow_run_version = workflow_run.version
  431. error = workflow_run.error or ""
  432. total_tokens = workflow_run.total_tokens
  433. file_list = workflow_run_inputs.get("sys.file") or []
  434. query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
  435. # get workflow_app_log_id
  436. workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
  437. WorkflowAppLog.tenant_id == tenant_id,
  438. WorkflowAppLog.app_id == workflow_run.app_id,
  439. WorkflowAppLog.workflow_run_id == workflow_run.id,
  440. )
  441. workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
  442. # get message_id
  443. message_id = None
  444. if conversation_id:
  445. message_data_stmt = select(Message.id).where(
  446. Message.conversation_id == conversation_id,
  447. Message.workflow_run_id == workflow_run_id,
  448. )
  449. message_id = session.scalar(message_data_stmt)
  450. metadata = {
  451. "workflow_id": workflow_id,
  452. "conversation_id": conversation_id,
  453. "workflow_run_id": workflow_run_id,
  454. "tenant_id": tenant_id,
  455. "elapsed_time": workflow_run_elapsed_time,
  456. "status": workflow_run_status,
  457. "version": workflow_run_version,
  458. "total_tokens": total_tokens,
  459. "file_list": file_list,
  460. "triggered_from": workflow_run.triggered_from,
  461. "user_id": user_id,
  462. "app_id": workflow_run.app_id,
  463. }
  464. workflow_trace_info = WorkflowTraceInfo(
  465. trace_id=self.trace_id,
  466. workflow_data=workflow_run.to_dict(),
  467. conversation_id=conversation_id,
  468. workflow_id=workflow_id,
  469. tenant_id=tenant_id,
  470. workflow_run_id=workflow_run_id,
  471. workflow_run_elapsed_time=workflow_run_elapsed_time,
  472. workflow_run_status=workflow_run_status,
  473. workflow_run_inputs=workflow_run_inputs,
  474. workflow_run_outputs=workflow_run_outputs,
  475. workflow_run_version=workflow_run_version,
  476. error=error,
  477. total_tokens=total_tokens,
  478. file_list=file_list,
  479. query=query,
  480. metadata=metadata,
  481. workflow_app_log_id=workflow_app_log_id,
  482. message_id=message_id,
  483. start_time=workflow_run.created_at,
  484. end_time=workflow_run.finished_at,
  485. )
  486. return workflow_trace_info
  487. def message_trace(self, message_id: str | None):
  488. if not message_id:
  489. return {}
  490. message_data = get_message_data(message_id)
  491. if not message_data:
  492. return {}
  493. conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
  494. conversation_mode = db.session.scalars(conversation_mode_stmt).all()
  495. if not conversation_mode or len(conversation_mode) == 0:
  496. return {}
  497. conversation_mode = conversation_mode[0]
  498. created_at = message_data.created_at
  499. inputs = message_data.message
  500. # get message file data
  501. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  502. file_list = []
  503. if message_file_data and message_file_data.url is not None:
  504. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  505. file_list.append(file_url)
  506. metadata = {
  507. "conversation_id": message_data.conversation_id,
  508. "ls_provider": message_data.model_provider,
  509. "ls_model_name": message_data.model_id,
  510. "status": message_data.status,
  511. "from_end_user_id": message_data.from_end_user_id,
  512. "from_account_id": message_data.from_account_id,
  513. "agent_based": message_data.agent_based,
  514. "workflow_run_id": message_data.workflow_run_id,
  515. "from_source": message_data.from_source,
  516. "message_id": message_id,
  517. }
  518. message_tokens = message_data.message_tokens
  519. message_trace_info = MessageTraceInfo(
  520. trace_id=self.trace_id,
  521. message_id=message_id,
  522. message_data=message_data.to_dict(),
  523. conversation_model=conversation_mode,
  524. message_tokens=message_tokens,
  525. answer_tokens=message_data.answer_tokens,
  526. total_tokens=message_tokens + message_data.answer_tokens,
  527. error=message_data.error or "",
  528. inputs=inputs,
  529. outputs=message_data.answer,
  530. file_list=file_list,
  531. start_time=created_at,
  532. end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
  533. metadata=metadata,
  534. message_file_data=message_file_data,
  535. conversation_mode=conversation_mode,
  536. )
  537. return message_trace_info
  538. def moderation_trace(self, message_id, timer, **kwargs):
  539. moderation_result = kwargs.get("moderation_result")
  540. if not moderation_result:
  541. return {}
  542. inputs = kwargs.get("inputs")
  543. message_data = get_message_data(message_id)
  544. if not message_data:
  545. return {}
  546. metadata = {
  547. "message_id": message_id,
  548. "action": moderation_result.action,
  549. "preset_response": moderation_result.preset_response,
  550. "query": moderation_result.query,
  551. }
  552. # get workflow_app_log_id
  553. workflow_app_log_id = None
  554. if message_data.workflow_run_id:
  555. workflow_app_log_data = (
  556. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  557. )
  558. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  559. moderation_trace_info = ModerationTraceInfo(
  560. trace_id=self.trace_id,
  561. message_id=workflow_app_log_id or message_id,
  562. inputs=inputs,
  563. message_data=message_data.to_dict(),
  564. flagged=moderation_result.flagged,
  565. action=moderation_result.action,
  566. preset_response=moderation_result.preset_response,
  567. query=moderation_result.query,
  568. start_time=timer.get("start"),
  569. end_time=timer.get("end"),
  570. metadata=metadata,
  571. )
  572. return moderation_trace_info
  573. def suggested_question_trace(self, message_id, timer, **kwargs):
  574. suggested_question = kwargs.get("suggested_question", [])
  575. message_data = get_message_data(message_id)
  576. if not message_data:
  577. return {}
  578. metadata = {
  579. "message_id": message_id,
  580. "ls_provider": message_data.model_provider,
  581. "ls_model_name": message_data.model_id,
  582. "status": message_data.status,
  583. "from_end_user_id": message_data.from_end_user_id,
  584. "from_account_id": message_data.from_account_id,
  585. "agent_based": message_data.agent_based,
  586. "workflow_run_id": message_data.workflow_run_id,
  587. "from_source": message_data.from_source,
  588. }
  589. # get workflow_app_log_id
  590. workflow_app_log_id = None
  591. if message_data.workflow_run_id:
  592. workflow_app_log_data = (
  593. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  594. )
  595. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  596. suggested_question_trace_info = SuggestedQuestionTraceInfo(
  597. trace_id=self.trace_id,
  598. message_id=workflow_app_log_id or message_id,
  599. message_data=message_data.to_dict(),
  600. inputs=message_data.message,
  601. outputs=message_data.answer,
  602. start_time=timer.get("start"),
  603. end_time=timer.get("end"),
  604. metadata=metadata,
  605. total_tokens=message_data.message_tokens + message_data.answer_tokens,
  606. status=message_data.status,
  607. error=message_data.error,
  608. from_account_id=message_data.from_account_id,
  609. agent_based=message_data.agent_based,
  610. from_source=message_data.from_source,
  611. model_provider=message_data.model_provider,
  612. model_id=message_data.model_id,
  613. suggested_question=suggested_question,
  614. level=message_data.status,
  615. status_message=message_data.error,
  616. )
  617. return suggested_question_trace_info
  618. def dataset_retrieval_trace(self, message_id, timer, **kwargs):
  619. documents = kwargs.get("documents")
  620. message_data = get_message_data(message_id)
  621. if not message_data:
  622. return {}
  623. metadata = {
  624. "message_id": message_id,
  625. "ls_provider": message_data.model_provider,
  626. "ls_model_name": message_data.model_id,
  627. "status": message_data.status,
  628. "from_end_user_id": message_data.from_end_user_id,
  629. "from_account_id": message_data.from_account_id,
  630. "agent_based": message_data.agent_based,
  631. "workflow_run_id": message_data.workflow_run_id,
  632. "from_source": message_data.from_source,
  633. }
  634. dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
  635. trace_id=self.trace_id,
  636. message_id=message_id,
  637. inputs=message_data.query or message_data.inputs,
  638. documents=[doc.model_dump() for doc in documents] if documents else [],
  639. start_time=timer.get("start"),
  640. end_time=timer.get("end"),
  641. metadata=metadata,
  642. message_data=message_data.to_dict(),
  643. )
  644. return dataset_retrieval_trace_info
  645. def tool_trace(self, message_id, timer, **kwargs):
  646. tool_name = kwargs.get("tool_name", "")
  647. tool_inputs = kwargs.get("tool_inputs", {})
  648. tool_outputs = kwargs.get("tool_outputs", {})
  649. message_data = get_message_data(message_id)
  650. if not message_data:
  651. return {}
  652. tool_config = {}
  653. time_cost = 0
  654. error = None
  655. tool_parameters = {}
  656. created_time = message_data.created_at
  657. end_time = message_data.updated_at
  658. agent_thoughts = message_data.agent_thoughts
  659. for agent_thought in agent_thoughts:
  660. if tool_name in agent_thought.tools:
  661. created_time = agent_thought.created_at
  662. tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
  663. tool_config = tool_meta_data.get("tool_config", {})
  664. time_cost = tool_meta_data.get("time_cost", 0)
  665. end_time = created_time + timedelta(seconds=time_cost)
  666. error = tool_meta_data.get("error", "")
  667. tool_parameters = tool_meta_data.get("tool_parameters", {})
  668. metadata = {
  669. "message_id": message_id,
  670. "tool_name": tool_name,
  671. "tool_inputs": tool_inputs,
  672. "tool_outputs": tool_outputs,
  673. "tool_config": tool_config,
  674. "time_cost": time_cost,
  675. "error": error,
  676. "tool_parameters": tool_parameters,
  677. }
  678. file_url = ""
  679. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  680. if message_file_data:
  681. message_file_id = message_file_data.id if message_file_data else None
  682. type = message_file_data.type
  683. created_by_role = message_file_data.created_by_role
  684. created_user_id = message_file_data.created_by
  685. file_url = f"{self.file_base_url}/{message_file_data.url}"
  686. metadata.update(
  687. {
  688. "message_file_id": message_file_id,
  689. "created_by_role": created_by_role,
  690. "created_user_id": created_user_id,
  691. "type": type,
  692. }
  693. )
  694. tool_trace_info = ToolTraceInfo(
  695. trace_id=self.trace_id,
  696. message_id=message_id,
  697. message_data=message_data.to_dict(),
  698. tool_name=tool_name,
  699. start_time=timer.get("start") if timer else created_time,
  700. end_time=timer.get("end") if timer else end_time,
  701. tool_inputs=tool_inputs,
  702. tool_outputs=tool_outputs,
  703. metadata=metadata,
  704. message_file_data=message_file_data,
  705. error=error,
  706. inputs=message_data.message,
  707. outputs=message_data.answer,
  708. tool_config=tool_config,
  709. time_cost=time_cost,
  710. tool_parameters=tool_parameters,
  711. file_url=file_url,
  712. )
  713. return tool_trace_info
  714. def generate_name_trace(self, conversation_id, timer, **kwargs):
  715. generate_conversation_name = kwargs.get("generate_conversation_name")
  716. inputs = kwargs.get("inputs")
  717. tenant_id = kwargs.get("tenant_id")
  718. if not tenant_id:
  719. return {}
  720. start_time = timer.get("start")
  721. end_time = timer.get("end")
  722. metadata = {
  723. "conversation_id": conversation_id,
  724. "tenant_id": tenant_id,
  725. }
  726. generate_name_trace_info = GenerateNameTraceInfo(
  727. trace_id=self.trace_id,
  728. conversation_id=conversation_id,
  729. inputs=inputs,
  730. outputs=generate_conversation_name,
  731. start_time=start_time,
  732. end_time=end_time,
  733. metadata=metadata,
  734. tenant_id=tenant_id,
  735. )
  736. return generate_name_trace_info
  737. trace_manager_timer: Optional[threading.Timer] = None
  738. trace_manager_queue: queue.Queue = queue.Queue()
  739. trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
  740. trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
  741. class TraceQueueManager:
  742. def __init__(self, app_id=None, user_id=None):
  743. global trace_manager_timer
  744. self.app_id = app_id
  745. self.user_id = user_id
  746. self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
  747. self.flask_app = current_app._get_current_object() # type: ignore
  748. if trace_manager_timer is None:
  749. self.start_timer()
  750. def add_trace_task(self, trace_task: TraceTask):
  751. global trace_manager_timer, trace_manager_queue
  752. try:
  753. if self.trace_instance:
  754. trace_task.app_id = self.app_id
  755. trace_manager_queue.put(trace_task)
  756. except Exception:
  757. logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
  758. finally:
  759. self.start_timer()
  760. def collect_tasks(self):
  761. global trace_manager_queue
  762. tasks: list[TraceTask] = []
  763. while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
  764. task = trace_manager_queue.get_nowait()
  765. tasks.append(task)
  766. trace_manager_queue.task_done()
  767. return tasks
  768. def run(self):
  769. try:
  770. tasks = self.collect_tasks()
  771. if tasks:
  772. self.send_to_celery(tasks)
  773. except Exception:
  774. logger.exception("Error processing trace tasks")
  775. def start_timer(self):
  776. global trace_manager_timer
  777. if trace_manager_timer is None or not trace_manager_timer.is_alive():
  778. trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
  779. trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
  780. trace_manager_timer.daemon = False
  781. trace_manager_timer.start()
  782. def send_to_celery(self, tasks: list[TraceTask]):
  783. with self.flask_app.app_context():
  784. for task in tasks:
  785. if task.app_id is None:
  786. continue
  787. file_id = uuid4().hex
  788. trace_info = task.execute()
  789. task_data = TaskData(
  790. app_id=task.app_id,
  791. trace_info_type=type(trace_info).__name__,
  792. trace_info=trace_info.model_dump() if trace_info else None,
  793. )
  794. file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
  795. storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
  796. file_info = {
  797. "file_id": file_id,
  798. "app_id": task.app_id,
  799. }
  800. process_trace_tasks.delay(file_info)