You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import json
  2. from typing import Any
  3. from opentelemetry.trace import Link, Status, StatusCode
  4. from core.ops.aliyun_trace.entities.semconv import (
  5. GEN_AI_FRAMEWORK,
  6. GEN_AI_SESSION_ID,
  7. GEN_AI_SPAN_KIND,
  8. GEN_AI_USER_ID,
  9. INPUT_VALUE,
  10. OUTPUT_VALUE,
  11. GenAISpanKind,
  12. )
  13. from core.rag.models.document import Document
  14. from core.workflow.entities import WorkflowNodeExecution
  15. from core.workflow.enums import WorkflowNodeExecutionStatus
  16. from extensions.ext_database import db
  17. from models import EndUser
  18. # Constants
  19. DEFAULT_JSON_ENSURE_ASCII = False
  20. DEFAULT_FRAMEWORK_NAME = "dify"
  21. def get_user_id_from_message_data(message_data) -> str:
  22. user_id = message_data.from_account_id
  23. if message_data.from_end_user_id:
  24. end_user_data: EndUser | None = (
  25. db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
  26. )
  27. if end_user_data is not None:
  28. user_id = end_user_data.session_id
  29. return user_id
  30. def create_status_from_error(error: str | None) -> Status:
  31. if error:
  32. return Status(StatusCode.ERROR, error)
  33. return Status(StatusCode.OK)
  34. def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
  35. if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  36. return Status(StatusCode.OK)
  37. if node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
  38. return Status(StatusCode.ERROR, str(node_execution.error))
  39. return Status(StatusCode.UNSET)
  40. def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
  41. from core.ops.aliyun_trace.data_exporter.traceclient import create_link
  42. links = []
  43. if trace_id:
  44. links.append(create_link(trace_id_str=trace_id))
  45. return links
  46. def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
  47. documents_data = []
  48. for document in documents:
  49. document_data = {
  50. "content": document.page_content,
  51. "metadata": {
  52. "dataset_id": document.metadata.get("dataset_id"),
  53. "doc_id": document.metadata.get("doc_id"),
  54. "document_id": document.metadata.get("document_id"),
  55. },
  56. "score": document.metadata.get("score"),
  57. }
  58. documents_data.append(document_data)
  59. return documents_data
  60. def serialize_json_data(data: Any, ensure_ascii: bool = DEFAULT_JSON_ENSURE_ASCII) -> str:
  61. return json.dumps(data, ensure_ascii=ensure_ascii)
  62. def create_common_span_attributes(
  63. session_id: str = "",
  64. user_id: str = "",
  65. span_kind: str = GenAISpanKind.CHAIN,
  66. framework: str = DEFAULT_FRAMEWORK_NAME,
  67. inputs: str = "",
  68. outputs: str = "",
  69. ) -> dict[str, Any]:
  70. return {
  71. GEN_AI_SESSION_ID: session_id,
  72. GEN_AI_USER_ID: user_id,
  73. GEN_AI_SPAN_KIND: span_kind,
  74. GEN_AI_FRAMEWORK: framework,
  75. INPUT_VALUE: inputs,
  76. OUTPUT_VALUE: outputs,
  77. }