You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import json
  2. from collections.abc import Mapping
  3. from typing import Any
  4. from opentelemetry.trace import Link, Status, StatusCode
  5. from core.ops.aliyun_trace.entities.semconv import (
  6. GEN_AI_FRAMEWORK,
  7. GEN_AI_SESSION_ID,
  8. GEN_AI_SPAN_KIND,
  9. GEN_AI_USER_ID,
  10. INPUT_VALUE,
  11. OUTPUT_VALUE,
  12. GenAISpanKind,
  13. )
  14. from core.rag.models.document import Document
  15. from core.workflow.entities import WorkflowNodeExecution
  16. from core.workflow.enums import WorkflowNodeExecutionStatus
  17. from extensions.ext_database import db
  18. from models import EndUser
  19. # Constants
  20. DEFAULT_JSON_ENSURE_ASCII = False
  21. DEFAULT_FRAMEWORK_NAME = "dify"
  22. def get_user_id_from_message_data(message_data) -> str:
  23. user_id = message_data.from_account_id
  24. if message_data.from_end_user_id:
  25. end_user_data: EndUser | None = (
  26. db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
  27. )
  28. if end_user_data is not None:
  29. user_id = end_user_data.session_id
  30. return user_id
  31. def create_status_from_error(error: str | None) -> Status:
  32. if error:
  33. return Status(StatusCode.ERROR, error)
  34. return Status(StatusCode.OK)
  35. def get_workflow_node_status(node_execution: WorkflowNodeExecution) -> Status:
  36. if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
  37. return Status(StatusCode.OK)
  38. if node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
  39. return Status(StatusCode.ERROR, str(node_execution.error))
  40. return Status(StatusCode.UNSET)
  41. def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
  42. from core.ops.aliyun_trace.data_exporter.traceclient import create_link
  43. links = []
  44. if trace_id:
  45. links.append(create_link(trace_id_str=trace_id))
  46. return links
  47. def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
  48. documents_data = []
  49. for document in documents:
  50. document_data = {
  51. "content": document.page_content,
  52. "metadata": {
  53. "dataset_id": document.metadata.get("dataset_id"),
  54. "doc_id": document.metadata.get("doc_id"),
  55. "document_id": document.metadata.get("document_id"),
  56. },
  57. "score": document.metadata.get("score"),
  58. }
  59. documents_data.append(document_data)
  60. return documents_data
  61. def serialize_json_data(data: Any, ensure_ascii: bool = DEFAULT_JSON_ENSURE_ASCII) -> str:
  62. return json.dumps(data, ensure_ascii=ensure_ascii)
  63. def create_common_span_attributes(
  64. session_id: str = "",
  65. user_id: str = "",
  66. span_kind: str = GenAISpanKind.CHAIN,
  67. framework: str = DEFAULT_FRAMEWORK_NAME,
  68. inputs: str = "",
  69. outputs: str = "",
  70. ) -> dict[str, Any]:
  71. return {
  72. GEN_AI_SESSION_ID: session_id,
  73. GEN_AI_USER_ID: user_id,
  74. GEN_AI_SPAN_KIND: span_kind,
  75. GEN_AI_FRAMEWORK: framework,
  76. INPUT_VALUE: inputs,
  77. OUTPUT_VALUE: outputs,
  78. }
  79. def format_retrieval_documents(retrieval_documents: list) -> list:
  80. try:
  81. if not isinstance(retrieval_documents, list):
  82. return []
  83. semantic_documents = []
  84. for doc in retrieval_documents:
  85. if not isinstance(doc, dict):
  86. continue
  87. metadata = doc.get("metadata", {})
  88. content = doc.get("content", "")
  89. title = doc.get("title", "")
  90. score = metadata.get("score", 0.0)
  91. document_id = metadata.get("document_id", "")
  92. semantic_metadata = {}
  93. if title:
  94. semantic_metadata["title"] = title
  95. if metadata.get("source"):
  96. semantic_metadata["source"] = metadata["source"]
  97. elif metadata.get("_source"):
  98. semantic_metadata["source"] = metadata["_source"]
  99. if metadata.get("doc_metadata"):
  100. doc_metadata = metadata["doc_metadata"]
  101. if isinstance(doc_metadata, dict):
  102. semantic_metadata.update(doc_metadata)
  103. semantic_doc = {
  104. "document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
  105. }
  106. semantic_documents.append(semantic_doc)
  107. return semantic_documents
  108. except Exception:
  109. return []
  110. def format_input_messages(process_data: Mapping[str, Any]) -> str:
  111. try:
  112. if not isinstance(process_data, dict):
  113. return serialize_json_data([])
  114. prompts = process_data.get("prompts", [])
  115. if not prompts:
  116. return serialize_json_data([])
  117. valid_roles = {"system", "user", "assistant", "tool"}
  118. input_messages = []
  119. for prompt in prompts:
  120. if not isinstance(prompt, dict):
  121. continue
  122. role = prompt.get("role", "")
  123. text = prompt.get("text", "")
  124. if not role or role not in valid_roles:
  125. continue
  126. if text:
  127. message = {"role": role, "parts": [{"type": "text", "content": text}]}
  128. input_messages.append(message)
  129. return serialize_json_data(input_messages)
  130. except Exception:
  131. return serialize_json_data([])
  132. def format_output_messages(outputs: Mapping[str, Any]) -> str:
  133. try:
  134. if not isinstance(outputs, dict):
  135. return serialize_json_data([])
  136. text = outputs.get("text", "")
  137. finish_reason = outputs.get("finish_reason", "")
  138. if not text:
  139. return serialize_json_data([])
  140. valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
  141. if finish_reason not in valid_finish_reasons:
  142. finish_reason = "stop"
  143. output_message = {
  144. "role": "assistant",
  145. "parts": [{"type": "text", "content": text}],
  146. "finish_reason": finish_reason,
  147. }
  148. return serialize_json_data([output_message])
  149. except Exception:
  150. return serialize_json_data([])