|
|
|
|
|
|
|
|
import json |
|
|
import json |
|
|
|
|
|
import logging |
|
|
import time |
|
|
import time |
|
|
from typing import Optional |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
) |
|
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel |
|
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class BedrockTextEmbeddingModel(TextEmbeddingModel): |
|
|
class BedrockTextEmbeddingModel(TextEmbeddingModel): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings = [] |
|
|
embeddings = [] |
|
|
token_usage = 0 |
|
|
token_usage = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_prefix = model.split('.')[0] |
|
|
model_prefix = model.split('.')[0] |
|
|
if model_prefix == "amazon": |
|
|
|
|
|
for text in texts: |
|
|
|
|
|
body = { |
|
|
|
|
|
"inputText": text, |
|
|
|
|
|
} |
|
|
|
|
|
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) |
|
|
|
|
|
embeddings.extend([response_body.get('embedding')]) |
|
|
|
|
|
token_usage += response_body.get('inputTextTokenCount') |
|
|
|
|
|
result = TextEmbeddingResult( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_prefix == "amazon" : |
|
|
|
|
|
for text in texts: |
|
|
|
|
|
body = { |
|
|
|
|
|
"inputText": text, |
|
|
|
|
|
} |
|
|
|
|
|
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) |
|
|
|
|
|
embeddings.extend([response_body.get('embedding')]) |
|
|
|
|
|
token_usage += response_body.get('inputTextTokenCount') |
|
|
|
|
|
logger.warning(f'Total Tokens: {token_usage}') |
|
|
|
|
|
result = TextEmbeddingResult( |
|
|
model=model, |
|
|
model=model, |
|
|
embeddings=embeddings, |
|
|
embeddings=embeddings, |
|
|
usage=self._calc_response_usage( |
|
|
usage=self._calc_response_usage( |
|
|
|
|
|
|
|
|
credentials=credentials, |
|
|
credentials=credentials, |
|
|
tokens=token_usage |
|
|
tokens=token_usage |
|
|
) |
|
|
) |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") |
|
|
|
|
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
if model_prefix == "cohere" : |
|
|
|
|
|
input_type = 'search_document' if len(texts) > 1 else 'search_query' |
|
|
|
|
|
for text in texts: |
|
|
|
|
|
body = { |
|
|
|
|
|
"texts": [text], |
|
|
|
|
|
"input_type": input_type, |
|
|
|
|
|
} |
|
|
|
|
|
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) |
|
|
|
|
|
embeddings.extend(response_body.get('embeddings')) |
|
|
|
|
|
token_usage += len(text) |
|
|
|
|
|
result = TextEmbeddingResult( |
|
|
|
|
|
model=model, |
|
|
|
|
|
embeddings=embeddings, |
|
|
|
|
|
usage=self._calc_response_usage( |
|
|
|
|
|
model=model, |
|
|
|
|
|
credentials=credentials, |
|
|
|
|
|
tokens=token_usage |
|
|
|
|
|
) |
|
|
|
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
#others |
|
|
|
|
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: |
|
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: |