您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from core.model_runtime.entities.llm_entities import LLMResult
  2. from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
  3. from core.tools.__base.tool import Tool
  4. from core.tools.__base.tool_runtime import ToolRuntime
  5. from core.tools.entities.tool_entities import ToolProviderType
  6. from core.tools.utils.model_invocation_utils import ModelInvocationUtils
  7. _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
  8. and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
  9. retain the original meaning and keep the key points.
  10. however, the text you got is too long, what you got is possible a part of the text.
  11. Please summarize the text you got.
  12. """
  13. class BuiltinTool(Tool):
  14. """
  15. Builtin tool
  16. :param meta: the meta data of a tool call processing
  17. """
  18. provider: str
  19. def __init__(self, provider: str, **kwargs):
  20. super().__init__(**kwargs)
  21. self.provider = provider
  22. def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
  23. """
  24. fork a new tool with metadata
  25. :return: the new tool
  26. """
  27. return self.__class__(
  28. entity=self.entity.model_copy(),
  29. runtime=runtime,
  30. provider=self.provider,
  31. )
  32. def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult:
  33. """
  34. invoke model
  35. :param user_id: the user id
  36. :param prompt_messages: the prompt messages
  37. :param stop: the stop words
  38. :return: the model result
  39. """
  40. # invoke model
  41. return ModelInvocationUtils.invoke(
  42. user_id=user_id,
  43. tenant_id=self.runtime.tenant_id or "",
  44. tool_type="builtin",
  45. tool_name=self.entity.identity.name,
  46. prompt_messages=prompt_messages,
  47. )
  48. def tool_provider_type(self) -> ToolProviderType:
  49. return ToolProviderType.BUILT_IN
  50. def get_max_tokens(self) -> int:
  51. """
  52. get max tokens
  53. :return: the max tokens
  54. """
  55. if self.runtime is None:
  56. raise ValueError("runtime is required")
  57. return ModelInvocationUtils.get_max_llm_context_tokens(
  58. tenant_id=self.runtime.tenant_id or "",
  59. )
  60. def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
  61. """
  62. get prompt tokens
  63. :param prompt_messages: the prompt messages
  64. :return: the tokens
  65. """
  66. if self.runtime is None:
  67. raise ValueError("runtime is required")
  68. return ModelInvocationUtils.calculate_tokens(
  69. tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
  70. )
  71. def summary(self, user_id: str, content: str) -> str:
  72. max_tokens = self.get_max_tokens()
  73. if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=content)]) < max_tokens * 0.6:
  74. return content
  75. def get_prompt_tokens(content: str) -> int:
  76. return self.get_prompt_tokens(
  77. prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)]
  78. )
  79. def summarize(content: str) -> str:
  80. summary = self.invoke_model(
  81. user_id=user_id,
  82. prompt_messages=[SystemPromptMessage(content=_SUMMARY_PROMPT), UserPromptMessage(content=content)],
  83. stop=[],
  84. )
  85. assert isinstance(summary.message.content, str)
  86. return summary.message.content
  87. lines = content.split("\n")
  88. new_lines = []
  89. # split long line into multiple lines
  90. for i in range(len(lines)):
  91. line = lines[i]
  92. if not line.strip():
  93. continue
  94. if len(line) < max_tokens * 0.5:
  95. new_lines.append(line)
  96. elif get_prompt_tokens(line) > max_tokens * 0.7:
  97. while get_prompt_tokens(line) > max_tokens * 0.7:
  98. new_lines.append(line[: int(max_tokens * 0.5)])
  99. line = line[int(max_tokens * 0.5) :]
  100. new_lines.append(line)
  101. else:
  102. new_lines.append(line)
  103. # merge lines into messages with max tokens
  104. messages: list[str] = []
  105. for j in new_lines:
  106. if len(messages) == 0:
  107. messages.append(j)
  108. else:
  109. if len(messages[-1]) + len(j) < max_tokens * 0.5:
  110. messages[-1] += j
  111. if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
  112. messages.append(j)
  113. else:
  114. messages[-1] += j
  115. summaries = []
  116. for i in range(len(messages)):
  117. message = messages[i]
  118. summary = summarize(message)
  119. summaries.append(summary)
  120. result = "\n".join(summaries)
  121. if self.get_prompt_tokens(prompt_messages=[UserPromptMessage(content=result)]) > max_tokens * 0.7:
  122. return self.summary(user_id=user_id, content=result)
  123. return result