소스 검색

Fix/price calc (#862)

tags/0.3.15
Krasus.Chen 2 년 전
부모
커밋
fd0fc8f4fe
No account linked to committer's email address

+ 11
- 20
api/core/conversation_message_task.py 파일 보기

def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)


total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)

message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price


self.message.message = llm_message.prompt self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens self.message.message_tokens = message_tokens


def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)


loop_message_tokens = agent_loop.prompt_tokens loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens loop_answer_tokens = agent_loop.completion_tokens


loop_total_price = self.calc_total_price(
loop_message_tokens,
agent_message_unit_price,
loop_answer_tokens,
agent_answer_unit_price
)
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price


message_agent_thought.observation = agent_loop.tool_output message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support message_agent_thought.tool_process_data = '' # currently not support


db.session.add(dataset_query) db.session.add(dataset_query)


def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)

total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)

def end(self): def end(self):
self._pub_handler.pub_end() self._pub_handler.pub_end()



+ 3
- 3
api/core/indexing_runner.py 파일 보기

"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts
return { return {
"total_segments": total_segments, "total_segments": total_segments,
"tokens": tokens, "tokens": tokens,
"total_price": '{:f}'.format(embedding_model.get_token_price(tokens)),
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"preview": preview_texts "preview": preview_texts
} }
"total_segments": total_segments * 20, "total_segments": total_segments * 20,
"tokens": total_segments * 2000, "tokens": total_segments * 2000,
"total_price": '{:f}'.format( "total_price": '{:f}'.format(
text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
"currency": embedding_model.get_currency(), "currency": embedding_model.get_currency(),
"qa_preview": document_qa_list, "qa_preview": document_qa_list,
"preview": preview_texts "preview": preview_texts

+ 9
- 10
api/core/model_providers/models/embedding/azure_openai_embedding.py 파일 보기

) )


super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")


def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
""" """
# calculate the number of tokens in the encoded text # calculate the number of tokens in the encoded text
return len(tokenized_text) return len(tokenized_text)


def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)

total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)

def get_currency(self):
return 'USD'

def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError): if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to Azure OpenAI API.") logging.warning("Invalid request to Azure OpenAI API.")

+ 69
- 5
api/core/model_providers/models/embedding/base.py 파일 보기

from abc import abstractmethod from abc import abstractmethod
from typing import Any from typing import Any
import decimal


import tiktoken import tiktoken
from langchain.schema.language_model import _get_token_ids_default_method from langchain.schema.language_model import _get_token_ids_default_method
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider

import logging
logger = logging.getLogger(__name__)


class BaseEmbedding(BaseProviderModel): class BaseEmbedding(BaseProviderModel):
name: str name: str
super().__init__(model_provider, client) super().__init__(model_provider, client)
self.name = name self.name = name


@property
def base_model_name(self) -> str:
"""
get base model name
:return: str
"""
return self.name

@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()

logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config

def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
"""
calc tokens total price.

:param tokens:
:return: decimal.Decimal('0.0000001')
"""
unit_price = self._price_config['completion']
unit = self._price_config['unit']
total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price

def get_tokens_unit_price(self) -> decimal.Decimal:
"""
get token price.

:return: decimal.Decimal('0.0001')
"""
unit_price = self._price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logger.debug(f'unit_price:{unit_price}')
return unit_price

def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
""" """
get num tokens of text. get num tokens of text.


return len(_get_token_ids_default_method(text)) return len(_get_token_ids_default_method(text))


def get_token_price(self, tokens: int):
return 0

def get_currency(self): def get_currency(self):
return 'USD'
"""
get token currency.

:return: get from price config, default 'USD'
"""
currency = self._price_config['currency']
return currency


@abstractmethod @abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:

+ 0
- 3
api/core/model_providers/models/embedding/minimax_embedding.py 파일 보기



super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)


def get_token_price(self, tokens: int):
return decimal.Decimal('0')

def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'



+ 0
- 10
api/core/model_providers/models/embedding/openai_embedding.py 파일 보기

# calculate the number of tokens in the encoded text # calculate the number of tokens in the encoded text
return len(tokenized_text) return len(tokenized_text)


def get_token_price(self, tokens: int):
tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)

total_price = tokens_per_1k * decimal.Decimal('0.0001')
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)

def get_currency(self):
return 'USD'

def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError): if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to OpenAI API.") logging.warning("Invalid request to OpenAI API.")

+ 0
- 7
api/core/model_providers/models/embedding/replicate_embedding.py 파일 보기



super().__init__(model_provider, client, name) super().__init__(model_provider, client, name)


def get_token_price(self, tokens: int):
# replicate only pay for prediction seconds
return decimal.Decimal('0')

def get_currency(self):
return 'USD'

def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)): if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Replicate: {str(ex)}") return LLMBadRequestError(f"Replicate: {str(ex)}")

