| @@ -140,10 +140,13 @@ class ConversationMessageTask: | |||
| def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): | |||
| message_tokens = llm_message.prompt_tokens | |||
| answer_tokens = llm_message.completion_tokens | |||
| message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN) | |||
| answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT) | |||
| total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) | |||
| message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) | |||
| answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) | |||
| answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) | |||
| total_price = message_total_price + answer_total_price | |||
| self.message.message = llm_message.prompt | |||
| self.message.message_tokens = message_tokens | |||
| @@ -206,18 +209,15 @@ class ConversationMessageTask: | |||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, | |||
| agent_loop: AgentLoop): | |||
| agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN) | |||
| agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT) | |||
| agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN) | |||
| agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| loop_message_tokens = agent_loop.prompt_tokens | |||
| loop_answer_tokens = agent_loop.completion_tokens | |||
| loop_total_price = self.calc_total_price( | |||
| loop_message_tokens, | |||
| agent_message_unit_price, | |||
| loop_answer_tokens, | |||
| agent_answer_unit_price | |||
| ) | |||
| loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) | |||
| loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) | |||
| loop_total_price = loop_message_total_price + loop_answer_total_price | |||
| message_agent_thought.observation = agent_loop.tool_output | |||
| message_agent_thought.tool_process_data = '' # currently not support | |||
| @@ -243,15 +243,6 @@ class ConversationMessageTask: | |||
| db.session.add(dataset_query) | |||
| def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price): | |||
| message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def end(self): | |||
| self._pub_handler.pub_end() | |||
| @@ -278,7 +278,7 @@ class IndexingRunner: | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format( | |||
| text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), | |||
| text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), | |||
| "currency": embedding_model.get_currency(), | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| @@ -286,7 +286,7 @@ class IndexingRunner: | |||
| return { | |||
| "total_segments": total_segments, | |||
| "tokens": tokens, | |||
| "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), | |||
| "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)), | |||
| "currency": embedding_model.get_currency(), | |||
| "preview": preview_texts | |||
| } | |||
| @@ -371,7 +371,7 @@ class IndexingRunner: | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format( | |||
| text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), | |||
| text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), | |||
| "currency": embedding_model.get_currency(), | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| @@ -31,6 +31,15 @@ class AzureOpenAIEmbedding(BaseEmbedding): | |||
| ) | |||
| super().__init__(model_provider, client, name) | |||
| @property | |||
| def base_model_name(self) -> str: | |||
| """ | |||
| get base model name (not deployment) | |||
| :return: str | |||
| """ | |||
| return self.credentials.get("base_model_name") | |||
| def get_num_tokens(self, text: str) -> int: | |||
| """ | |||
| @@ -49,16 +58,6 @@ class AzureOpenAIEmbedding(BaseEmbedding): | |||
| # calculate the number of tokens in the encoded text | |||
| return len(tokenized_text) | |||
| def get_token_price(self, tokens: int): | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * decimal.Decimal('0.0001') | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| logging.warning("Invalid request to Azure OpenAI API.") | |||
| @@ -1,5 +1,6 @@ | |||
| from abc import abstractmethod | |||
| from typing import Any | |||
| import decimal | |||
| import tiktoken | |||
| from langchain.schema.language_model import _get_token_ids_default_method | |||
| @@ -7,7 +8,8 @@ from langchain.schema.language_model import _get_token_ids_default_method | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| import logging | |||
| logger = logging.getLogger(__name__) | |||
| class BaseEmbedding(BaseProviderModel): | |||
| name: str | |||
| @@ -17,6 +19,65 @@ class BaseEmbedding(BaseProviderModel): | |||
| super().__init__(model_provider, client) | |||
| self.name = name | |||
| @property | |||
| def base_model_name(self) -> str: | |||
| """ | |||
| get base model name | |||
| :return: str | |||
| """ | |||
| return self.name | |||
| @property | |||
| def price_config(self) -> dict: | |||
| def get_or_default(): | |||
| default_price_config = { | |||
| 'prompt': decimal.Decimal('0'), | |||
| 'completion': decimal.Decimal('0'), | |||
| 'unit': decimal.Decimal('0'), | |||
| 'currency': 'USD' | |||
| } | |||
| rules = self.model_provider.get_rules() | |||
| price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config | |||
| price_config = { | |||
| 'prompt': decimal.Decimal(price_config['prompt']), | |||
| 'completion': decimal.Decimal(price_config['completion']), | |||
| 'unit': decimal.Decimal(price_config['unit']), | |||
| 'currency': price_config['currency'] | |||
| } | |||
| return price_config | |||
| self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default() | |||
| logger.debug(f"model: {self.name} price_config: {self._price_config}") | |||
| return self._price_config | |||
| def calc_tokens_price(self, tokens:int) -> decimal.Decimal: | |||
| """ | |||
| calc tokens total price. | |||
| :param tokens: | |||
| :return: decimal.Decimal('0.0000001') | |||
| """ | |||
| unit_price = self._price_config['completion'] | |||
| unit = self._price_config['unit'] | |||
| total_price = tokens * unit_price * unit | |||
| total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") | |||
| return total_price | |||
| def get_tokens_unit_price(self) -> decimal.Decimal: | |||
| """ | |||
| get token price. | |||
| :return: decimal.Decimal('0.0001') | |||
| """ | |||
| unit_price = self._price_config['completion'] | |||
| unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP) | |||
| logger.debug(f'unit_price:{unit_price}') | |||
| return unit_price | |||
| def get_num_tokens(self, text: str) -> int: | |||
| """ | |||
| get num tokens of text. | |||
| @@ -29,11 +90,14 @@ class BaseEmbedding(BaseProviderModel): | |||
| return len(_get_token_ids_default_method(text)) | |||
| def get_token_price(self, tokens: int): | |||
| return 0 | |||
| def get_currency(self): | |||
| return 'USD' | |||
| """ | |||
| get token currency. | |||
| :return: get from price config, default 'USD' | |||
| """ | |||
| currency = self._price_config['currency'] | |||
| return currency | |||
| @abstractmethod | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| @@ -22,9 +22,6 @@ class MinimaxEmbedding(BaseEmbedding): | |||
| super().__init__(model_provider, client, name) | |||
| def get_token_price(self, tokens: int): | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| @@ -42,16 +42,6 @@ class OpenAIEmbedding(BaseEmbedding): | |||
| # calculate the number of tokens in the encoded text | |||
| return len(tokenized_text) | |||
| def get_token_price(self, tokens: int): | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * decimal.Decimal('0.0001') | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| logging.warning("Invalid request to OpenAI API.") | |||
| @@ -22,13 +22,6 @@ class ReplicateEmbedding(BaseEmbedding): | |||
| super().__init__(model_provider, client, name) | |||
| def get_token_price(self, tokens: int): | |||
| # replicate only pay for prediction seconds | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, (ModelError, ReplicateError)): | |||
| return LLMBadRequestError(f"Replicate: {str(ex)}") | |||
| @@ -54,32 +54,6 @@ class AnthropicModel(BaseLLM): | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| model_unit_prices = { | |||
| 'claude-instant-1': { | |||
| 'prompt': decimal.Decimal('1.63'), | |||
| 'completion': decimal.Decimal('5.51'), | |||
| }, | |||
| 'claude-2': { | |||
| 'prompt': decimal.Decimal('11.02'), | |||
| 'completion': decimal.Decimal('32.68'), | |||
| }, | |||
| } | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = model_unit_prices[self.name]['prompt'] | |||
| else: | |||
| unit_price = model_unit_prices[self.name]['completion'] | |||
| tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1m * unit_price | |||
| return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| for k, v in provider_model_kwargs.items(): | |||
| @@ -29,7 +29,6 @@ class AzureOpenAIModel(BaseLLM): | |||
| self.model_mode = ModelMode.COMPLETION | |||
| else: | |||
| self.model_mode = ModelMode.CHAT | |||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||
| def _init_client(self) -> Any: | |||
| @@ -83,6 +82,15 @@ class AzureOpenAIModel(BaseLLM): | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| @property | |||
| def base_model_name(self) -> str: | |||
| """ | |||
| get base model name (not deployment) | |||
| :return: str | |||
| """ | |||
| return self.credentials.get("base_model_name") | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| @@ -97,45 +105,6 @@ class AzureOpenAIModel(BaseLLM): | |||
| else: | |||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| model_unit_prices = { | |||
| 'gpt-4': { | |||
| 'prompt': decimal.Decimal('0.03'), | |||
| 'completion': decimal.Decimal('0.06'), | |||
| }, | |||
| 'gpt-4-32k': { | |||
| 'prompt': decimal.Decimal('0.06'), | |||
| 'completion': decimal.Decimal('0.12') | |||
| }, | |||
| 'gpt-35-turbo': { | |||
| 'prompt': decimal.Decimal('0.0015'), | |||
| 'completion': decimal.Decimal('0.002') | |||
| }, | |||
| 'gpt-35-turbo-16k': { | |||
| 'prompt': decimal.Decimal('0.003'), | |||
| 'completion': decimal.Decimal('0.004') | |||
| }, | |||
| 'text-davinci-003': { | |||
| 'prompt': decimal.Decimal('0.02'), | |||
| 'completion': decimal.Decimal('0.02') | |||
| }, | |||
| } | |||
| base_model_name = self.credentials.get("base_model_name") | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = model_unit_prices[base_model_name]['prompt'] | |||
| else: | |||
| unit_price = model_unit_prices[base_model_name]['completion'] | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * unit_price | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| if self.name == 'text-davinci-003': | |||
| @@ -1,5 +1,6 @@ | |||
| from abc import abstractmethod | |||
| from typing import List, Optional, Any, Union | |||
| import decimal | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration | |||
| @@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| import logging | |||
| logger = logging.getLogger(__name__) | |||
| class BaseLLM(BaseProviderModel): | |||
| @@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel): | |||
| def _init_client(self) -> Any: | |||
| raise NotImplementedError | |||
| @property | |||
| def base_model_name(self) -> str: | |||
| """ | |||
| get llm base model name | |||
| :return: str | |||
| """ | |||
| return self.name | |||
| @property | |||
| def price_config(self) -> dict: | |||
| def get_or_default(): | |||
| default_price_config = { | |||
| 'prompt': decimal.Decimal('0'), | |||
| 'completion': decimal.Decimal('0'), | |||
| 'unit': decimal.Decimal('0'), | |||
| 'currency': 'USD' | |||
| } | |||
| rules = self.model_provider.get_rules() | |||
| price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config | |||
| price_config = { | |||
| 'prompt': decimal.Decimal(price_config['prompt']), | |||
| 'completion': decimal.Decimal(price_config['completion']), | |||
| 'unit': decimal.Decimal(price_config['unit']), | |||
| 'currency': price_config['currency'] | |||
| } | |||
| return price_config | |||
| self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default() | |||
| logger.debug(f"model: {self.name} price_config: {self._price_config}") | |||
| return self._price_config | |||
| def run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| @@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel): | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| def calc_tokens_price(self, tokens:int, message_type: MessageType): | |||
| """ | |||
| get token price. | |||
| calc tokens total price. | |||
| :param tokens: | |||
| :param message_type: | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = self.price_config['prompt'] | |||
| else: | |||
| unit_price = self.price_config['completion'] | |||
| unit = self.price_config['unit'] | |||
| total_price = tokens * unit_price * unit | |||
| total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}") | |||
| return total_price | |||
| def get_tokens_unit_price(self, message_type: MessageType): | |||
| """ | |||
| get token price. | |||
| :param message_type: | |||
| :return: decimal.Decimal('0.0001') | |||
| """ | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = self.price_config['prompt'] | |||
| else: | |||
| unit_price = self.price_config['completion'] | |||
| unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP) | |||
| logging.debug(f"unit_price={unit_price}") | |||
| return unit_price | |||
| @abstractmethod | |||
| def get_currency(self): | |||
| """ | |||
| get token currency. | |||
| :return: | |||
| :return: get from price config, default 'USD' | |||
| """ | |||
| raise NotImplementedError | |||
| currency = self.price_config['currency'] | |||
| return currency | |||
| def get_model_kwargs(self): | |||
| return self.model_kwargs | |||
| @@ -47,9 +47,6 @@ class ChatGLMModel(BaseLLM): | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| @@ -62,13 +62,6 @@ class HuggingfaceHubModel(BaseLLM): | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.get_num_tokens(prompts) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| # not support calc price | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| self.client.model_kwargs = provider_model_kwargs | |||
| @@ -51,9 +51,6 @@ class MinimaxModel(BaseLLM): | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| @@ -46,7 +46,8 @@ class OpenAIModel(BaseLLM): | |||
| self.model_mode = ModelMode.COMPLETION | |||
| else: | |||
| self.model_mode = ModelMode.CHAT | |||
| # TODO load price config from configs(db) | |||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||
| def _init_client(self) -> Any: | |||
| @@ -117,44 +118,6 @@ class OpenAIModel(BaseLLM): | |||
| else: | |||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| model_unit_prices = { | |||
| 'gpt-4': { | |||
| 'prompt': decimal.Decimal('0.03'), | |||
| 'completion': decimal.Decimal('0.06'), | |||
| }, | |||
| 'gpt-4-32k': { | |||
| 'prompt': decimal.Decimal('0.06'), | |||
| 'completion': decimal.Decimal('0.12') | |||
| }, | |||
| 'gpt-3.5-turbo': { | |||
| 'prompt': decimal.Decimal('0.0015'), | |||
| 'completion': decimal.Decimal('0.002') | |||
| }, | |||
| 'gpt-3.5-turbo-16k': { | |||
| 'prompt': decimal.Decimal('0.003'), | |||
| 'completion': decimal.Decimal('0.004') | |||
| }, | |||
| 'text-davinci-003': { | |||
| 'prompt': decimal.Decimal('0.02'), | |||
| 'completion': decimal.Decimal('0.02') | |||
| }, | |||
| } | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = model_unit_prices[self.name]['prompt'] | |||
| else: | |||
| unit_price = model_unit_prices[self.name]['completion'] | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * unit_price | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| if self.name in COMPLETION_MODELS: | |||
| @@ -81,13 +81,6 @@ class ReplicateModel(BaseLLM): | |||
| return self._client.get_num_tokens(prompts) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| # replicate only pay for prediction seconds | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| self.client.input = provider_model_kwargs | |||
| @@ -50,9 +50,6 @@ class SparkModel(BaseLLM): | |||
| contents = [message.content for message in messages] | |||
| return max(self._client.get_num_tokens("".join(contents)), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| @@ -53,9 +53,6 @@ class TongyiModel(BaseLLM): | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| @@ -16,6 +16,7 @@ class WenxinModel(BaseLLM): | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| # TODO load price_config from configs(db) | |||
| return Wenxin( | |||
| streaming=self.streaming, | |||
| callbacks=self.callbacks, | |||
| @@ -48,36 +49,6 @@ class WenxinModel(BaseLLM): | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| model_unit_prices = { | |||
| 'ernie-bot': { | |||
| 'prompt': decimal.Decimal('0.012'), | |||
| 'completion': decimal.Decimal('0.012'), | |||
| }, | |||
| 'ernie-bot-turbo': { | |||
| 'prompt': decimal.Decimal('0.008'), | |||
| 'completion': decimal.Decimal('0.008') | |||
| }, | |||
| 'bloomz-7b': { | |||
| 'prompt': decimal.Decimal('0.006'), | |||
| 'completion': decimal.Decimal('0.006') | |||
| } | |||
| } | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = model_unit_prices[self.name]['prompt'] | |||
| else: | |||
| unit_price = model_unit_prices[self.name]['completion'] | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * unit_price | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| for k, v in provider_model_kwargs.items(): | |||
| @@ -11,5 +11,19 @@ | |||
| "quota_unit": "tokens", | |||
| "quota_limit": 600000 | |||
| }, | |||
| "model_flexibility": "fixed" | |||
| "model_flexibility": "fixed", | |||
| "price_config": { | |||
| "claude-instant-1": { | |||
| "prompt": "1.63", | |||
| "completion": "5.51", | |||
| "unit": "0.000001", | |||
| "currency": "USD" | |||
| }, | |||
| "claude-2": { | |||
| "prompt": "11.02", | |||
| "completion": "32.68", | |||
| "unit": "0.000001", | |||
| "currency": "USD" | |||
| } | |||
| } | |||
| } | |||
| @@ -3,5 +3,48 @@ | |||
| "custom" | |||
| ], | |||
| "system_config": null, | |||
| "model_flexibility": "configurable" | |||
| "model_flexibility": "configurable", | |||
| "price_config":{ | |||
| "gpt-4": { | |||
| "prompt": "0.03", | |||
| "completion": "0.06", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "gpt-4-32k": { | |||
| "prompt": "0.06", | |||
| "completion": "0.12", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "gpt-35-turbo": { | |||
| "prompt": "0.0015", | |||
| "completion": "0.002", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "gpt-35-turbo-16k": { | |||
| "prompt": "0.003", | |||
| "completion": "0.004", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "text-davinci-002": { | |||
| "prompt": "0.02", | |||
| "completion": "0.02", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "text-davinci-003": { | |||
| "prompt": "0.02", | |||
| "completion": "0.02", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "text-embedding-ada-002":{ | |||
| "completion": "0.0001", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| } | |||
| } | |||
| } | |||
| @@ -10,5 +10,42 @@ | |||
| "quota_unit": "times", | |||
| "quota_limit": 200 | |||
| }, | |||
| "model_flexibility": "fixed" | |||
| "model_flexibility": "fixed", | |||
| "price_config": { | |||
| "gpt-4": { | |||
| "prompt": "0.03", | |||
| "completion": "0.06", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "gpt-4-32k": { | |||
| "prompt": "0.06", | |||
| "completion": "0.12", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "gpt-3.5-turbo": { | |||
| "prompt": "0.0015", | |||
| "completion": "0.002", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "gpt-3.5-turbo-16k": { | |||
| "prompt": "0.003", | |||
| "completion": "0.004", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "text-davinci-003": { | |||
| "prompt": "0.02", | |||
| "completion": "0.02", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| }, | |||
| "text-embedding-ada-002":{ | |||
| "completion": "0.0001", | |||
| "unit": "0.001", | |||
| "currency": "USD" | |||
| } | |||
| } | |||
| } | |||
| @@ -3,5 +3,25 @@ | |||
| "custom" | |||
| ], | |||
| "system_config": null, | |||
| "model_flexibility": "fixed" | |||
| "model_flexibility": "fixed", | |||
| "price_config": { | |||
| "ernie-bot": { | |||
| "prompt": "0.012", | |||
| "completion": "0.012", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| }, | |||
| "ernie-bot-turbo": { | |||
| "prompt": "0.008", | |||
| "completion": "0.008", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| }, | |||
| "bloomz-7b": { | |||
| "prompt": "0.006", | |||
| "completion": "0.006", | |||
| "unit": "0.001", | |||
| "currency": "RMB" | |||
| } | |||
| } | |||
| } | |||