|
|
|
@@ -291,7 +291,7 @@ class OllamaEmbed(Base): |
|
|
|
arr = [] |
|
|
|
tks_num = 0 |
|
|
|
for txt in texts: |
|
|
|
# remove special tokens if they exist |
|
|
|
# remove special tokens if they exist base on regex in one request |
|
|
|
for token in OllamaEmbed._special_tokens: |
|
|
|
txt = txt.replace(token, "") |
|
|
|
res = self.client.embeddings(prompt=txt, model=self.model_name, options={"use_mmap": True}, keep_alive=self.keep_alive) |
|
|
|
@@ -487,6 +487,8 @@ class BedrockEmbed(Base): |
|
|
|
self.bedrock_sk = json.loads(key).get("bedrock_sk", "") |
|
|
|
self.bedrock_region = json.loads(key).get("bedrock_region", "") |
|
|
|
self.model_name = model_name |
|
|
|
self.is_amazon = self.model_name.split(".")[0] == "amazon" |
|
|
|
self.is_cohere = self.model_name.split(".")[0] == "cohere" |
|
|
|
|
|
|
|
if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "": |
|
|
|
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) |
|
|
|
@@ -499,9 +501,9 @@ class BedrockEmbed(Base): |
|
|
|
embeddings = [] |
|
|
|
token_count = 0 |
|
|
|
for text in texts: |
|
|
|
if self.model_name.split(".")[0] == "amazon": |
|
|
|
if self.is_amazon: |
|
|
|
body = {"inputText": text} |
|
|
|
elif self.model_name.split(".")[0] == "cohere": |
|
|
|
elif self.is_cohere: |
|
|
|
body = {"texts": [text], "input_type": "search_document"} |
|
|
|
|
|
|
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) |
|
|
|
@@ -517,9 +519,9 @@ class BedrockEmbed(Base): |
|
|
|
def encode_queries(self, text): |
|
|
|
embeddings = [] |
|
|
|
token_count = num_tokens_from_string(text) |
|
|
|
if self.model_name.split(".")[0] == "amazon": |
|
|
|
if self.is_amazon: |
|
|
|
body = {"inputText": truncate(text, 8196)} |
|
|
|
elif self.model_name.split(".")[0] == "cohere": |
|
|
|
elif self.is_cohere: |
|
|
|
body = {"texts": [truncate(text, 8196)], "input_type": "search_query"} |
|
|
|
|
|
|
|
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) |