Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

ops_trace_manager.py 31KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826
  1. import json
  2. import logging
  3. import os
  4. import queue
  5. import threading
  6. import time
  7. from datetime import timedelta
  8. from typing import Any, Optional, Union
  9. from uuid import UUID, uuid4
  10. from cachetools import LRUCache
  11. from flask import current_app
  12. from sqlalchemy import select
  13. from sqlalchemy.orm import Session
  14. from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
  15. from core.ops.entities.config_entity import (
  16. OPS_FILE_PATH,
  17. LangfuseConfig,
  18. LangSmithConfig,
  19. OpikConfig,
  20. TracingProviderEnum,
  21. )
  22. from core.ops.entities.trace_entity import (
  23. DatasetRetrievalTraceInfo,
  24. GenerateNameTraceInfo,
  25. MessageTraceInfo,
  26. ModerationTraceInfo,
  27. SuggestedQuestionTraceInfo,
  28. TaskData,
  29. ToolTraceInfo,
  30. TraceTaskName,
  31. WorkflowTraceInfo,
  32. )
  33. from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
  34. from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
  35. from core.ops.utils import get_message_data
  36. from extensions.ext_database import db
  37. from extensions.ext_storage import storage
  38. from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
  39. from models.workflow import WorkflowAppLog, WorkflowRun
  40. from tasks.ops_trace_task import process_trace_tasks
  41. def build_opik_trace_instance(config: OpikConfig):
  42. from core.ops.opik_trace.opik_trace import OpikDataTrace
  43. return OpikDataTrace(config)
  44. provider_config_map: dict[str, dict[str, Any]] = {
  45. TracingProviderEnum.LANGFUSE.value: {
  46. "config_class": LangfuseConfig,
  47. "secret_keys": ["public_key", "secret_key"],
  48. "other_keys": ["host", "project_key"],
  49. "trace_instance": LangFuseDataTrace,
  50. },
  51. TracingProviderEnum.LANGSMITH.value: {
  52. "config_class": LangSmithConfig,
  53. "secret_keys": ["api_key"],
  54. "other_keys": ["project", "endpoint"],
  55. "trace_instance": LangSmithDataTrace,
  56. },
  57. TracingProviderEnum.OPIK.value: {
  58. "config_class": OpikConfig,
  59. "secret_keys": ["api_key"],
  60. "other_keys": ["project", "url", "workspace"],
  61. "trace_instance": lambda config: build_opik_trace_instance(config),
  62. },
  63. }
  64. class OpsTraceManager:
  65. ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
  66. @classmethod
  67. def encrypt_tracing_config(
  68. cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
  69. ):
  70. """
  71. Encrypt tracing config.
  72. :param tenant_id: tenant id
  73. :param tracing_provider: tracing provider
  74. :param tracing_config: tracing config dictionary to be encrypted
  75. :param current_trace_config: current tracing configuration for keeping existing values
  76. :return: encrypted tracing configuration
  77. """
  78. # Get the configuration class and the keys that require encryption
  79. config_class, secret_keys, other_keys = (
  80. provider_config_map[tracing_provider]["config_class"],
  81. provider_config_map[tracing_provider]["secret_keys"],
  82. provider_config_map[tracing_provider]["other_keys"],
  83. )
  84. new_config = {}
  85. # Encrypt necessary keys
  86. for key in secret_keys:
  87. if key in tracing_config:
  88. if "*" in tracing_config[key]:
  89. # If the key contains '*', retain the original value from the current config
  90. new_config[key] = current_trace_config.get(key, tracing_config[key])
  91. else:
  92. # Otherwise, encrypt the key
  93. new_config[key] = encrypt_token(tenant_id, tracing_config[key])
  94. for key in other_keys:
  95. new_config[key] = tracing_config.get(key, "")
  96. # Create a new instance of the config class with the new configuration
  97. encrypted_config = config_class(**new_config)
  98. return encrypted_config.model_dump()
  99. @classmethod
  100. def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
  101. """
  102. Decrypt tracing config
  103. :param tenant_id: tenant id
  104. :param tracing_provider: tracing provider
  105. :param tracing_config: tracing config
  106. :return:
  107. """
  108. config_class, secret_keys, other_keys = (
  109. provider_config_map[tracing_provider]["config_class"],
  110. provider_config_map[tracing_provider]["secret_keys"],
  111. provider_config_map[tracing_provider]["other_keys"],
  112. )
  113. new_config = {}
  114. for key in secret_keys:
  115. if key in tracing_config:
  116. new_config[key] = decrypt_token(tenant_id, tracing_config[key])
  117. for key in other_keys:
  118. new_config[key] = tracing_config.get(key, "")
  119. return config_class(**new_config).model_dump()
  120. @classmethod
  121. def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
  122. """
  123. Decrypt tracing config
  124. :param tracing_provider: tracing provider
  125. :param decrypt_tracing_config: tracing config
  126. :return:
  127. """
  128. config_class, secret_keys, other_keys = (
  129. provider_config_map[tracing_provider]["config_class"],
  130. provider_config_map[tracing_provider]["secret_keys"],
  131. provider_config_map[tracing_provider]["other_keys"],
  132. )
  133. new_config = {}
  134. for key in secret_keys:
  135. if key in decrypt_tracing_config:
  136. new_config[key] = obfuscated_token(decrypt_tracing_config[key])
  137. for key in other_keys:
  138. new_config[key] = decrypt_tracing_config.get(key, "")
  139. return config_class(**new_config).model_dump()
  140. @classmethod
  141. def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
  142. """
  143. Get decrypted tracing config
  144. :param app_id: app id
  145. :param tracing_provider: tracing provider
  146. :return:
  147. """
  148. trace_config_data: Optional[TraceAppConfig] = (
  149. db.session.query(TraceAppConfig)
  150. .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
  151. .first()
  152. )
  153. if not trace_config_data:
  154. return None
  155. # decrypt_token
  156. app = db.session.query(App).filter(App.id == app_id).first()
  157. if not app:
  158. raise ValueError("App not found")
  159. tenant_id = app.tenant_id
  160. decrypt_tracing_config = cls.decrypt_tracing_config(
  161. tenant_id, tracing_provider, trace_config_data.tracing_config
  162. )
  163. return decrypt_tracing_config
  164. @classmethod
  165. def get_ops_trace_instance(
  166. cls,
  167. app_id: Optional[Union[UUID, str]] = None,
  168. ):
  169. """
  170. Get ops trace through model config
  171. :param app_id: app_id
  172. :return:
  173. """
  174. if isinstance(app_id, UUID):
  175. app_id = str(app_id)
  176. if app_id is None:
  177. return None
  178. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  179. if app is None:
  180. return None
  181. app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
  182. if app_ops_trace_config is None:
  183. return None
  184. if not app_ops_trace_config.get("enabled"):
  185. return None
  186. tracing_provider = app_ops_trace_config.get("tracing_provider")
  187. if tracing_provider is None or tracing_provider not in provider_config_map:
  188. return None
  189. # decrypt_token
  190. decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
  191. if not decrypt_trace_config:
  192. return None
  193. trace_instance, config_class = (
  194. provider_config_map[tracing_provider]["trace_instance"],
  195. provider_config_map[tracing_provider]["config_class"],
  196. )
  197. decrypt_trace_config_key = str(decrypt_trace_config)
  198. tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
  199. if tracing_instance is None:
  200. # create new tracing_instance and update the cache if it absent
  201. tracing_instance = trace_instance(config_class(**decrypt_trace_config))
  202. cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
  203. logging.info(f"new tracing_instance for app_id: {app_id}")
  204. return tracing_instance
  205. @classmethod
  206. def get_app_config_through_message_id(cls, message_id: str):
  207. app_model_config = None
  208. message_data = db.session.query(Message).filter(Message.id == message_id).first()
  209. if not message_data:
  210. return None
  211. conversation_id = message_data.conversation_id
  212. conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
  213. if not conversation_data:
  214. return None
  215. if conversation_data.app_model_config_id:
  216. app_model_config = (
  217. db.session.query(AppModelConfig)
  218. .filter(AppModelConfig.id == conversation_data.app_model_config_id)
  219. .first()
  220. )
  221. elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
  222. app_model_config = conversation_data.override_model_configs
  223. return app_model_config
  224. @classmethod
  225. def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
  226. """
  227. Update app tracing config
  228. :param app_id: app id
  229. :param enabled: enabled
  230. :param tracing_provider: tracing provider
  231. :return:
  232. """
  233. # auth check
  234. if tracing_provider not in provider_config_map and tracing_provider is not None:
  235. raise ValueError(f"Invalid tracing provider: {tracing_provider}")
  236. app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  237. if not app_config:
  238. raise ValueError("App not found")
  239. app_config.tracing = json.dumps(
  240. {
  241. "enabled": enabled,
  242. "tracing_provider": tracing_provider,
  243. }
  244. )
  245. db.session.commit()
  246. @classmethod
  247. def get_app_tracing_config(cls, app_id: str):
  248. """
  249. Get app tracing config
  250. :param app_id: app id
  251. :return:
  252. """
  253. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  254. if not app:
  255. raise ValueError("App not found")
  256. if not app.tracing:
  257. return {"enabled": False, "tracing_provider": None}
  258. app_trace_config = json.loads(app.tracing)
  259. return app_trace_config
  260. @staticmethod
  261. def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
  262. """
  263. Check trace config is effective
  264. :param tracing_config: tracing config
  265. :param tracing_provider: tracing provider
  266. :return:
  267. """
  268. config_type, trace_instance = (
  269. provider_config_map[tracing_provider]["config_class"],
  270. provider_config_map[tracing_provider]["trace_instance"],
  271. )
  272. tracing_config = config_type(**tracing_config)
  273. return trace_instance(tracing_config).api_check()
  274. @staticmethod
  275. def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
  276. """
  277. get trace config is project key
  278. :param tracing_config: tracing config
  279. :param tracing_provider: tracing provider
  280. :return:
  281. """
  282. config_type, trace_instance = (
  283. provider_config_map[tracing_provider]["config_class"],
  284. provider_config_map[tracing_provider]["trace_instance"],
  285. )
  286. tracing_config = config_type(**tracing_config)
  287. return trace_instance(tracing_config).get_project_key()
  288. @staticmethod
  289. def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
  290. """
  291. get trace config is project key
  292. :param tracing_config: tracing config
  293. :param tracing_provider: tracing provider
  294. :return:
  295. """
  296. config_type, trace_instance = (
  297. provider_config_map[tracing_provider]["config_class"],
  298. provider_config_map[tracing_provider]["trace_instance"],
  299. )
  300. tracing_config = config_type(**tracing_config)
  301. return trace_instance(tracing_config).get_project_url()
  302. class TraceTask:
  303. def __init__(
  304. self,
  305. trace_type: Any,
  306. message_id: Optional[str] = None,
  307. workflow_run: Optional[WorkflowRun] = None,
  308. conversation_id: Optional[str] = None,
  309. user_id: Optional[str] = None,
  310. timer: Optional[Any] = None,
  311. **kwargs,
  312. ):
  313. self.trace_type = trace_type
  314. self.message_id = message_id
  315. self.workflow_run_id = workflow_run.id if workflow_run else None
  316. self.conversation_id = conversation_id
  317. self.user_id = user_id
  318. self.timer = timer
  319. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  320. self.app_id = None
  321. self.kwargs = kwargs
  322. def execute(self):
  323. return self.preprocess()
  324. def preprocess(self):
  325. preprocess_map = {
  326. TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
  327. TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
  328. workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
  329. ),
  330. TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
  331. TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
  332. message_id=self.message_id, timer=self.timer, **self.kwargs
  333. ),
  334. TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
  335. message_id=self.message_id, timer=self.timer, **self.kwargs
  336. ),
  337. TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
  338. message_id=self.message_id, timer=self.timer, **self.kwargs
  339. ),
  340. TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
  341. message_id=self.message_id, timer=self.timer, **self.kwargs
  342. ),
  343. TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
  344. conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
  345. ),
  346. }
  347. return preprocess_map.get(self.trace_type, lambda: None)()
  348. # process methods for different trace types
  349. def conversation_trace(self, **kwargs):
  350. return kwargs
  351. def workflow_trace(
  352. self,
  353. *,
  354. workflow_run_id: str | None,
  355. conversation_id: str | None,
  356. user_id: str | None,
  357. ):
  358. if not workflow_run_id:
  359. return {}
  360. with Session(db.engine) as session:
  361. workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
  362. workflow_run = session.scalars(workflow_run_stmt).first()
  363. if not workflow_run:
  364. raise ValueError("Workflow run not found")
  365. workflow_id = workflow_run.workflow_id
  366. tenant_id = workflow_run.tenant_id
  367. workflow_run_id = workflow_run.id
  368. workflow_run_elapsed_time = workflow_run.elapsed_time
  369. workflow_run_status = workflow_run.status
  370. workflow_run_inputs = workflow_run.inputs_dict
  371. workflow_run_outputs = workflow_run.outputs_dict
  372. workflow_run_version = workflow_run.version
  373. error = workflow_run.error or ""
  374. total_tokens = workflow_run.total_tokens
  375. file_list = workflow_run_inputs.get("sys.file") or []
  376. query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
  377. # get workflow_app_log_id
  378. workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
  379. WorkflowAppLog.tenant_id == tenant_id,
  380. WorkflowAppLog.app_id == workflow_run.app_id,
  381. WorkflowAppLog.workflow_run_id == workflow_run.id,
  382. )
  383. workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
  384. # get message_id
  385. message_id = None
  386. if conversation_id:
  387. message_data_stmt = select(Message.id).where(
  388. Message.conversation_id == conversation_id,
  389. Message.workflow_run_id == workflow_run_id,
  390. )
  391. message_id = session.scalar(message_data_stmt)
  392. metadata = {
  393. "workflow_id": workflow_id,
  394. "conversation_id": conversation_id,
  395. "workflow_run_id": workflow_run_id,
  396. "tenant_id": tenant_id,
  397. "elapsed_time": workflow_run_elapsed_time,
  398. "status": workflow_run_status,
  399. "version": workflow_run_version,
  400. "total_tokens": total_tokens,
  401. "file_list": file_list,
  402. "triggered_form": workflow_run.triggered_from,
  403. "user_id": user_id,
  404. }
  405. workflow_trace_info = WorkflowTraceInfo(
  406. workflow_data=workflow_run.to_dict(),
  407. conversation_id=conversation_id,
  408. workflow_id=workflow_id,
  409. tenant_id=tenant_id,
  410. workflow_run_id=workflow_run_id,
  411. workflow_run_elapsed_time=workflow_run_elapsed_time,
  412. workflow_run_status=workflow_run_status,
  413. workflow_run_inputs=workflow_run_inputs,
  414. workflow_run_outputs=workflow_run_outputs,
  415. workflow_run_version=workflow_run_version,
  416. error=error,
  417. total_tokens=total_tokens,
  418. file_list=file_list,
  419. query=query,
  420. metadata=metadata,
  421. workflow_app_log_id=workflow_app_log_id,
  422. message_id=message_id,
  423. start_time=workflow_run.created_at,
  424. end_time=workflow_run.finished_at,
  425. )
  426. return workflow_trace_info
  427. def message_trace(self, message_id: str | None):
  428. if not message_id:
  429. return {}
  430. message_data = get_message_data(message_id)
  431. if not message_data:
  432. return {}
  433. conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
  434. conversation_mode = db.session.scalars(conversation_mode_stmt).all()
  435. if not conversation_mode or len(conversation_mode) == 0:
  436. return {}
  437. conversation_mode = conversation_mode[0]
  438. created_at = message_data.created_at
  439. inputs = message_data.message
  440. # get message file data
  441. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  442. file_list = []
  443. if message_file_data and message_file_data.url is not None:
  444. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  445. file_list.append(file_url)
  446. metadata = {
  447. "conversation_id": message_data.conversation_id,
  448. "ls_provider": message_data.model_provider,
  449. "ls_model_name": message_data.model_id,
  450. "status": message_data.status,
  451. "from_end_user_id": message_data.from_end_user_id,
  452. "from_account_id": message_data.from_account_id,
  453. "agent_based": message_data.agent_based,
  454. "workflow_run_id": message_data.workflow_run_id,
  455. "from_source": message_data.from_source,
  456. "message_id": message_id,
  457. }
  458. message_tokens = message_data.message_tokens
  459. message_trace_info = MessageTraceInfo(
  460. message_id=message_id,
  461. message_data=message_data.to_dict(),
  462. conversation_model=conversation_mode,
  463. message_tokens=message_tokens,
  464. answer_tokens=message_data.answer_tokens,
  465. total_tokens=message_tokens + message_data.answer_tokens,
  466. error=message_data.error or "",
  467. inputs=inputs,
  468. outputs=message_data.answer,
  469. file_list=file_list,
  470. start_time=created_at,
  471. end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
  472. metadata=metadata,
  473. message_file_data=message_file_data,
  474. conversation_mode=conversation_mode,
  475. )
  476. return message_trace_info
  477. def moderation_trace(self, message_id, timer, **kwargs):
  478. moderation_result = kwargs.get("moderation_result")
  479. if not moderation_result:
  480. return {}
  481. inputs = kwargs.get("inputs")
  482. message_data = get_message_data(message_id)
  483. if not message_data:
  484. return {}
  485. metadata = {
  486. "message_id": message_id,
  487. "action": moderation_result.action,
  488. "preset_response": moderation_result.preset_response,
  489. "query": moderation_result.query,
  490. }
  491. # get workflow_app_log_id
  492. workflow_app_log_id = None
  493. if message_data.workflow_run_id:
  494. workflow_app_log_data = (
  495. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  496. )
  497. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  498. moderation_trace_info = ModerationTraceInfo(
  499. message_id=workflow_app_log_id or message_id,
  500. inputs=inputs,
  501. message_data=message_data.to_dict(),
  502. flagged=moderation_result.flagged,
  503. action=moderation_result.action,
  504. preset_response=moderation_result.preset_response,
  505. query=moderation_result.query,
  506. start_time=timer.get("start"),
  507. end_time=timer.get("end"),
  508. metadata=metadata,
  509. )
  510. return moderation_trace_info
  511. def suggested_question_trace(self, message_id, timer, **kwargs):
  512. suggested_question = kwargs.get("suggested_question", [])
  513. message_data = get_message_data(message_id)
  514. if not message_data:
  515. return {}
  516. metadata = {
  517. "message_id": message_id,
  518. "ls_provider": message_data.model_provider,
  519. "ls_model_name": message_data.model_id,
  520. "status": message_data.status,
  521. "from_end_user_id": message_data.from_end_user_id,
  522. "from_account_id": message_data.from_account_id,
  523. "agent_based": message_data.agent_based,
  524. "workflow_run_id": message_data.workflow_run_id,
  525. "from_source": message_data.from_source,
  526. }
  527. # get workflow_app_log_id
  528. workflow_app_log_id = None
  529. if message_data.workflow_run_id:
  530. workflow_app_log_data = (
  531. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  532. )
  533. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  534. suggested_question_trace_info = SuggestedQuestionTraceInfo(
  535. message_id=workflow_app_log_id or message_id,
  536. message_data=message_data.to_dict(),
  537. inputs=message_data.message,
  538. outputs=message_data.answer,
  539. start_time=timer.get("start"),
  540. end_time=timer.get("end"),
  541. metadata=metadata,
  542. total_tokens=message_data.message_tokens + message_data.answer_tokens,
  543. status=message_data.status,
  544. error=message_data.error,
  545. from_account_id=message_data.from_account_id,
  546. agent_based=message_data.agent_based,
  547. from_source=message_data.from_source,
  548. model_provider=message_data.model_provider,
  549. model_id=message_data.model_id,
  550. suggested_question=suggested_question,
  551. level=message_data.status,
  552. status_message=message_data.error,
  553. )
  554. return suggested_question_trace_info
  555. def dataset_retrieval_trace(self, message_id, timer, **kwargs):
  556. documents = kwargs.get("documents")
  557. message_data = get_message_data(message_id)
  558. if not message_data:
  559. return {}
  560. metadata = {
  561. "message_id": message_id,
  562. "ls_provider": message_data.model_provider,
  563. "ls_model_name": message_data.model_id,
  564. "status": message_data.status,
  565. "from_end_user_id": message_data.from_end_user_id,
  566. "from_account_id": message_data.from_account_id,
  567. "agent_based": message_data.agent_based,
  568. "workflow_run_id": message_data.workflow_run_id,
  569. "from_source": message_data.from_source,
  570. }
  571. dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
  572. message_id=message_id,
  573. inputs=message_data.query or message_data.inputs,
  574. documents=[doc.model_dump() for doc in documents] if documents else [],
  575. start_time=timer.get("start"),
  576. end_time=timer.get("end"),
  577. metadata=metadata,
  578. message_data=message_data.to_dict(),
  579. )
  580. return dataset_retrieval_trace_info
  581. def tool_trace(self, message_id, timer, **kwargs):
  582. tool_name = kwargs.get("tool_name", "")
  583. tool_inputs = kwargs.get("tool_inputs", {})
  584. tool_outputs = kwargs.get("tool_outputs", {})
  585. message_data = get_message_data(message_id)
  586. if not message_data:
  587. return {}
  588. tool_config = {}
  589. time_cost = 0
  590. error = None
  591. tool_parameters = {}
  592. created_time = message_data.created_at
  593. end_time = message_data.updated_at
  594. agent_thoughts = message_data.agent_thoughts
  595. for agent_thought in agent_thoughts:
  596. if tool_name in agent_thought.tools:
  597. created_time = agent_thought.created_at
  598. tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
  599. tool_config = tool_meta_data.get("tool_config", {})
  600. time_cost = tool_meta_data.get("time_cost", 0)
  601. end_time = created_time + timedelta(seconds=time_cost)
  602. error = tool_meta_data.get("error", "")
  603. tool_parameters = tool_meta_data.get("tool_parameters", {})
  604. metadata = {
  605. "message_id": message_id,
  606. "tool_name": tool_name,
  607. "tool_inputs": tool_inputs,
  608. "tool_outputs": tool_outputs,
  609. "tool_config": tool_config,
  610. "time_cost": time_cost,
  611. "error": error,
  612. "tool_parameters": tool_parameters,
  613. }
  614. file_url = ""
  615. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  616. if message_file_data:
  617. message_file_id = message_file_data.id if message_file_data else None
  618. type = message_file_data.type
  619. created_by_role = message_file_data.created_by_role
  620. created_user_id = message_file_data.created_by
  621. file_url = f"{self.file_base_url}/{message_file_data.url}"
  622. metadata.update(
  623. {
  624. "message_file_id": message_file_id,
  625. "created_by_role": created_by_role,
  626. "created_user_id": created_user_id,
  627. "type": type,
  628. }
  629. )
  630. tool_trace_info = ToolTraceInfo(
  631. message_id=message_id,
  632. message_data=message_data.to_dict(),
  633. tool_name=tool_name,
  634. start_time=timer.get("start") if timer else created_time,
  635. end_time=timer.get("end") if timer else end_time,
  636. tool_inputs=tool_inputs,
  637. tool_outputs=tool_outputs,
  638. metadata=metadata,
  639. message_file_data=message_file_data,
  640. error=error,
  641. inputs=message_data.message,
  642. outputs=message_data.answer,
  643. tool_config=tool_config,
  644. time_cost=time_cost,
  645. tool_parameters=tool_parameters,
  646. file_url=file_url,
  647. )
  648. return tool_trace_info
  649. def generate_name_trace(self, conversation_id, timer, **kwargs):
  650. generate_conversation_name = kwargs.get("generate_conversation_name")
  651. inputs = kwargs.get("inputs")
  652. tenant_id = kwargs.get("tenant_id")
  653. if not tenant_id:
  654. return {}
  655. start_time = timer.get("start")
  656. end_time = timer.get("end")
  657. metadata = {
  658. "conversation_id": conversation_id,
  659. "tenant_id": tenant_id,
  660. }
  661. generate_name_trace_info = GenerateNameTraceInfo(
  662. conversation_id=conversation_id,
  663. inputs=inputs,
  664. outputs=generate_conversation_name,
  665. start_time=start_time,
  666. end_time=end_time,
  667. metadata=metadata,
  668. tenant_id=tenant_id,
  669. )
  670. return generate_name_trace_info
  671. trace_manager_timer: Optional[threading.Timer] = None
  672. trace_manager_queue: queue.Queue = queue.Queue()
  673. trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
  674. trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
  675. class TraceQueueManager:
  676. def __init__(self, app_id=None, user_id=None):
  677. global trace_manager_timer
  678. self.app_id = app_id
  679. self.user_id = user_id
  680. self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
  681. self.flask_app = current_app._get_current_object() # type: ignore
  682. if trace_manager_timer is None:
  683. self.start_timer()
  684. def add_trace_task(self, trace_task: TraceTask):
  685. global trace_manager_timer, trace_manager_queue
  686. try:
  687. if self.trace_instance:
  688. trace_task.app_id = self.app_id
  689. trace_manager_queue.put(trace_task)
  690. except Exception as e:
  691. logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
  692. finally:
  693. self.start_timer()
  694. def collect_tasks(self):
  695. global trace_manager_queue
  696. tasks: list[TraceTask] = []
  697. while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
  698. task = trace_manager_queue.get_nowait()
  699. tasks.append(task)
  700. trace_manager_queue.task_done()
  701. return tasks
  702. def run(self):
  703. try:
  704. tasks = self.collect_tasks()
  705. if tasks:
  706. self.send_to_celery(tasks)
  707. except Exception as e:
  708. logging.exception("Error processing trace tasks")
  709. def start_timer(self):
  710. global trace_manager_timer
  711. if trace_manager_timer is None or not trace_manager_timer.is_alive():
  712. trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
  713. trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
  714. trace_manager_timer.daemon = False
  715. trace_manager_timer.start()
  716. def send_to_celery(self, tasks: list[TraceTask]):
  717. with self.flask_app.app_context():
  718. for task in tasks:
  719. if task.app_id is None:
  720. continue
  721. file_id = uuid4().hex
  722. trace_info = task.execute()
  723. task_data = TaskData(
  724. app_id=task.app_id,
  725. trace_info_type=type(trace_info).__name__,
  726. trace_info=trace_info.model_dump() if trace_info else None,
  727. )
  728. file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
  729. storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
  730. file_info = {
  731. "file_id": file_id,
  732. "app_id": task.app_id,
  733. }
  734. process_trace_tasks.delay(file_info)