您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

ops_trace_manager.py 32KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849
  1. import json
  2. import logging
  3. import os
  4. import queue
  5. import threading
  6. import time
  7. from datetime import timedelta
  8. from typing import Any, Optional, Union
  9. from uuid import UUID, uuid4
  10. from cachetools import LRUCache
  11. from flask import current_app
  12. from sqlalchemy import select
  13. from sqlalchemy.orm import Session
  14. from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
  15. from core.ops.entities.config_entity import (
  16. OPS_FILE_PATH,
  17. TracingProviderEnum,
  18. )
  19. from core.ops.entities.trace_entity import (
  20. DatasetRetrievalTraceInfo,
  21. GenerateNameTraceInfo,
  22. MessageTraceInfo,
  23. ModerationTraceInfo,
  24. SuggestedQuestionTraceInfo,
  25. TaskData,
  26. ToolTraceInfo,
  27. TraceTaskName,
  28. WorkflowTraceInfo,
  29. )
  30. from core.ops.utils import get_message_data
  31. from extensions.ext_database import db
  32. from extensions.ext_storage import storage
  33. from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
  34. from models.workflow import WorkflowAppLog, WorkflowRun
  35. from tasks.ops_trace_task import process_trace_tasks
  36. class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
  37. def __getitem__(self, provider: str) -> dict[str, Any]:
  38. match provider:
  39. case TracingProviderEnum.LANGFUSE:
  40. from core.ops.entities.config_entity import LangfuseConfig
  41. from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
  42. return {
  43. "config_class": LangfuseConfig,
  44. "secret_keys": ["public_key", "secret_key"],
  45. "other_keys": ["host", "project_key"],
  46. "trace_instance": LangFuseDataTrace,
  47. }
  48. case TracingProviderEnum.LANGSMITH:
  49. from core.ops.entities.config_entity import LangSmithConfig
  50. from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
  51. return {
  52. "config_class": LangSmithConfig,
  53. "secret_keys": ["api_key"],
  54. "other_keys": ["project", "endpoint"],
  55. "trace_instance": LangSmithDataTrace,
  56. }
  57. case TracingProviderEnum.OPIK:
  58. from core.ops.entities.config_entity import OpikConfig
  59. from core.ops.opik_trace.opik_trace import OpikDataTrace
  60. return {
  61. "config_class": OpikConfig,
  62. "secret_keys": ["api_key"],
  63. "other_keys": ["project", "url", "workspace"],
  64. "trace_instance": OpikDataTrace,
  65. }
  66. case TracingProviderEnum.WEAVE:
  67. from core.ops.entities.config_entity import WeaveConfig
  68. from core.ops.weave_trace.weave_trace import WeaveDataTrace
  69. return {
  70. "config_class": WeaveConfig,
  71. "secret_keys": ["api_key"],
  72. "other_keys": ["project", "entity", "endpoint"],
  73. "trace_instance": WeaveDataTrace,
  74. }
  75. case _:
  76. raise KeyError(f"Unsupported tracing provider: {provider}")
  77. provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
  78. class OpsTraceManager:
  79. ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
  80. @classmethod
  81. def encrypt_tracing_config(
  82. cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
  83. ):
  84. """
  85. Encrypt tracing config.
  86. :param tenant_id: tenant id
  87. :param tracing_provider: tracing provider
  88. :param tracing_config: tracing config dictionary to be encrypted
  89. :param current_trace_config: current tracing configuration for keeping existing values
  90. :return: encrypted tracing configuration
  91. """
  92. # Get the configuration class and the keys that require encryption
  93. config_class, secret_keys, other_keys = (
  94. provider_config_map[tracing_provider]["config_class"],
  95. provider_config_map[tracing_provider]["secret_keys"],
  96. provider_config_map[tracing_provider]["other_keys"],
  97. )
  98. new_config = {}
  99. # Encrypt necessary keys
  100. for key in secret_keys:
  101. if key in tracing_config:
  102. if "*" in tracing_config[key]:
  103. # If the key contains '*', retain the original value from the current config
  104. new_config[key] = current_trace_config.get(key, tracing_config[key])
  105. else:
  106. # Otherwise, encrypt the key
  107. new_config[key] = encrypt_token(tenant_id, tracing_config[key])
  108. for key in other_keys:
  109. new_config[key] = tracing_config.get(key, "")
  110. # Create a new instance of the config class with the new configuration
  111. encrypted_config = config_class(**new_config)
  112. return encrypted_config.model_dump()
  113. @classmethod
  114. def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
  115. """
  116. Decrypt tracing config
  117. :param tenant_id: tenant id
  118. :param tracing_provider: tracing provider
  119. :param tracing_config: tracing config
  120. :return:
  121. """
  122. config_class, secret_keys, other_keys = (
  123. provider_config_map[tracing_provider]["config_class"],
  124. provider_config_map[tracing_provider]["secret_keys"],
  125. provider_config_map[tracing_provider]["other_keys"],
  126. )
  127. new_config = {}
  128. for key in secret_keys:
  129. if key in tracing_config:
  130. new_config[key] = decrypt_token(tenant_id, tracing_config[key])
  131. for key in other_keys:
  132. new_config[key] = tracing_config.get(key, "")
  133. return config_class(**new_config).model_dump()
  134. @classmethod
  135. def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
  136. """
  137. Decrypt tracing config
  138. :param tracing_provider: tracing provider
  139. :param decrypt_tracing_config: tracing config
  140. :return:
  141. """
  142. config_class, secret_keys, other_keys = (
  143. provider_config_map[tracing_provider]["config_class"],
  144. provider_config_map[tracing_provider]["secret_keys"],
  145. provider_config_map[tracing_provider]["other_keys"],
  146. )
  147. new_config = {}
  148. for key in secret_keys:
  149. if key in decrypt_tracing_config:
  150. new_config[key] = obfuscated_token(decrypt_tracing_config[key])
  151. for key in other_keys:
  152. new_config[key] = decrypt_tracing_config.get(key, "")
  153. return config_class(**new_config).model_dump()
  154. @classmethod
  155. def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
  156. """
  157. Get decrypted tracing config
  158. :param app_id: app id
  159. :param tracing_provider: tracing provider
  160. :return:
  161. """
  162. trace_config_data: Optional[TraceAppConfig] = (
  163. db.session.query(TraceAppConfig)
  164. .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
  165. .first()
  166. )
  167. if not trace_config_data:
  168. return None
  169. # decrypt_token
  170. app = db.session.query(App).filter(App.id == app_id).first()
  171. if not app:
  172. raise ValueError("App not found")
  173. tenant_id = app.tenant_id
  174. decrypt_tracing_config = cls.decrypt_tracing_config(
  175. tenant_id, tracing_provider, trace_config_data.tracing_config
  176. )
  177. return decrypt_tracing_config
  178. @classmethod
  179. def get_ops_trace_instance(
  180. cls,
  181. app_id: Optional[Union[UUID, str]] = None,
  182. ):
  183. """
  184. Get ops trace through model config
  185. :param app_id: app_id
  186. :return:
  187. """
  188. if isinstance(app_id, UUID):
  189. app_id = str(app_id)
  190. if app_id is None:
  191. return None
  192. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  193. if app is None:
  194. return None
  195. app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
  196. if app_ops_trace_config is None:
  197. return None
  198. if not app_ops_trace_config.get("enabled"):
  199. return None
  200. tracing_provider = app_ops_trace_config.get("tracing_provider")
  201. if tracing_provider is None or tracing_provider not in provider_config_map:
  202. return None
  203. # decrypt_token
  204. decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
  205. if not decrypt_trace_config:
  206. return None
  207. trace_instance, config_class = (
  208. provider_config_map[tracing_provider]["trace_instance"],
  209. provider_config_map[tracing_provider]["config_class"],
  210. )
  211. decrypt_trace_config_key = str(decrypt_trace_config)
  212. tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
  213. if tracing_instance is None:
  214. # create new tracing_instance and update the cache if it absent
  215. tracing_instance = trace_instance(config_class(**decrypt_trace_config))
  216. cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
  217. logging.info(f"new tracing_instance for app_id: {app_id}")
  218. return tracing_instance
  219. @classmethod
  220. def get_app_config_through_message_id(cls, message_id: str):
  221. app_model_config = None
  222. message_data = db.session.query(Message).filter(Message.id == message_id).first()
  223. if not message_data:
  224. return None
  225. conversation_id = message_data.conversation_id
  226. conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
  227. if not conversation_data:
  228. return None
  229. if conversation_data.app_model_config_id:
  230. app_model_config = (
  231. db.session.query(AppModelConfig)
  232. .filter(AppModelConfig.id == conversation_data.app_model_config_id)
  233. .first()
  234. )
  235. elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
  236. app_model_config = conversation_data.override_model_configs
  237. return app_model_config
  238. @classmethod
  239. def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
  240. """
  241. Update app tracing config
  242. :param app_id: app id
  243. :param enabled: enabled
  244. :param tracing_provider: tracing provider
  245. :return:
  246. """
  247. # auth check
  248. try:
  249. provider_config_map[tracing_provider]
  250. except KeyError:
  251. raise ValueError(f"Invalid tracing provider: {tracing_provider}")
  252. app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  253. if not app_config:
  254. raise ValueError("App not found")
  255. app_config.tracing = json.dumps(
  256. {
  257. "enabled": enabled,
  258. "tracing_provider": tracing_provider,
  259. }
  260. )
  261. db.session.commit()
  262. @classmethod
  263. def get_app_tracing_config(cls, app_id: str):
  264. """
  265. Get app tracing config
  266. :param app_id: app id
  267. :return:
  268. """
  269. app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
  270. if not app:
  271. raise ValueError("App not found")
  272. if not app.tracing:
  273. return {"enabled": False, "tracing_provider": None}
  274. app_trace_config = json.loads(app.tracing)
  275. return app_trace_config
  276. @staticmethod
  277. def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
  278. """
  279. Check trace config is effective
  280. :param tracing_config: tracing config
  281. :param tracing_provider: tracing provider
  282. :return:
  283. """
  284. config_type, trace_instance = (
  285. provider_config_map[tracing_provider]["config_class"],
  286. provider_config_map[tracing_provider]["trace_instance"],
  287. )
  288. tracing_config = config_type(**tracing_config)
  289. return trace_instance(tracing_config).api_check()
  290. @staticmethod
  291. def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
  292. """
  293. get trace config is project key
  294. :param tracing_config: tracing config
  295. :param tracing_provider: tracing provider
  296. :return:
  297. """
  298. config_type, trace_instance = (
  299. provider_config_map[tracing_provider]["config_class"],
  300. provider_config_map[tracing_provider]["trace_instance"],
  301. )
  302. tracing_config = config_type(**tracing_config)
  303. return trace_instance(tracing_config).get_project_key()
  304. @staticmethod
  305. def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
  306. """
  307. get trace config is project key
  308. :param tracing_config: tracing config
  309. :param tracing_provider: tracing provider
  310. :return:
  311. """
  312. config_type, trace_instance = (
  313. provider_config_map[tracing_provider]["config_class"],
  314. provider_config_map[tracing_provider]["trace_instance"],
  315. )
  316. tracing_config = config_type(**tracing_config)
  317. return trace_instance(tracing_config).get_project_url()
  318. class TraceTask:
  319. def __init__(
  320. self,
  321. trace_type: Any,
  322. message_id: Optional[str] = None,
  323. workflow_run: Optional[WorkflowRun] = None,
  324. conversation_id: Optional[str] = None,
  325. user_id: Optional[str] = None,
  326. timer: Optional[Any] = None,
  327. **kwargs,
  328. ):
  329. self.trace_type = trace_type
  330. self.message_id = message_id
  331. self.workflow_run_id = workflow_run.id if workflow_run else None
  332. self.conversation_id = conversation_id
  333. self.user_id = user_id
  334. self.timer = timer
  335. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  336. self.app_id = None
  337. self.kwargs = kwargs
  338. def execute(self):
  339. return self.preprocess()
  340. def preprocess(self):
  341. preprocess_map = {
  342. TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
  343. TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
  344. workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
  345. ),
  346. TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
  347. TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
  348. message_id=self.message_id, timer=self.timer, **self.kwargs
  349. ),
  350. TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
  351. message_id=self.message_id, timer=self.timer, **self.kwargs
  352. ),
  353. TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
  354. message_id=self.message_id, timer=self.timer, **self.kwargs
  355. ),
  356. TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
  357. message_id=self.message_id, timer=self.timer, **self.kwargs
  358. ),
  359. TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
  360. conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
  361. ),
  362. }
  363. return preprocess_map.get(self.trace_type, lambda: None)()
  364. # process methods for different trace types
  365. def conversation_trace(self, **kwargs):
  366. return kwargs
  367. def workflow_trace(
  368. self,
  369. *,
  370. workflow_run_id: str | None,
  371. conversation_id: str | None,
  372. user_id: str | None,
  373. ):
  374. if not workflow_run_id:
  375. return {}
  376. with Session(db.engine) as session:
  377. workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
  378. workflow_run = session.scalars(workflow_run_stmt).first()
  379. if not workflow_run:
  380. raise ValueError("Workflow run not found")
  381. workflow_id = workflow_run.workflow_id
  382. tenant_id = workflow_run.tenant_id
  383. workflow_run_id = workflow_run.id
  384. workflow_run_elapsed_time = workflow_run.elapsed_time
  385. workflow_run_status = workflow_run.status
  386. workflow_run_inputs = workflow_run.inputs_dict
  387. workflow_run_outputs = workflow_run.outputs_dict
  388. workflow_run_version = workflow_run.version
  389. error = workflow_run.error or ""
  390. total_tokens = workflow_run.total_tokens
  391. file_list = workflow_run_inputs.get("sys.file") or []
  392. query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
  393. # get workflow_app_log_id
  394. workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
  395. WorkflowAppLog.tenant_id == tenant_id,
  396. WorkflowAppLog.app_id == workflow_run.app_id,
  397. WorkflowAppLog.workflow_run_id == workflow_run.id,
  398. )
  399. workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
  400. # get message_id
  401. message_id = None
  402. if conversation_id:
  403. message_data_stmt = select(Message.id).where(
  404. Message.conversation_id == conversation_id,
  405. Message.workflow_run_id == workflow_run_id,
  406. )
  407. message_id = session.scalar(message_data_stmt)
  408. metadata = {
  409. "workflow_id": workflow_id,
  410. "conversation_id": conversation_id,
  411. "workflow_run_id": workflow_run_id,
  412. "tenant_id": tenant_id,
  413. "elapsed_time": workflow_run_elapsed_time,
  414. "status": workflow_run_status,
  415. "version": workflow_run_version,
  416. "total_tokens": total_tokens,
  417. "file_list": file_list,
  418. "triggered_from": workflow_run.triggered_from,
  419. "user_id": user_id,
  420. }
  421. workflow_trace_info = WorkflowTraceInfo(
  422. workflow_data=workflow_run.to_dict(),
  423. conversation_id=conversation_id,
  424. workflow_id=workflow_id,
  425. tenant_id=tenant_id,
  426. workflow_run_id=workflow_run_id,
  427. workflow_run_elapsed_time=workflow_run_elapsed_time,
  428. workflow_run_status=workflow_run_status,
  429. workflow_run_inputs=workflow_run_inputs,
  430. workflow_run_outputs=workflow_run_outputs,
  431. workflow_run_version=workflow_run_version,
  432. error=error,
  433. total_tokens=total_tokens,
  434. file_list=file_list,
  435. query=query,
  436. metadata=metadata,
  437. workflow_app_log_id=workflow_app_log_id,
  438. message_id=message_id,
  439. start_time=workflow_run.created_at,
  440. end_time=workflow_run.finished_at,
  441. )
  442. return workflow_trace_info
  443. def message_trace(self, message_id: str | None):
  444. if not message_id:
  445. return {}
  446. message_data = get_message_data(message_id)
  447. if not message_data:
  448. return {}
  449. conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
  450. conversation_mode = db.session.scalars(conversation_mode_stmt).all()
  451. if not conversation_mode or len(conversation_mode) == 0:
  452. return {}
  453. conversation_mode = conversation_mode[0]
  454. created_at = message_data.created_at
  455. inputs = message_data.message
  456. # get message file data
  457. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  458. file_list = []
  459. if message_file_data and message_file_data.url is not None:
  460. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  461. file_list.append(file_url)
  462. metadata = {
  463. "conversation_id": message_data.conversation_id,
  464. "ls_provider": message_data.model_provider,
  465. "ls_model_name": message_data.model_id,
  466. "status": message_data.status,
  467. "from_end_user_id": message_data.from_end_user_id,
  468. "from_account_id": message_data.from_account_id,
  469. "agent_based": message_data.agent_based,
  470. "workflow_run_id": message_data.workflow_run_id,
  471. "from_source": message_data.from_source,
  472. "message_id": message_id,
  473. }
  474. message_tokens = message_data.message_tokens
  475. message_trace_info = MessageTraceInfo(
  476. message_id=message_id,
  477. message_data=message_data.to_dict(),
  478. conversation_model=conversation_mode,
  479. message_tokens=message_tokens,
  480. answer_tokens=message_data.answer_tokens,
  481. total_tokens=message_tokens + message_data.answer_tokens,
  482. error=message_data.error or "",
  483. inputs=inputs,
  484. outputs=message_data.answer,
  485. file_list=file_list,
  486. start_time=created_at,
  487. end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
  488. metadata=metadata,
  489. message_file_data=message_file_data,
  490. conversation_mode=conversation_mode,
  491. )
  492. return message_trace_info
  493. def moderation_trace(self, message_id, timer, **kwargs):
  494. moderation_result = kwargs.get("moderation_result")
  495. if not moderation_result:
  496. return {}
  497. inputs = kwargs.get("inputs")
  498. message_data = get_message_data(message_id)
  499. if not message_data:
  500. return {}
  501. metadata = {
  502. "message_id": message_id,
  503. "action": moderation_result.action,
  504. "preset_response": moderation_result.preset_response,
  505. "query": moderation_result.query,
  506. }
  507. # get workflow_app_log_id
  508. workflow_app_log_id = None
  509. if message_data.workflow_run_id:
  510. workflow_app_log_data = (
  511. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  512. )
  513. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  514. moderation_trace_info = ModerationTraceInfo(
  515. message_id=workflow_app_log_id or message_id,
  516. inputs=inputs,
  517. message_data=message_data.to_dict(),
  518. flagged=moderation_result.flagged,
  519. action=moderation_result.action,
  520. preset_response=moderation_result.preset_response,
  521. query=moderation_result.query,
  522. start_time=timer.get("start"),
  523. end_time=timer.get("end"),
  524. metadata=metadata,
  525. )
  526. return moderation_trace_info
  527. def suggested_question_trace(self, message_id, timer, **kwargs):
  528. suggested_question = kwargs.get("suggested_question", [])
  529. message_data = get_message_data(message_id)
  530. if not message_data:
  531. return {}
  532. metadata = {
  533. "message_id": message_id,
  534. "ls_provider": message_data.model_provider,
  535. "ls_model_name": message_data.model_id,
  536. "status": message_data.status,
  537. "from_end_user_id": message_data.from_end_user_id,
  538. "from_account_id": message_data.from_account_id,
  539. "agent_based": message_data.agent_based,
  540. "workflow_run_id": message_data.workflow_run_id,
  541. "from_source": message_data.from_source,
  542. }
  543. # get workflow_app_log_id
  544. workflow_app_log_id = None
  545. if message_data.workflow_run_id:
  546. workflow_app_log_data = (
  547. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  548. )
  549. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  550. suggested_question_trace_info = SuggestedQuestionTraceInfo(
  551. message_id=workflow_app_log_id or message_id,
  552. message_data=message_data.to_dict(),
  553. inputs=message_data.message,
  554. outputs=message_data.answer,
  555. start_time=timer.get("start"),
  556. end_time=timer.get("end"),
  557. metadata=metadata,
  558. total_tokens=message_data.message_tokens + message_data.answer_tokens,
  559. status=message_data.status,
  560. error=message_data.error,
  561. from_account_id=message_data.from_account_id,
  562. agent_based=message_data.agent_based,
  563. from_source=message_data.from_source,
  564. model_provider=message_data.model_provider,
  565. model_id=message_data.model_id,
  566. suggested_question=suggested_question,
  567. level=message_data.status,
  568. status_message=message_data.error,
  569. )
  570. return suggested_question_trace_info
  571. def dataset_retrieval_trace(self, message_id, timer, **kwargs):
  572. documents = kwargs.get("documents")
  573. message_data = get_message_data(message_id)
  574. if not message_data:
  575. return {}
  576. metadata = {
  577. "message_id": message_id,
  578. "ls_provider": message_data.model_provider,
  579. "ls_model_name": message_data.model_id,
  580. "status": message_data.status,
  581. "from_end_user_id": message_data.from_end_user_id,
  582. "from_account_id": message_data.from_account_id,
  583. "agent_based": message_data.agent_based,
  584. "workflow_run_id": message_data.workflow_run_id,
  585. "from_source": message_data.from_source,
  586. }
  587. dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
  588. message_id=message_id,
  589. inputs=message_data.query or message_data.inputs,
  590. documents=[doc.model_dump() for doc in documents] if documents else [],
  591. start_time=timer.get("start"),
  592. end_time=timer.get("end"),
  593. metadata=metadata,
  594. message_data=message_data.to_dict(),
  595. )
  596. return dataset_retrieval_trace_info
  597. def tool_trace(self, message_id, timer, **kwargs):
  598. tool_name = kwargs.get("tool_name", "")
  599. tool_inputs = kwargs.get("tool_inputs", {})
  600. tool_outputs = kwargs.get("tool_outputs", {})
  601. message_data = get_message_data(message_id)
  602. if not message_data:
  603. return {}
  604. tool_config = {}
  605. time_cost = 0
  606. error = None
  607. tool_parameters = {}
  608. created_time = message_data.created_at
  609. end_time = message_data.updated_at
  610. agent_thoughts = message_data.agent_thoughts
  611. for agent_thought in agent_thoughts:
  612. if tool_name in agent_thought.tools:
  613. created_time = agent_thought.created_at
  614. tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
  615. tool_config = tool_meta_data.get("tool_config", {})
  616. time_cost = tool_meta_data.get("time_cost", 0)
  617. end_time = created_time + timedelta(seconds=time_cost)
  618. error = tool_meta_data.get("error", "")
  619. tool_parameters = tool_meta_data.get("tool_parameters", {})
  620. metadata = {
  621. "message_id": message_id,
  622. "tool_name": tool_name,
  623. "tool_inputs": tool_inputs,
  624. "tool_outputs": tool_outputs,
  625. "tool_config": tool_config,
  626. "time_cost": time_cost,
  627. "error": error,
  628. "tool_parameters": tool_parameters,
  629. }
  630. file_url = ""
  631. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  632. if message_file_data:
  633. message_file_id = message_file_data.id if message_file_data else None
  634. type = message_file_data.type
  635. created_by_role = message_file_data.created_by_role
  636. created_user_id = message_file_data.created_by
  637. file_url = f"{self.file_base_url}/{message_file_data.url}"
  638. metadata.update(
  639. {
  640. "message_file_id": message_file_id,
  641. "created_by_role": created_by_role,
  642. "created_user_id": created_user_id,
  643. "type": type,
  644. }
  645. )
  646. tool_trace_info = ToolTraceInfo(
  647. message_id=message_id,
  648. message_data=message_data.to_dict(),
  649. tool_name=tool_name,
  650. start_time=timer.get("start") if timer else created_time,
  651. end_time=timer.get("end") if timer else end_time,
  652. tool_inputs=tool_inputs,
  653. tool_outputs=tool_outputs,
  654. metadata=metadata,
  655. message_file_data=message_file_data,
  656. error=error,
  657. inputs=message_data.message,
  658. outputs=message_data.answer,
  659. tool_config=tool_config,
  660. time_cost=time_cost,
  661. tool_parameters=tool_parameters,
  662. file_url=file_url,
  663. )
  664. return tool_trace_info
  665. def generate_name_trace(self, conversation_id, timer, **kwargs):
  666. generate_conversation_name = kwargs.get("generate_conversation_name")
  667. inputs = kwargs.get("inputs")
  668. tenant_id = kwargs.get("tenant_id")
  669. if not tenant_id:
  670. return {}
  671. start_time = timer.get("start")
  672. end_time = timer.get("end")
  673. metadata = {
  674. "conversation_id": conversation_id,
  675. "tenant_id": tenant_id,
  676. }
  677. generate_name_trace_info = GenerateNameTraceInfo(
  678. conversation_id=conversation_id,
  679. inputs=inputs,
  680. outputs=generate_conversation_name,
  681. start_time=start_time,
  682. end_time=end_time,
  683. metadata=metadata,
  684. tenant_id=tenant_id,
  685. )
  686. return generate_name_trace_info
  687. trace_manager_timer: Optional[threading.Timer] = None
  688. trace_manager_queue: queue.Queue = queue.Queue()
  689. trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
  690. trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
  691. class TraceQueueManager:
  692. def __init__(self, app_id=None, user_id=None):
  693. global trace_manager_timer
  694. self.app_id = app_id
  695. self.user_id = user_id
  696. self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
  697. self.flask_app = current_app._get_current_object() # type: ignore
  698. if trace_manager_timer is None:
  699. self.start_timer()
  700. def add_trace_task(self, trace_task: TraceTask):
  701. global trace_manager_timer, trace_manager_queue
  702. try:
  703. if self.trace_instance:
  704. trace_task.app_id = self.app_id
  705. trace_manager_queue.put(trace_task)
  706. except Exception as e:
  707. logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}")
  708. finally:
  709. self.start_timer()
  710. def collect_tasks(self):
  711. global trace_manager_queue
  712. tasks: list[TraceTask] = []
  713. while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
  714. task = trace_manager_queue.get_nowait()
  715. tasks.append(task)
  716. trace_manager_queue.task_done()
  717. return tasks
  718. def run(self):
  719. try:
  720. tasks = self.collect_tasks()
  721. if tasks:
  722. self.send_to_celery(tasks)
  723. except Exception as e:
  724. logging.exception("Error processing trace tasks")
  725. def start_timer(self):
  726. global trace_manager_timer
  727. if trace_manager_timer is None or not trace_manager_timer.is_alive():
  728. trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
  729. trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
  730. trace_manager_timer.daemon = False
  731. trace_manager_timer.start()
  732. def send_to_celery(self, tasks: list[TraceTask]):
  733. with self.flask_app.app_context():
  734. for task in tasks:
  735. if task.app_id is None:
  736. continue
  737. file_id = uuid4().hex
  738. trace_info = task.execute()
  739. task_data = TaskData(
  740. app_id=task.app_id,
  741. trace_info_type=type(trace_info).__name__,
  742. trace_info=trace_info.model_dump() if trace_info else None,
  743. )
  744. file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
  745. storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
  746. file_info = {
  747. "file_id": file_id,
  748. "app_id": task.app_id,
  749. }
  750. process_trace_tasks.delay(file_info)