Explorar el Código

Refactor: Remove Useless split for BedrockEmbed (#9067)

### What problem does this PR solve?

Remove Useless split for BedrockEmbed

### Type of change

- [x] Refactoring
tags/v0.20.0
Stephen Hu hace 3 meses
padre
commit
86b4da0844
No account linked to committer's email address
Se han modificado 1 ficheros con 7 adiciones y 5 borrados
  1. 7
    5
      rag/llm/embedding_model.py

+ 7
- 5
rag/llm/embedding_model.py Ver fichero

@@ -291,7 +291,7 @@ class OllamaEmbed(Base):
arr = []
tks_num = 0
for txt in texts:
# remove special tokens if they exist
# remove special tokens if they exist base on regex in one request
for token in OllamaEmbed._special_tokens:
txt = txt.replace(token, "")
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive)
@@ -487,6 +487,8 @@ class BedrockEmbed(Base):
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
self.bedrock_region = json.loads(key).get("bedrock_region", "")
self.model_name = model_name
self.is_amazon = self.model_name.split(".")[0] == "amazon"
self.is_cohere = self.model_name.split(".")[0] == "cohere"

if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "":
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
@@ -499,9 +501,9 @@ class BedrockEmbed(Base):
embeddings = []
token_count = 0
for text in texts:
if self.model_name.split(".")[0] == "amazon":
if self.is_amazon:
body = {"inputText": text}
elif self.model_name.split(".")[0] == "cohere":
elif self.is_cohere:
body = {"texts": [text], "input_type": "search_document"}

response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
@@ -517,9 +519,9 @@ class BedrockEmbed(Base):
def encode_queries(self, text):
embeddings = []
token_count = num_tokens_from_string(text)
if self.model_name.split(".")[0] == "amazon":
if self.is_amazon:
body = {"inputText": truncate(text, 8196)}
elif self.model_name.split(".")[0] == "cohere":
elif self.is_cohere:
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"}

response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))

Cargando…
Cancelar
Guardar