You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_trace_manager.py 35KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897
  1. import json
  2. import logging
  3. import os
  4. import queue
  5. import threading
  6. import time
  7. from datetime import timedelta
  8. from typing import Any, Optional, Union
  9. from uuid import UUID, uuid4
  10. from cachetools import LRUCache
  11. from flask import current_app
  12. from sqlalchemy import select
  13. from sqlalchemy.orm import Session
  14. from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
  15. from core.ops.entities.config_entity import (
  16. OPS_FILE_PATH,
  17. TracingProviderEnum,
  18. )
  19. from core.ops.entities.trace_entity import (
  20. DatasetRetrievalTraceInfo,
  21. GenerateNameTraceInfo,
  22. MessageTraceInfo,
  23. ModerationTraceInfo,
  24. SuggestedQuestionTraceInfo,
  25. TaskData,
  26. ToolTraceInfo,
  27. TraceTaskName,
  28. WorkflowTraceInfo,
  29. )
  30. from core.ops.utils import get_message_data
  31. from core.workflow.entities.workflow_execution import WorkflowExecution
  32. from extensions.ext_database import db
  33. from extensions.ext_storage import storage
  34. from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
  35. from models.workflow import WorkflowAppLog, WorkflowRun
  36. from tasks.ops_trace_task import process_trace_tasks
  37. logger = logging.getLogger(__name__)
  38. class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
  39. def __getitem__(self, provider: str) -> dict[str, Any]:
  40. match provider:
  41. case TracingProviderEnum.LANGFUSE:
  42. from core.ops.entities.config_entity import LangfuseConfig
  43. from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
  44. return {
  45. "config_class": LangfuseConfig,
  46. "secret_keys": ["public_key", "secret_key"],
  47. "other_keys": ["host", "project_key"],
  48. "trace_instance": LangFuseDataTrace,
  49. }
  50. case TracingProviderEnum.LANGSMITH:
  51. from core.ops.entities.config_entity import LangSmithConfig
  52. from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
  53. return {
  54. "config_class": LangSmithConfig,
  55. "secret_keys": ["api_key"],
  56. "other_keys": ["project", "endpoint"],
  57. "trace_instance": LangSmithDataTrace,
  58. }
  59. case TracingProviderEnum.OPIK:
  60. from core.ops.entities.config_entity import OpikConfig
  61. from core.ops.opik_trace.opik_trace import OpikDataTrace
  62. return {
  63. "config_class": OpikConfig,
  64. "secret_keys": ["api_key"],
  65. "other_keys": ["project", "url", "workspace"],
  66. "trace_instance": OpikDataTrace,
  67. }
  68. case TracingProviderEnum.WEAVE:
  69. from core.ops.entities.config_entity import WeaveConfig
  70. from core.ops.weave_trace.weave_trace import WeaveDataTrace
  71. return {
  72. "config_class": WeaveConfig,
  73. "secret_keys": ["api_key"],
  74. "other_keys": ["project", "entity", "endpoint", "host"],
  75. "trace_instance": WeaveDataTrace,
  76. }
  77. case TracingProviderEnum.ARIZE:
  78. from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
  79. from core.ops.entities.config_entity import ArizeConfig
  80. return {
  81. "config_class": ArizeConfig,
  82. "secret_keys": ["api_key", "space_id"],
  83. "other_keys": ["project", "endpoint"],
  84. "trace_instance": ArizePhoenixDataTrace,
  85. }
  86. case TracingProviderEnum.PHOENIX:
  87. from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
  88. from core.ops.entities.config_entity import PhoenixConfig
  89. return {
  90. "config_class": PhoenixConfig,
  91. "secret_keys": ["api_key"],
  92. "other_keys": ["project", "endpoint"],
  93. "trace_instance": ArizePhoenixDataTrace,
  94. }
  95. case TracingProviderEnum.ALIYUN:
  96. from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
  97. from core.ops.entities.config_entity import AliyunConfig
  98. return {
  99. "config_class": AliyunConfig,
  100. "secret_keys": ["license_key"],
  101. "other_keys": ["endpoint", "app_name"],
  102. "trace_instance": AliyunDataTrace,
  103. }
  104. case _:
  105. raise KeyError(f"Unsupported tracing provider: {provider}")
  106. provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
  107. class OpsTraceManager:
  108. ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
  109. @classmethod
  110. def encrypt_tracing_config(
  111. cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
  112. ):
  113. """
  114. Encrypt tracing config.
  115. :param tenant_id: tenant id
  116. :param tracing_provider: tracing provider
  117. :param tracing_config: tracing config dictionary to be encrypted
  118. :param current_trace_config: current tracing configuration for keeping existing values
  119. :return: encrypted tracing configuration
  120. """
  121. # Get the configuration class and the keys that require encryption
  122. config_class, secret_keys, other_keys = (
  123. provider_config_map[tracing_provider]["config_class"],
  124. provider_config_map[tracing_provider]["secret_keys"],
  125. provider_config_map[tracing_provider]["other_keys"],
  126. )
  127. new_config = {}
  128. # Encrypt necessary keys
  129. for key in secret_keys:
  130. if key in tracing_config:
  131. if "*" in tracing_config[key]:
  132. # If the key contains '*', retain the original value from the current config
  133. new_config[key] = current_trace_config.get(key, tracing_config[key])
  134. else:
  135. # Otherwise, encrypt the key
  136. new_config[key] = encrypt_token(tenant_id, tracing_config[key])
  137. for key in other_keys:
  138. new_config[key] = tracing_config.get(key, "")
  139. # Create a new instance of the config class with the new configuration
  140. encrypted_config = config_class(**new_config)
  141. return encrypted_config.model_dump()
  142. @classmethod
  143. def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
  144. """
  145. Decrypt tracing config
  146. :param tenant_id: tenant id
  147. :param tracing_provider: tracing provider
  148. :param tracing_config: tracing config
  149. :return:
  150. """
  151. config_class, secret_keys, other_keys = (
  152. provider_config_map[tracing_provider]["config_class"],
  153. provider_config_map[tracing_provider]["secret_keys"],
  154. provider_config_map[tracing_provider]["other_keys"],
  155. )
  156. new_config = {}
  157. for key in secret_keys:
  158. if key in tracing_config:
  159. new_config[key] = decrypt_token(tenant_id, tracing_config[key])
  160. for key in other_keys:
  161. new_config[key] = tracing_config.get(key, "")
  162. return config_class(**new_config).model_dump()
  163. @classmethod
  164. def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
  165. """
  166. Decrypt tracing config
  167. :param tracing_provider: tracing provider
  168. :param decrypt_tracing_config: tracing config
  169. :return:
  170. """
  171. config_class, secret_keys, other_keys = (
  172. provider_config_map[tracing_provider]["config_class"],
  173. provider_config_map[tracing_provider]["secret_keys"],
  174. provider_config_map[tracing_provider]["other_keys"],
  175. )
  176. new_config = {}
  177. for key in secret_keys:
  178. if key in decrypt_tracing_config:
  179. new_config[key] = obfuscated_token(decrypt_tracing_config[key])
  180. for key in other_keys:
  181. new_config[key] = decrypt_tracing_config.get(key, "")
  182. return config_class(**new_config).model_dump()
  183. @classmethod
  184. def get_decrypted_tracing_config(cls, app_id: str, tracing_provider: str):
  185. """
  186. Get decrypted tracing config
  187. :param app_id: app id
  188. :param tracing_provider: tracing provider
  189. :return:
  190. """
  191. trace_config_data: Optional[TraceAppConfig] = (
  192. db.session.query(TraceAppConfig)
  193. .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
  194. .first()
  195. )
  196. if not trace_config_data:
  197. return None
  198. # decrypt_token
  199. stmt = select(App).where(App.id == app_id)
  200. app = db.session.scalar(stmt)
  201. if not app:
  202. raise ValueError("App not found")
  203. tenant_id = app.tenant_id
  204. decrypt_tracing_config = cls.decrypt_tracing_config(
  205. tenant_id, tracing_provider, trace_config_data.tracing_config
  206. )
  207. return decrypt_tracing_config
  208. @classmethod
  209. def get_ops_trace_instance(
  210. cls,
  211. app_id: Optional[Union[UUID, str]] = None,
  212. ):
  213. """
  214. Get ops trace through model config
  215. :param app_id: app_id
  216. :return:
  217. """
  218. if isinstance(app_id, UUID):
  219. app_id = str(app_id)
  220. if app_id is None:
  221. return None
  222. app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
  223. if app is None:
  224. return None
  225. app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
  226. if app_ops_trace_config is None:
  227. return None
  228. if not app_ops_trace_config.get("enabled"):
  229. return None
  230. tracing_provider = app_ops_trace_config.get("tracing_provider")
  231. if tracing_provider is None:
  232. return None
  233. try:
  234. provider_config_map[tracing_provider]
  235. except KeyError:
  236. return None
  237. # decrypt_token
  238. decrypt_trace_config = cls.get_decrypted_tracing_config(app_id, tracing_provider)
  239. if not decrypt_trace_config:
  240. return None
  241. trace_instance, config_class = (
  242. provider_config_map[tracing_provider]["trace_instance"],
  243. provider_config_map[tracing_provider]["config_class"],
  244. )
  245. decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True)
  246. tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
  247. if tracing_instance is None:
  248. # create new tracing_instance and update the cache if it absent
  249. tracing_instance = trace_instance(config_class(**decrypt_trace_config))
  250. cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance
  251. logger.info("new tracing_instance for app_id: %s", app_id)
  252. return tracing_instance
  253. @classmethod
  254. def get_app_config_through_message_id(cls, message_id: str):
  255. app_model_config = None
  256. message_stmt = select(Message).where(Message.id == message_id)
  257. message_data = db.session.scalar(message_stmt)
  258. if not message_data:
  259. return None
  260. conversation_id = message_data.conversation_id
  261. conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
  262. conversation_data = db.session.scalar(conversation_stmt)
  263. if not conversation_data:
  264. return None
  265. if conversation_data.app_model_config_id:
  266. config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
  267. app_model_config = db.session.scalar(config_stmt)
  268. elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
  269. app_model_config = conversation_data.override_model_configs
  270. return app_model_config
  271. @classmethod
  272. def update_app_tracing_config(cls, app_id: str, enabled: bool, tracing_provider: str):
  273. """
  274. Update app tracing config
  275. :param app_id: app id
  276. :param enabled: enabled
  277. :param tracing_provider: tracing provider
  278. :return:
  279. """
  280. # auth check
  281. try:
  282. if enabled or tracing_provider is not None:
  283. provider_config_map[tracing_provider]
  284. except KeyError:
  285. raise ValueError(f"Invalid tracing provider: {tracing_provider}")
  286. app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
  287. if not app_config:
  288. raise ValueError("App not found")
  289. app_config.tracing = json.dumps(
  290. {
  291. "enabled": enabled,
  292. "tracing_provider": tracing_provider,
  293. }
  294. )
  295. db.session.commit()
  296. @classmethod
  297. def get_app_tracing_config(cls, app_id: str):
  298. """
  299. Get app tracing config
  300. :param app_id: app id
  301. :return:
  302. """
  303. app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
  304. if not app:
  305. raise ValueError("App not found")
  306. if not app.tracing:
  307. return {"enabled": False, "tracing_provider": None}
  308. app_trace_config = json.loads(app.tracing)
  309. return app_trace_config
  310. @staticmethod
  311. def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
  312. """
  313. Check trace config is effective
  314. :param tracing_config: tracing config
  315. :param tracing_provider: tracing provider
  316. :return:
  317. """
  318. config_type, trace_instance = (
  319. provider_config_map[tracing_provider]["config_class"],
  320. provider_config_map[tracing_provider]["trace_instance"],
  321. )
  322. tracing_config = config_type(**tracing_config)
  323. return trace_instance(tracing_config).api_check()
  324. @staticmethod
  325. def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
  326. """
  327. get trace config is project key
  328. :param tracing_config: tracing config
  329. :param tracing_provider: tracing provider
  330. :return:
  331. """
  332. config_type, trace_instance = (
  333. provider_config_map[tracing_provider]["config_class"],
  334. provider_config_map[tracing_provider]["trace_instance"],
  335. )
  336. tracing_config = config_type(**tracing_config)
  337. return trace_instance(tracing_config).get_project_key()
  338. @staticmethod
  339. def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
  340. """
  341. get trace config is project key
  342. :param tracing_config: tracing config
  343. :param tracing_provider: tracing provider
  344. :return:
  345. """
  346. config_type, trace_instance = (
  347. provider_config_map[tracing_provider]["config_class"],
  348. provider_config_map[tracing_provider]["trace_instance"],
  349. )
  350. tracing_config = config_type(**tracing_config)
  351. return trace_instance(tracing_config).get_project_url()
  352. class TraceTask:
  353. def __init__(
  354. self,
  355. trace_type: Any,
  356. message_id: Optional[str] = None,
  357. workflow_execution: Optional[WorkflowExecution] = None,
  358. conversation_id: Optional[str] = None,
  359. user_id: Optional[str] = None,
  360. timer: Optional[Any] = None,
  361. **kwargs,
  362. ):
  363. self.trace_type = trace_type
  364. self.message_id = message_id
  365. self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
  366. self.conversation_id = conversation_id
  367. self.user_id = user_id
  368. self.timer = timer
  369. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  370. self.app_id = None
  371. self.trace_id = None
  372. self.kwargs = kwargs
  373. external_trace_id = kwargs.get("external_trace_id")
  374. if external_trace_id:
  375. self.trace_id = external_trace_id
  376. def execute(self):
  377. return self.preprocess()
  378. def preprocess(self):
  379. preprocess_map = {
  380. TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
  381. TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
  382. workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
  383. ),
  384. TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
  385. TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
  386. message_id=self.message_id, timer=self.timer, **self.kwargs
  387. ),
  388. TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
  389. message_id=self.message_id, timer=self.timer, **self.kwargs
  390. ),
  391. TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
  392. message_id=self.message_id, timer=self.timer, **self.kwargs
  393. ),
  394. TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
  395. message_id=self.message_id, timer=self.timer, **self.kwargs
  396. ),
  397. TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
  398. conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
  399. ),
  400. }
  401. return preprocess_map.get(self.trace_type, lambda: None)()
  402. # process methods for different trace types
  403. def conversation_trace(self, **kwargs):
  404. return kwargs
  405. def workflow_trace(
  406. self,
  407. *,
  408. workflow_run_id: str | None,
  409. conversation_id: str | None,
  410. user_id: str | None,
  411. ):
  412. if not workflow_run_id:
  413. return {}
  414. with Session(db.engine) as session:
  415. workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
  416. workflow_run = session.scalars(workflow_run_stmt).first()
  417. if not workflow_run:
  418. raise ValueError("Workflow run not found")
  419. workflow_id = workflow_run.workflow_id
  420. tenant_id = workflow_run.tenant_id
  421. workflow_run_id = workflow_run.id
  422. workflow_run_elapsed_time = workflow_run.elapsed_time
  423. workflow_run_status = workflow_run.status
  424. workflow_run_inputs = workflow_run.inputs_dict
  425. workflow_run_outputs = workflow_run.outputs_dict
  426. workflow_run_version = workflow_run.version
  427. error = workflow_run.error or ""
  428. total_tokens = workflow_run.total_tokens
  429. file_list = workflow_run_inputs.get("sys.file") or []
  430. query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
  431. # get workflow_app_log_id
  432. workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
  433. WorkflowAppLog.tenant_id == tenant_id,
  434. WorkflowAppLog.app_id == workflow_run.app_id,
  435. WorkflowAppLog.workflow_run_id == workflow_run.id,
  436. )
  437. workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
  438. # get message_id
  439. message_id = None
  440. if conversation_id:
  441. message_data_stmt = select(Message.id).where(
  442. Message.conversation_id == conversation_id,
  443. Message.workflow_run_id == workflow_run_id,
  444. )
  445. message_id = session.scalar(message_data_stmt)
  446. metadata = {
  447. "workflow_id": workflow_id,
  448. "conversation_id": conversation_id,
  449. "workflow_run_id": workflow_run_id,
  450. "tenant_id": tenant_id,
  451. "elapsed_time": workflow_run_elapsed_time,
  452. "status": workflow_run_status,
  453. "version": workflow_run_version,
  454. "total_tokens": total_tokens,
  455. "file_list": file_list,
  456. "triggered_from": workflow_run.triggered_from,
  457. "user_id": user_id,
  458. "app_id": workflow_run.app_id,
  459. }
  460. workflow_trace_info = WorkflowTraceInfo(
  461. trace_id=self.trace_id,
  462. workflow_data=workflow_run.to_dict(),
  463. conversation_id=conversation_id,
  464. workflow_id=workflow_id,
  465. tenant_id=tenant_id,
  466. workflow_run_id=workflow_run_id,
  467. workflow_run_elapsed_time=workflow_run_elapsed_time,
  468. workflow_run_status=workflow_run_status,
  469. workflow_run_inputs=workflow_run_inputs,
  470. workflow_run_outputs=workflow_run_outputs,
  471. workflow_run_version=workflow_run_version,
  472. error=error,
  473. total_tokens=total_tokens,
  474. file_list=file_list,
  475. query=query,
  476. metadata=metadata,
  477. workflow_app_log_id=workflow_app_log_id,
  478. message_id=message_id,
  479. start_time=workflow_run.created_at,
  480. end_time=workflow_run.finished_at,
  481. )
  482. return workflow_trace_info
  483. def message_trace(self, message_id: str | None):
  484. if not message_id:
  485. return {}
  486. message_data = get_message_data(message_id)
  487. if not message_data:
  488. return {}
  489. conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
  490. conversation_mode = db.session.scalars(conversation_mode_stmt).all()
  491. if not conversation_mode or len(conversation_mode) == 0:
  492. return {}
  493. conversation_mode = conversation_mode[0]
  494. created_at = message_data.created_at
  495. inputs = message_data.message
  496. # get message file data
  497. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  498. file_list = []
  499. if message_file_data and message_file_data.url is not None:
  500. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  501. file_list.append(file_url)
  502. metadata = {
  503. "conversation_id": message_data.conversation_id,
  504. "ls_provider": message_data.model_provider,
  505. "ls_model_name": message_data.model_id,
  506. "status": message_data.status,
  507. "from_end_user_id": message_data.from_end_user_id,
  508. "from_account_id": message_data.from_account_id,
  509. "agent_based": message_data.agent_based,
  510. "workflow_run_id": message_data.workflow_run_id,
  511. "from_source": message_data.from_source,
  512. "message_id": message_id,
  513. }
  514. message_tokens = message_data.message_tokens
  515. message_trace_info = MessageTraceInfo(
  516. trace_id=self.trace_id,
  517. message_id=message_id,
  518. message_data=message_data.to_dict(),
  519. conversation_model=conversation_mode,
  520. message_tokens=message_tokens,
  521. answer_tokens=message_data.answer_tokens,
  522. total_tokens=message_tokens + message_data.answer_tokens,
  523. error=message_data.error or "",
  524. inputs=inputs,
  525. outputs=message_data.answer,
  526. file_list=file_list,
  527. start_time=created_at,
  528. end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
  529. metadata=metadata,
  530. message_file_data=message_file_data,
  531. conversation_mode=conversation_mode,
  532. )
  533. return message_trace_info
  534. def moderation_trace(self, message_id, timer, **kwargs):
  535. moderation_result = kwargs.get("moderation_result")
  536. if not moderation_result:
  537. return {}
  538. inputs = kwargs.get("inputs")
  539. message_data = get_message_data(message_id)
  540. if not message_data:
  541. return {}
  542. metadata = {
  543. "message_id": message_id,
  544. "action": moderation_result.action,
  545. "preset_response": moderation_result.preset_response,
  546. "query": moderation_result.query,
  547. }
  548. # get workflow_app_log_id
  549. workflow_app_log_id = None
  550. if message_data.workflow_run_id:
  551. workflow_app_log_data = (
  552. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  553. )
  554. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  555. moderation_trace_info = ModerationTraceInfo(
  556. trace_id=self.trace_id,
  557. message_id=workflow_app_log_id or message_id,
  558. inputs=inputs,
  559. message_data=message_data.to_dict(),
  560. flagged=moderation_result.flagged,
  561. action=moderation_result.action,
  562. preset_response=moderation_result.preset_response,
  563. query=moderation_result.query,
  564. start_time=timer.get("start"),
  565. end_time=timer.get("end"),
  566. metadata=metadata,
  567. )
  568. return moderation_trace_info
  569. def suggested_question_trace(self, message_id, timer, **kwargs):
  570. suggested_question = kwargs.get("suggested_question", [])
  571. message_data = get_message_data(message_id)
  572. if not message_data:
  573. return {}
  574. metadata = {
  575. "message_id": message_id,
  576. "ls_provider": message_data.model_provider,
  577. "ls_model_name": message_data.model_id,
  578. "status": message_data.status,
  579. "from_end_user_id": message_data.from_end_user_id,
  580. "from_account_id": message_data.from_account_id,
  581. "agent_based": message_data.agent_based,
  582. "workflow_run_id": message_data.workflow_run_id,
  583. "from_source": message_data.from_source,
  584. }
  585. # get workflow_app_log_id
  586. workflow_app_log_id = None
  587. if message_data.workflow_run_id:
  588. workflow_app_log_data = (
  589. db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first()
  590. )
  591. workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
  592. suggested_question_trace_info = SuggestedQuestionTraceInfo(
  593. trace_id=self.trace_id,
  594. message_id=workflow_app_log_id or message_id,
  595. message_data=message_data.to_dict(),
  596. inputs=message_data.message,
  597. outputs=message_data.answer,
  598. start_time=timer.get("start"),
  599. end_time=timer.get("end"),
  600. metadata=metadata,
  601. total_tokens=message_data.message_tokens + message_data.answer_tokens,
  602. status=message_data.status,
  603. error=message_data.error,
  604. from_account_id=message_data.from_account_id,
  605. agent_based=message_data.agent_based,
  606. from_source=message_data.from_source,
  607. model_provider=message_data.model_provider,
  608. model_id=message_data.model_id,
  609. suggested_question=suggested_question,
  610. level=message_data.status,
  611. status_message=message_data.error,
  612. )
  613. return suggested_question_trace_info
  614. def dataset_retrieval_trace(self, message_id, timer, **kwargs):
  615. documents = kwargs.get("documents")
  616. message_data = get_message_data(message_id)
  617. if not message_data:
  618. return {}
  619. metadata = {
  620. "message_id": message_id,
  621. "ls_provider": message_data.model_provider,
  622. "ls_model_name": message_data.model_id,
  623. "status": message_data.status,
  624. "from_end_user_id": message_data.from_end_user_id,
  625. "from_account_id": message_data.from_account_id,
  626. "agent_based": message_data.agent_based,
  627. "workflow_run_id": message_data.workflow_run_id,
  628. "from_source": message_data.from_source,
  629. }
  630. dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
  631. trace_id=self.trace_id,
  632. message_id=message_id,
  633. inputs=message_data.query or message_data.inputs,
  634. documents=[doc.model_dump() for doc in documents] if documents else [],
  635. start_time=timer.get("start"),
  636. end_time=timer.get("end"),
  637. metadata=metadata,
  638. message_data=message_data.to_dict(),
  639. )
  640. return dataset_retrieval_trace_info
  641. def tool_trace(self, message_id, timer, **kwargs):
  642. tool_name = kwargs.get("tool_name", "")
  643. tool_inputs = kwargs.get("tool_inputs", {})
  644. tool_outputs = kwargs.get("tool_outputs", {})
  645. message_data = get_message_data(message_id)
  646. if not message_data:
  647. return {}
  648. tool_config = {}
  649. time_cost = 0
  650. error = None
  651. tool_parameters = {}
  652. created_time = message_data.created_at
  653. end_time = message_data.updated_at
  654. agent_thoughts = message_data.agent_thoughts
  655. for agent_thought in agent_thoughts:
  656. if tool_name in agent_thought.tools:
  657. created_time = agent_thought.created_at
  658. tool_meta_data = agent_thought.tool_meta.get(tool_name, {})
  659. tool_config = tool_meta_data.get("tool_config", {})
  660. time_cost = tool_meta_data.get("time_cost", 0)
  661. end_time = created_time + timedelta(seconds=time_cost)
  662. error = tool_meta_data.get("error", "")
  663. tool_parameters = tool_meta_data.get("tool_parameters", {})
  664. metadata = {
  665. "message_id": message_id,
  666. "tool_name": tool_name,
  667. "tool_inputs": tool_inputs,
  668. "tool_outputs": tool_outputs,
  669. "tool_config": tool_config,
  670. "time_cost": time_cost,
  671. "error": error,
  672. "tool_parameters": tool_parameters,
  673. }
  674. file_url = ""
  675. message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
  676. if message_file_data:
  677. message_file_id = message_file_data.id if message_file_data else None
  678. type = message_file_data.type
  679. created_by_role = message_file_data.created_by_role
  680. created_user_id = message_file_data.created_by
  681. file_url = f"{self.file_base_url}/{message_file_data.url}"
  682. metadata.update(
  683. {
  684. "message_file_id": message_file_id,
  685. "created_by_role": created_by_role,
  686. "created_user_id": created_user_id,
  687. "type": type,
  688. }
  689. )
  690. tool_trace_info = ToolTraceInfo(
  691. trace_id=self.trace_id,
  692. message_id=message_id,
  693. message_data=message_data.to_dict(),
  694. tool_name=tool_name,
  695. start_time=timer.get("start") if timer else created_time,
  696. end_time=timer.get("end") if timer else end_time,
  697. tool_inputs=tool_inputs,
  698. tool_outputs=tool_outputs,
  699. metadata=metadata,
  700. message_file_data=message_file_data,
  701. error=error,
  702. inputs=message_data.message,
  703. outputs=message_data.answer,
  704. tool_config=tool_config,
  705. time_cost=time_cost,
  706. tool_parameters=tool_parameters,
  707. file_url=file_url,
  708. )
  709. return tool_trace_info
  710. def generate_name_trace(self, conversation_id, timer, **kwargs):
  711. generate_conversation_name = kwargs.get("generate_conversation_name")
  712. inputs = kwargs.get("inputs")
  713. tenant_id = kwargs.get("tenant_id")
  714. if not tenant_id:
  715. return {}
  716. start_time = timer.get("start")
  717. end_time = timer.get("end")
  718. metadata = {
  719. "conversation_id": conversation_id,
  720. "tenant_id": tenant_id,
  721. }
  722. generate_name_trace_info = GenerateNameTraceInfo(
  723. trace_id=self.trace_id,
  724. conversation_id=conversation_id,
  725. inputs=inputs,
  726. outputs=generate_conversation_name,
  727. start_time=start_time,
  728. end_time=end_time,
  729. metadata=metadata,
  730. tenant_id=tenant_id,
  731. )
  732. return generate_name_trace_info
  733. trace_manager_timer: Optional[threading.Timer] = None
  734. trace_manager_queue: queue.Queue = queue.Queue()
  735. trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 5))
  736. trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
  737. class TraceQueueManager:
  738. def __init__(self, app_id=None, user_id=None):
  739. global trace_manager_timer
  740. self.app_id = app_id
  741. self.user_id = user_id
  742. self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
  743. self.flask_app = current_app._get_current_object() # type: ignore
  744. if trace_manager_timer is None:
  745. self.start_timer()
  746. def add_trace_task(self, trace_task: TraceTask):
  747. global trace_manager_timer, trace_manager_queue
  748. try:
  749. if self.trace_instance:
  750. trace_task.app_id = self.app_id
  751. trace_manager_queue.put(trace_task)
  752. except Exception:
  753. logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type)
  754. finally:
  755. self.start_timer()
  756. def collect_tasks(self):
  757. global trace_manager_queue
  758. tasks: list[TraceTask] = []
  759. while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
  760. task = trace_manager_queue.get_nowait()
  761. tasks.append(task)
  762. trace_manager_queue.task_done()
  763. return tasks
  764. def run(self):
  765. try:
  766. tasks = self.collect_tasks()
  767. if tasks:
  768. self.send_to_celery(tasks)
  769. except Exception:
  770. logger.exception("Error processing trace tasks")
  771. def start_timer(self):
  772. global trace_manager_timer
  773. if trace_manager_timer is None or not trace_manager_timer.is_alive():
  774. trace_manager_timer = threading.Timer(trace_manager_interval, self.run)
  775. trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
  776. trace_manager_timer.daemon = False
  777. trace_manager_timer.start()
  778. def send_to_celery(self, tasks: list[TraceTask]):
  779. with self.flask_app.app_context():
  780. for task in tasks:
  781. if task.app_id is None:
  782. continue
  783. file_id = uuid4().hex
  784. trace_info = task.execute()
  785. task_data = TaskData(
  786. app_id=task.app_id,
  787. trace_info_type=type(trace_info).__name__,
  788. trace_info=trace_info.model_dump() if trace_info else None,
  789. )
  790. file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
  791. storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
  792. file_info = {
  793. "file_id": file_id,
  794. "app_id": task.app_id,
  795. }
  796. process_trace_tasks.delay(file_info)