| @@ -10,6 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration | |||
| from core.callback_handler.entity.agent_loop import AgentLoop | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| @@ -68,6 +69,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._current_loop.status = 'llm_end' | |||
| if response.llm_output: | |||
| self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] | |||
| else: | |||
| self._current_loop.prompt_tokens = self.model_instant.get_num_tokens( | |||
| [PromptMessage(content=self._current_loop.prompt)] | |||
| ) | |||
| completion_generation = response.generations[0][0] | |||
| if isinstance(completion_generation, ChatGeneration): | |||
| completion_message = completion_generation.message | |||
| @@ -81,6 +86,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| if response.llm_output: | |||
| self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | |||
| else: | |||
| self._current_loop.completion_tokens = self.model_instant.get_num_tokens( | |||
| [PromptMessage(content=self._current_loop.completion)] | |||
| ) | |||
| def on_llm_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| @@ -119,9 +119,11 @@ class ConversationMessageTask: | |||
| message="", | |||
| message_tokens=0, | |||
| message_unit_price=0, | |||
| message_price_unit=0, | |||
| answer="", | |||
| answer_tokens=0, | |||
| answer_unit_price=0, | |||
| answer_price_unit=0, | |||
| provider_response_latency=0, | |||
| total_price=0, | |||
| currency=self.model_instance.get_currency(), | |||
| @@ -142,7 +144,9 @@ class ConversationMessageTask: | |||
| answer_tokens = llm_message.completion_tokens | |||
| message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) | |||
| message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN) | |||
| answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT) | |||
| message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) | |||
| answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) | |||
| @@ -151,9 +155,11 @@ class ConversationMessageTask: | |||
| self.message.message = llm_message.prompt | |||
| self.message.message_tokens = message_tokens | |||
| self.message.message_unit_price = message_unit_price | |||
| self.message.message_price_unit = message_price_unit | |||
| self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' | |||
| self.message.answer_tokens = answer_tokens | |||
| self.message.answer_unit_price = answer_unit_price | |||
| self.message.answer_price_unit = answer_price_unit | |||
| self.message.provider_response_latency = llm_message.latency | |||
| self.message.total_price = total_price | |||
| @@ -195,7 +201,9 @@ class ConversationMessageTask: | |||
| tool=agent_loop.tool_name, | |||
| tool_input=agent_loop.tool_input, | |||
| message=agent_loop.prompt, | |||
| message_price_unit=0, | |||
| answer=agent_loop.completion, | |||
| answer_price_unit=0, | |||
| created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), | |||
| created_by=self.user.id | |||
| ) | |||
| @@ -210,7 +218,9 @@ class ConversationMessageTask: | |||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, | |||
| agent_loop: AgentLoop): | |||
| agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) | |||
| agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN) | |||
| agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT) | |||
| loop_message_tokens = agent_loop.prompt_tokens | |||
| loop_answer_tokens = agent_loop.completion_tokens | |||
| @@ -223,8 +233,10 @@ class ConversationMessageTask: | |||
| message_agent_thought.tool_process_data = '' # currently not support | |||
| message_agent_thought.message_token = loop_message_tokens | |||
| message_agent_thought.message_unit_price = agent_message_unit_price | |||
| message_agent_thought.message_price_unit = agent_message_price_unit | |||
| message_agent_thought.answer_token = loop_answer_tokens | |||
| message_agent_thought.answer_unit_price = agent_answer_unit_price | |||
| message_agent_thought.answer_price_unit = agent_answer_price_unit | |||
| message_agent_thought.latency = agent_loop.latency | |||
| message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens | |||
| message_agent_thought.total_price = loop_total_price | |||
| @@ -197,7 +197,7 @@ class BaseLLM(BaseProviderModel): | |||
| """ | |||
| raise NotImplementedError | |||
| def calc_tokens_price(self, tokens:int, message_type: MessageType): | |||
| def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal: | |||
| """ | |||
| calc tokens total price. | |||
| @@ -209,14 +209,14 @@ class BaseLLM(BaseProviderModel): | |||
| unit_price = self.price_config['prompt'] | |||
| else: | |||
| unit_price = self.price_config['completion'] | |||
| unit = self.price_config['unit'] | |||
| unit = self.get_price_unit(message_type) | |||
| total_price = tokens * unit_price * unit | |||
| total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") | |||
| return total_price | |||
| def get_tokens_unit_price(self, message_type: MessageType): | |||
| def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal: | |||
| """ | |||
| get token price. | |||
| @@ -231,7 +231,23 @@ class BaseLLM(BaseProviderModel): | |||
| logging.debug(f"unit_price={unit_price}") | |||
| return unit_price | |||
| def get_currency(self): | |||
| def get_price_unit(self, message_type: MessageType) -> decimal.Decimal: | |||
| """ | |||
| get price unit. | |||
| :param message_type: | |||
| :return: decimal.Decimal('0.000001') | |||
| """ | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| price_unit = self.price_config['unit'] | |||
| else: | |||
| price_unit = self.price_config['unit'] | |||
| price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP) | |||
| logging.debug(f"price_unit={price_unit}") | |||
| return price_unit | |||
| def get_currency(self) -> str: | |||
| """ | |||
| get token currency. | |||
| @@ -0,0 +1,43 @@ | |||
| """add message price unit | |||
| Revision ID: 853f9b9cd3b6 | |||
| Revises: e8883b0148c9 | |||
| Create Date: 2023-08-19 17:01:57.471562 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = '853f9b9cd3b6' | |||
| down_revision = 'e8883b0148c9' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) | |||
| batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) | |||
| with op.batch_alter_table('messages', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) | |||
| batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('messages', schema=None) as batch_op: | |||
| batch_op.drop_column('answer_price_unit') | |||
| batch_op.drop_column('message_price_unit') | |||
| with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: | |||
| batch_op.drop_column('answer_price_unit') | |||
| batch_op.drop_column('message_price_unit') | |||
| # ### end Alembic commands ### | |||
| @@ -421,9 +421,11 @@ class Message(db.Model): | |||
| message = db.Column(db.JSON, nullable=False) | |||
| message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | |||
| message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | |||
| message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||
| answer = db.Column(db.Text, nullable=False) | |||
| answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | |||
| answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | |||
| answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||
| provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) | |||
| total_price = db.Column(db.Numeric(10, 7)) | |||
| currency = db.Column(db.String(255), nullable=False) | |||
| @@ -705,9 +707,11 @@ class MessageAgentThought(db.Model): | |||
| message = db.Column(db.Text, nullable=True) | |||
| message_token = db.Column(db.Integer, nullable=True) | |||
| message_unit_price = db.Column(db.Numeric, nullable=True) | |||
| message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||
| answer = db.Column(db.Text, nullable=True) | |||
| answer_token = db.Column(db.Integer, nullable=True) | |||
| answer_unit_price = db.Column(db.Numeric, nullable=True) | |||
| answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||
| tokens = db.Column(db.Integer, nullable=True) | |||
| total_price = db.Column(db.Numeric, nullable=True) | |||
| currency = db.Column(db.String, nullable=True) | |||