+ 0
- 26
api/core/model_providers/models/llm/anthropic_model.py 파일 보기

prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)


def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'claude-instant-1': {
'prompt': decimal.Decimal('1.63'),
'completion': decimal.Decimal('5.51'),
},
'claude-2': {
'prompt': decimal.Decimal('11.02'),
'completion': decimal.Decimal('32.68'),
},
}

if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']

tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'),
rounding=decimal.ROUND_HALF_UP)

total_price = tokens_per_1m * unit_price
return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP)

def get_currency(self):
return 'USD'

def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items(): for k, v in provider_model_kwargs.items():

+ 9
- 40
api/core/model_providers/models/llm/azure_openai_model.py 파일 보기

self.model_mode = ModelMode.COMPLETION self.model_mode = ModelMode.COMPLETION
else: else:
self.model_mode = ModelMode.CHAT self.model_mode = ModelMode.CHAT

super().__init__(model_provider, name, model_kwargs, streaming, callbacks) super().__init__(model_provider, name, model_kwargs, streaming, callbacks)


def _init_client(self) -> Any: def _init_client(self) -> Any:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks) return self._client.generate([prompts], stop, callbacks)
@property
def base_model_name(self) -> str:
"""
get base model name (not deployment)
:return: str
"""
return self.credentials.get("base_model_name")


def get_num_tokens(self, messages: List[PromptMessage]) -> int: def get_num_tokens(self, messages: List[PromptMessage]) -> int:
""" """
else: else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)


def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-35-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-35-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}

base_model_name = self.credentials.get("base_model_name")
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[base_model_name]['prompt']
else:
unit_price = model_unit_prices[base_model_name]['completion']

tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)

total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)

def get_currency(self):
return 'USD'

def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name == 'text-davinci-003': if self.name == 'text-davinci-003':

+ 66
- 7
api/core/model_providers/models/llm/base.py 파일 보기

from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional, Any, Union from typing import List, Optional, Any, Union
import decimal


from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
import logging
logger = logging.getLogger(__name__)




class BaseLLM(BaseProviderModel): class BaseLLM(BaseProviderModel):
def _init_client(self) -> Any: def _init_client(self) -> Any:
raise NotImplementedError raise NotImplementedError


@property
def base_model_name(self) -> str:
"""
get llm base model name

:return: str
"""
return self.name

@property
def price_config(self) -> dict:
def get_or_default():
default_price_config = {
'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'),
'currency': 'USD'
}
rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
price_config = {
'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']),
'unit': decimal.Decimal(price_config['unit']),
'currency': price_config['currency']
}
return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()

logger.debug(f"model: {self.name} price_config: {self._price_config}")
return self._price_config

