Explorar el Código

Add suuport for AWS Bedrock Cohere embedding (#3444)

tags/0.6.3
kerlion hace 1 año
padre
commit
200010be19
No account linked to committer's email address

+ 2
- 0
api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml Ver fichero

- amazon.titan-embed-text-v1 - amazon.titan-embed-text-v1
- cohere.embed-english-v3
- cohere.embed-multilingual-v3

+ 8
- 0
api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml Ver fichero

model: cohere.embed-english-v3
model_type: text-embedding
model_properties:
context_size: 512
pricing:
input: '0.1'
unit: '0.000001'
currency: USD

+ 8
- 0
api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml Ver fichero

model: cohere.embed-multilingual-v3
model_type: text-embedding
model_properties:
context_size: 512
pricing:
input: '0.1'
unit: '0.000001'
currency: USD

+ 40
- 15
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py Ver fichero

import json import json
import logging
import time import time
from typing import Optional from typing import Optional


) )
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel


logger = logging.getLogger(__name__)


class BedrockTextEmbeddingModel(TextEmbeddingModel): class BedrockTextEmbeddingModel(TextEmbeddingModel):




embeddings = [] embeddings = []
token_usage = 0 token_usage = 0
model_prefix = model.split('.')[0] model_prefix = model.split('.')[0]
if model_prefix == "amazon":
for text in texts:
body = {
"inputText": text,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
result = TextEmbeddingResult(
if model_prefix == "amazon" :
for text in texts:
body = {
"inputText": text,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
logger.warning(f'Total Tokens: {token_usage}')
result = TextEmbeddingResult(
model=model, model=model,
embeddings=embeddings, embeddings=embeddings,
usage=self._calc_response_usage( usage=self._calc_response_usage(
credentials=credentials, credentials=credentials,
tokens=token_usage tokens=token_usage
) )
)
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")

return result
)
return result
if model_prefix == "cohere" :
input_type = 'search_document' if len(texts) > 1 else 'search_query'
for text in texts:
body = {
"texts": [text],
"input_type": input_type,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend(response_body.get('embeddings'))
token_usage += len(text)
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
)
return result
#others
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")




def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:

Cargando…
Cancelar
Guardar