Browse Source

fix: correct agent node token counting to properly separate prompt and completion tokens (#24368)

tags/1.8.0
-LAN- 2 months ago
parent
commit
2e47558f4b
No account linked to committer's email address

+ 31
- 7
api/core/model_runtime/entities/llm_entities.py View File

from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from decimal import Decimal from decimal import Decimal
from enum import StrEnum from enum import StrEnum
from typing import Any, Optional
from typing import Any, Optional, TypedDict, Union


from pydantic import BaseModel, Field from pydantic import BaseModel, Field


CHAT = "chat" CHAT = "chat"




class LLMUsageMetadata(TypedDict, total=False):
"""
TypedDict for LLM usage metadata.
All fields are optional.
"""

prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_unit_price: Union[float, str]
completion_unit_price: Union[float, str]
total_price: Union[float, str]
currency: str
prompt_price_unit: Union[float, str]
completion_price_unit: Union[float, str]
prompt_price: Union[float, str]
completion_price: Union[float, str]
latency: float


class LLMUsage(ModelUsage): class LLMUsage(ModelUsage):
""" """
Model class for llm usage. Model class for llm usage.
) )


@classmethod @classmethod
def from_metadata(cls, metadata: dict) -> LLMUsage:
def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage:
""" """
Create LLMUsage instance from metadata dictionary with default values. Create LLMUsage instance from metadata dictionary with default values.


Args: Args:
metadata: Dictionary containing usage metadata
metadata: TypedDict containing usage metadata


Returns: Returns:
LLMUsage instance with values from metadata or defaults LLMUsage instance with values from metadata or defaults
""" """
total_tokens = metadata.get("total_tokens", 0)
prompt_tokens = metadata.get("prompt_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0) completion_tokens = metadata.get("completion_tokens", 0)
if total_tokens > 0 and completion_tokens == 0:
completion_tokens = total_tokens
total_tokens = metadata.get("total_tokens", 0)

# If total_tokens is not provided but prompt and completion tokens are,
# calculate total_tokens
if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0):
total_tokens = prompt_tokens + completion_tokens


return cls( return cls(
prompt_tokens=metadata.get("prompt_tokens", 0),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),

+ 2
- 2
api/core/workflow/nodes/agent/agent_node.py View File

from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.request import InvokeCredentials from core.plugin.entities.request import InvokeCredentials
assert isinstance(message.message, ToolInvokeMessage.JsonMessage) assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT: if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = { agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items() for key, value in msg_metadata.items()

+ 148
- 0
api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py View File

"""Tests for LLMUsage entity."""

from decimal import Decimal

from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata


class TestLLMUsage:
"""Test cases for LLMUsage class."""

def test_from_metadata_with_all_tokens(self):
"""Test from_metadata when all token types are provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_unit_price": 0.001,
"completion_unit_price": 0.002,
"total_price": 0.2,
"currency": "USD",
"latency": 1.5,
}

usage = LLMUsage.from_metadata(metadata)

assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150
assert usage.prompt_unit_price == Decimal("0.001")
assert usage.completion_unit_price == Decimal("0.002")
assert usage.total_price == Decimal("0.2")
assert usage.currency == "USD"
assert usage.latency == 1.5

def test_from_metadata_with_prompt_tokens_only(self):
"""Test from_metadata when only prompt_tokens is provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"total_tokens": 100,
}

usage = LLMUsage.from_metadata(metadata)

assert usage.prompt_tokens == 100
assert usage.completion_tokens == 0
assert usage.total_tokens == 100

def test_from_metadata_with_completion_tokens_only(self):
"""Test from_metadata when only completion_tokens is provided."""
metadata: LLMUsageMetadata = {
"completion_tokens": 50,
"total_tokens": 50,
}

usage = LLMUsage.from_metadata(metadata)

assert usage.prompt_tokens == 0
assert usage.completion_tokens == 50
assert usage.total_tokens == 50

def test_from_metadata_calculates_total_when_missing(self):
"""Test from_metadata calculates total_tokens when not provided."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
}

usage = LLMUsage.from_metadata(metadata)

assert usage.prompt_tokens == 100
assert usage.completion_tokens == 50
assert usage.total_tokens == 150 # Should be calculated

def test_from_metadata_with_total_but_no_completion(self):
"""
Test from_metadata when total_tokens is provided but completion_tokens is 0.
This tests the fix for issue #24360 - prompt tokens should NOT be assigned to completion_tokens.
"""
metadata: LLMUsageMetadata = {
"prompt_tokens": 479,
"completion_tokens": 0,
"total_tokens": 521,
}

usage = LLMUsage.from_metadata(metadata)

# This is the key fix - prompt tokens should remain as prompt tokens
assert usage.prompt_tokens == 479
assert usage.completion_tokens == 0
assert usage.total_tokens == 521

def test_from_metadata_with_empty_metadata(self):
"""Test from_metadata with empty metadata."""
metadata: LLMUsageMetadata = {}

usage = LLMUsage.from_metadata(metadata)

assert usage.prompt_tokens == 0
assert usage.completion_tokens == 0
assert usage.total_tokens == 0
assert usage.currency == "USD"
assert usage.latency == 0.0

def test_from_metadata_preserves_zero_completion_tokens(self):
"""
Test that zero completion_tokens are preserved when explicitly set.
This is important for agent nodes that only use prompt tokens.
"""
metadata: LLMUsageMetadata = {
"prompt_tokens": 1000,
"completion_tokens": 0,
"total_tokens": 1000,
"prompt_unit_price": 0.15,
"completion_unit_price": 0.60,
"prompt_price": 0.00015,
"completion_price": 0,
"total_price": 0.00015,
}

usage = LLMUsage.from_metadata(metadata)

assert usage.prompt_tokens == 1000
assert usage.completion_tokens == 0
assert usage.total_tokens == 1000
assert usage.prompt_price == Decimal("0.00015")
assert usage.completion_price == Decimal(0)
assert usage.total_price == Decimal("0.00015")

def test_from_metadata_with_decimal_values(self):
"""Test from_metadata handles decimal values correctly."""
metadata: LLMUsageMetadata = {
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150,
"prompt_unit_price": "0.001",
"completion_unit_price": "0.002",
"prompt_price": "0.1",
"completion_price": "0.1",
"total_price": "0.2",
}

usage = LLMUsage.from_metadata(metadata)

assert usage.prompt_unit_price == Decimal("0.001")
assert usage.completion_unit_price == Decimal("0.002")
assert usage.prompt_price == Decimal("0.1")
assert usage.completion_price == Decimal("0.1")
assert usage.total_price == Decimal("0.2")

Loading…
Cancel
Save