def run(self, messages: List[PromptMessage], def run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
""" """
raise NotImplementedError raise NotImplementedError


@abstractmethod
def get_token_price(self, tokens: int, message_type: MessageType):
def calc_tokens_price(self, tokens:int, message_type: MessageType):
""" """
get token price.
calc tokens total price.


:param tokens: :param tokens:
:param message_type: :param message_type:
:return: :return:
""" """
raise NotImplementedError
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit = self.price_config['unit']

total_price = tokens * unit_price * unit
total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
return total_price

def get_tokens_unit_price(self, message_type: MessageType):
"""
get token price.

:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
logging.debug(f"unit_price={unit_price}")
return unit_price


@abstractmethod
def get_currency(self): def get_currency(self):
""" """
get token currency. get token currency.


:return:
:return: get from price config, default 'USD'
""" """
raise NotImplementedError
currency = self.price_config['currency']
return currency


def get_model_kwargs(self): def get_model_kwargs(self):
return self.model_kwargs return self.model_kwargs

+ 0
- 3
api/core/model_providers/models/llm/chatglm_model.py 파일 보기

prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)


def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')

def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'



+ 0
- 7
api/core/model_providers/models/llm/huggingface_hub_model.py 파일 보기

prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts) return self._client.get_num_tokens(prompts)


def get_token_price(self, tokens: int, message_type: MessageType):
# not support calc price
return decimal.Decimal('0')

def get_currency(self):
return 'USD'

def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs self.client.model_kwargs = provider_model_kwargs

+ 0
- 3
api/core/model_providers/models/llm/minimax_model.py 파일 보기

prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)


def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')

def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'



+ 2
- 39
api/core/model_providers/models/llm/openai_model.py 파일 보기

self.model_mode = ModelMode.COMPLETION self.model_mode = ModelMode.COMPLETION
else: else:
self.model_mode = ModelMode.CHAT self.model_mode = ModelMode.CHAT

# TODO load price config from configs(db)
super().__init__(model_provider, name, model_kwargs, streaming, callbacks) super().__init__(model_provider, name, model_kwargs, streaming, callbacks)


def _init_client(self) -> Any: def _init_client(self) -> Any:
else: else:
return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0)


def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'gpt-4': {
'prompt': decimal.Decimal('0.03'),
'completion': decimal.Decimal('0.06'),
},
'gpt-4-32k': {
'prompt': decimal.Decimal('0.06'),
'completion': decimal.Decimal('0.12')
},
'gpt-3.5-turbo': {
'prompt': decimal.Decimal('0.0015'),
'completion': decimal.Decimal('0.002')
},
'gpt-3.5-turbo-16k': {
'prompt': decimal.Decimal('0.003'),
'completion': decimal.Decimal('0.004')
},
'text-davinci-003': {
'prompt': decimal.Decimal('0.02'),
'completion': decimal.Decimal('0.02')
},
}

if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']

tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)

total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)

def get_currency(self):
return 'USD'

def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.name in COMPLETION_MODELS: if self.name in COMPLETION_MODELS:

+ 0
- 7
api/core/model_providers/models/llm/replicate_model.py 파일 보기



return self._client.get_num_tokens(prompts) return self._client.get_num_tokens(prompts)


def get_token_price(self, tokens: int, message_type: MessageType):
# replicate only pay for prediction seconds
return decimal.Decimal('0')

def get_currency(self):
return 'USD'

def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.input = provider_model_kwargs self.client.input = provider_model_kwargs

+ 0
- 3
api/core/model_providers/models/llm/spark_model.py 파일 보기

contents = [message.content for message in messages] contents = [message.content for message in messages]
return max(self._client.get_num_tokens("".join(contents)), 0) return max(self._client.get_num_tokens("".join(contents)), 0)


def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')

def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'



+ 0
- 3
api/core/model_providers/models/llm/tongyi_model.py 파일 보기

prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)


def get_token_price(self, tokens: int, message_type: MessageType):
return decimal.Decimal('0')

def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'



+ 1
- 30
api/core/model_providers/models/llm/wenxin_model.py 파일 보기



def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin( return Wenxin(
streaming=self.streaming, streaming=self.streaming,
callbacks=self.callbacks, callbacks=self.callbacks,
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)


def get_token_price(self, tokens: int, message_type: MessageType):
model_unit_prices = {
'ernie-bot': {
'prompt': decimal.Decimal('0.012'),
'completion': decimal.Decimal('0.012'),
},
'ernie-bot-turbo': {
'prompt': decimal.Decimal('0.008'),
'completion': decimal.Decimal('0.008')
},
'bloomz-7b': {
'prompt': decimal.Decimal('0.006'),
'completion': decimal.Decimal('0.006')
}
}

if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
unit_price = model_unit_prices[self.name]['prompt']
else:
unit_price = model_unit_prices[self.name]['completion']

tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'),
rounding=decimal.ROUND_HALF_UP)

total_price = tokens_per_1k * unit_price
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)

def get_currency(self):
return 'RMB'

def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items(): for k, v in provider_model_kwargs.items():

+ 15
- 1
api/core/model_providers/rules/anthropic.json 파일 보기

"quota_unit": "tokens", "quota_unit": "tokens",
"quota_limit": 600000 "quota_limit": 600000
}, },
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"claude-instant-1": {
"prompt": "1.63",
"completion": "5.51",
"unit": "0.000001",
"currency": "USD"
},
"claude-2": {
"prompt": "11.02",
"completion": "32.68",
"unit": "0.000001",
"currency": "USD"
}
}
} }

+ 44
- 1
api/core/model_providers/rules/azure_openai.json 파일 보기

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "configurable"
"model_flexibility": "configurable",
"price_config":{
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-35-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-002": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
} }

+ 38
- 1
api/core/model_providers/rules/openai.json 파일 보기

"quota_unit": "times", "quota_unit": "times",
"quota_limit": 200 "quota_limit": 200
}, },
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"gpt-4": {
"prompt": "0.03",
"completion": "0.06",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-32k": {
"prompt": "0.06",
"completion": "0.12",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",
"unit": "0.001",
"currency": "USD"
},
"text-davinci-003": {
"prompt": "0.02",
"completion": "0.02",
"unit": "0.001",
"currency": "USD"
},
"text-embedding-ada-002":{
"completion": "0.0001",
"unit": "0.001",
"currency": "USD"
}
}
} }

+ 21
- 1
api/core/model_providers/rules/wenxin.json 파일 보기

"custom" "custom"
], ],
"system_config": null, "system_config": null,
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot-turbo": {
"prompt": "0.008",
"completion": "0.008",
"unit": "0.001",
"currency": "RMB"
},
"bloomz-7b": {
"prompt": "0.006",
"completion": "0.006",
"unit": "0.001",
"currency": "RMB"
}
}
} }

Loading…
취소
저장