| from core.callback_handler.entity.agent_loop import AgentLoop | from core.callback_handler.entity.agent_loop import AgentLoop | ||||
| from core.conversation_message_task import ConversationMessageTask | from core.conversation_message_task import ConversationMessageTask | ||||
| from core.model_providers.models.entity.message import PromptMessage | |||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| self._current_loop.status = 'llm_end' | self._current_loop.status = 'llm_end' | ||||
| if response.llm_output: | if response.llm_output: | ||||
| self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] | self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] | ||||
| else: | |||||
| self._current_loop.prompt_tokens = self.model_instant.get_num_tokens( | |||||
| [PromptMessage(content=self._current_loop.prompt)] | |||||
| ) | |||||
| completion_generation = response.generations[0][0] | completion_generation = response.generations[0][0] | ||||
| if isinstance(completion_generation, ChatGeneration): | if isinstance(completion_generation, ChatGeneration): | ||||
| completion_message = completion_generation.message | completion_message = completion_generation.message | ||||
| if response.llm_output: | if response.llm_output: | ||||
| self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | ||||
| else: | |||||
| self._current_loop.completion_tokens = self.model_instant.get_num_tokens( | |||||
| [PromptMessage(content=self._current_loop.completion)] | |||||
| ) | |||||
| def on_llm_error( | def on_llm_error( | ||||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any |
| message="", | message="", | ||||
| message_tokens=0, | message_tokens=0, | ||||
| message_unit_price=0, | message_unit_price=0, | ||||
| message_price_unit=0, | |||||
| answer="", | answer="", | ||||
| answer_tokens=0, | answer_tokens=0, | ||||
| answer_unit_price=0, | answer_unit_price=0, | ||||
| answer_price_unit=0, | |||||
| provider_response_latency=0, | provider_response_latency=0, | ||||
| total_price=0, | total_price=0, | ||||
| currency=self.model_instance.get_currency(), | currency=self.model_instance.get_currency(), | ||||
| answer_tokens = llm_message.completion_tokens | answer_tokens = llm_message.completion_tokens | ||||
| message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) | message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) | ||||
| message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN) | |||||
| answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | ||||
| answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT) | |||||
| message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) | message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) | ||||
| answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) | answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) | ||||
| self.message.message = llm_message.prompt | self.message.message = llm_message.prompt | ||||
| self.message.message_tokens = message_tokens | self.message.message_tokens = message_tokens | ||||
| self.message.message_unit_price = message_unit_price | self.message.message_unit_price = message_unit_price | ||||
| self.message.message_price_unit = message_price_unit | |||||
| self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' | self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' | ||||
| self.message.answer_tokens = answer_tokens | self.message.answer_tokens = answer_tokens | ||||
| self.message.answer_unit_price = answer_unit_price | self.message.answer_unit_price = answer_unit_price | ||||
| self.message.answer_price_unit = answer_price_unit | |||||
| self.message.provider_response_latency = llm_message.latency | self.message.provider_response_latency = llm_message.latency | ||||
| self.message.total_price = total_price | self.message.total_price = total_price | ||||
| tool=agent_loop.tool_name, | tool=agent_loop.tool_name, | ||||
| tool_input=agent_loop.tool_input, | tool_input=agent_loop.tool_input, | ||||
| message=agent_loop.prompt, | message=agent_loop.prompt, | ||||
| message_price_unit=0, | |||||
| answer=agent_loop.completion, | answer=agent_loop.completion, | ||||
| answer_price_unit=0, | |||||
| created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), | created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), | ||||
| created_by=self.user.id | created_by=self.user.id | ||||
| ) | ) | ||||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, | def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, | ||||
| agent_loop: AgentLoop): | agent_loop: AgentLoop): | ||||
| agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) | agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) | ||||
| agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN) | |||||
| agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) | agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) | ||||
| agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT) | |||||
| loop_message_tokens = agent_loop.prompt_tokens | loop_message_tokens = agent_loop.prompt_tokens | ||||
| loop_answer_tokens = agent_loop.completion_tokens | loop_answer_tokens = agent_loop.completion_tokens | ||||
| message_agent_thought.tool_process_data = '' # currently not support | message_agent_thought.tool_process_data = '' # currently not support | ||||
| message_agent_thought.message_token = loop_message_tokens | message_agent_thought.message_token = loop_message_tokens | ||||
| message_agent_thought.message_unit_price = agent_message_unit_price | message_agent_thought.message_unit_price = agent_message_unit_price | ||||
| message_agent_thought.message_price_unit = agent_message_price_unit | |||||
| message_agent_thought.answer_token = loop_answer_tokens | message_agent_thought.answer_token = loop_answer_tokens | ||||
| message_agent_thought.answer_unit_price = agent_answer_unit_price | message_agent_thought.answer_unit_price = agent_answer_unit_price | ||||
| message_agent_thought.answer_price_unit = agent_answer_price_unit | |||||
| message_agent_thought.latency = agent_loop.latency | message_agent_thought.latency = agent_loop.latency | ||||
| message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens | message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens | ||||
| message_agent_thought.total_price = loop_total_price | message_agent_thought.total_price = loop_total_price |
| """ | """ | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def calc_tokens_price(self, tokens:int, message_type: MessageType): | |||||
| def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal: | |||||
| """ | """ | ||||
| calc tokens total price. | calc tokens total price. | ||||
| unit_price = self.price_config['prompt'] | unit_price = self.price_config['prompt'] | ||||
| else: | else: | ||||
| unit_price = self.price_config['completion'] | unit_price = self.price_config['completion'] | ||||
| unit = self.price_config['unit'] | |||||
| unit = self.get_price_unit(message_type) | |||||
| total_price = tokens * unit_price * unit | total_price = tokens * unit_price * unit | ||||
| total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | ||||
| logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") | logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") | ||||
| return total_price | return total_price | ||||
| def get_tokens_unit_price(self, message_type: MessageType): | |||||
| def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal: | |||||
| """ | """ | ||||
| get token price. | get token price. | ||||
| logging.debug(f"unit_price={unit_price}") | logging.debug(f"unit_price={unit_price}") | ||||
| return unit_price | return unit_price | ||||
| def get_currency(self): | |||||
| def get_price_unit(self, message_type: MessageType) -> decimal.Decimal: | |||||
| """ | |||||
| get price unit. | |||||
| :param message_type: | |||||
| :return: decimal.Decimal('0.000001') | |||||
| """ | |||||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||||
| price_unit = self.price_config['unit'] | |||||
| else: | |||||
| price_unit = self.price_config['unit'] | |||||
| price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP) | |||||
| logging.debug(f"price_unit={price_unit}") | |||||
| return price_unit | |||||
| def get_currency(self) -> str: | |||||
| """ | """ | ||||
| get token currency. | get token currency. | ||||
| """add message price unit | |||||
| Revision ID: 853f9b9cd3b6 | |||||
| Revises: e8883b0148c9 | |||||
| Create Date: 2023-08-19 17:01:57.471562 | |||||
| """ | |||||
| from alembic import op | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.dialects import postgresql | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '853f9b9cd3b6' | |||||
| down_revision = 'e8883b0148c9' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) | |||||
| batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) | |||||
| with op.batch_alter_table('messages', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('message_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) | |||||
| batch_op.add_column(sa.Column('answer_price_unit', sa.Numeric(precision=10, scale=7), server_default=sa.text('0.001'), nullable=False)) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('messages', schema=None) as batch_op: | |||||
| batch_op.drop_column('answer_price_unit') | |||||
| batch_op.drop_column('message_price_unit') | |||||
| with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: | |||||
| batch_op.drop_column('answer_price_unit') | |||||
| batch_op.drop_column('message_price_unit') | |||||
| # ### end Alembic commands ### |
| message = db.Column(db.JSON, nullable=False) | message = db.Column(db.JSON, nullable=False) | ||||
| message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | ||||
| message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | ||||
| message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||||
| answer = db.Column(db.Text, nullable=False) | answer = db.Column(db.Text, nullable=False) | ||||
| answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | ||||
| answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) | ||||
| answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||||
| provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) | provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) | ||||
| total_price = db.Column(db.Numeric(10, 7)) | total_price = db.Column(db.Numeric(10, 7)) | ||||
| currency = db.Column(db.String(255), nullable=False) | currency = db.Column(db.String(255), nullable=False) | ||||
| message = db.Column(db.Text, nullable=True) | message = db.Column(db.Text, nullable=True) | ||||
| message_token = db.Column(db.Integer, nullable=True) | message_token = db.Column(db.Integer, nullable=True) | ||||
| message_unit_price = db.Column(db.Numeric, nullable=True) | message_unit_price = db.Column(db.Numeric, nullable=True) | ||||
| message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||||
| answer = db.Column(db.Text, nullable=True) | answer = db.Column(db.Text, nullable=True) | ||||
| answer_token = db.Column(db.Integer, nullable=True) | answer_token = db.Column(db.Integer, nullable=True) | ||||
| answer_unit_price = db.Column(db.Numeric, nullable=True) | answer_unit_price = db.Column(db.Numeric, nullable=True) | ||||
| answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001')) | |||||
| tokens = db.Column(db.Integer, nullable=True) | tokens = db.Column(db.Integer, nullable=True) | ||||
| total_price = db.Column(db.Numeric, nullable=True) | total_price = db.Column(db.Numeric, nullable=True) | ||||
| currency = db.Column(db.String, nullable=True) | currency = db.Column(db.String, nullable=True) |