Parcourir la source

Add OCI(Oracle Cloud Infrastructure) Generative AI Service as a Model Provider (#7775)

Co-authored-by: Walter Jin <jinshuhaicc@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: walter from vm <walter.jin@oracle.com>
tags/0.8.0
tmuife il y a 1 an
Parent
révision
89aede80cc
Aucun compte lié à l'adresse e-mail de l'auteur
23 fichiers modifiés avec 1679 ajouts et 431 suppressions
  1. 0
    0
      api/core/model_runtime/model_providers/oci/__init__.py
  2. 1
    0
      api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg
  3. 1
    0
      api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg
  4. 52
    0
      api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml
  5. 52
    0
      api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml
  6. 461
    0
      api/core/model_runtime/model_providers/oci/llm/llm.py
  7. 51
    0
      api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml
  8. 34
    0
      api/core/model_runtime/model_providers/oci/oci.py
  9. 42
    0
      api/core/model_runtime/model_providers/oci/oci.yaml
  10. 0
    0
      api/core/model_runtime/model_providers/oci/text_embedding/__init__.py
  11. 5
    0
      api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml
  12. 9
    0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml
  13. 9
    0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml
  14. 9
    0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml
  15. 9
    0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml
  16. 9
    0
      api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml
  17. 242
    0
      api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py
  18. 484
    431
      api/poetry.lock
  19. 1
    0
      api/pyproject.toml
  20. 0
    0
      api/tests/integration_tests/model_runtime/oci/__init__.py
  21. 130
    0
      api/tests/integration_tests/model_runtime/oci/test_llm.py
  22. 20
    0
      api/tests/integration_tests/model_runtime/oci/test_provider.py
  23. 58
    0
      api/tests/integration_tests/model_runtime/oci/test_text_embedding.py

+ 0
- 0
api/core/model_runtime/model_providers/oci/__init__.py Voir le fichier


+ 1
- 0
api/core/model_runtime/model_providers/oci/_assets/icon_l_en.svg Voir le fichier

@@ -0,0 +1 @@
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>

+ 1
- 0
api/core/model_runtime/model_providers/oci/_assets/icon_s_en.svg Voir le fichier

@@ -0,0 +1 @@
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>

+ 52
- 0
api/core/model_runtime/model_providers/oci/llm/cohere.command-r-16k.yaml Voir le fichier

@@ -0,0 +1,52 @@
model: cohere.command-r-16k
label:
en_US: cohere.command-r-16k v1.2
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 1
max: 1.0
- name: topP
use_template: top_p
default: 0.75
min: 0
max: 1
- name: topK
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presencePenalty
use_template: presence_penalty
min: 0
max: 1
default: 0
- name: frequencyPenalty
use_template: frequency_penalty
min: 0
max: 1
default: 0
- name: maxTokens
use_template: max_tokens
default: 600
max: 4000
pricing:
input: '0.004'
output: '0.004'
unit: '0.0001'
currency: USD

+ 52
- 0
api/core/model_runtime/model_providers/oci/llm/cohere.command-r-plus.yaml Voir le fichier

@@ -0,0 +1,52 @@
model: cohere.command-r-plus
label:
en_US: cohere.command-r-plus v1.2
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 1
max: 1.0
- name: topP
use_template: top_p
default: 0.75
min: 0
max: 1
- name: topK
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presencePenalty
use_template: presence_penalty
min: 0
max: 1
default: 0
- name: frequencyPenalty
use_template: frequency_penalty
min: 0
max: 1
default: 0
- name: maxTokens
use_template: max_tokens
default: 600
max: 4000
pricing:
input: '0.0219'
output: '0.0219'
unit: '0.0001'
currency: USD

+ 461
- 0
api/core/model_runtime/model_providers/oci/llm/llm.py Voir le fichier

@@ -0,0 +1,461 @@
import base64
import copy
import json
import logging
from collections.abc import Generator
from typing import Optional, Union

import oci
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

logger = logging.getLogger(__name__)

request_template = {
"compartmentId": "",
"servingMode": {
"modelId": "cohere.command-r-plus",
"servingType": "ON_DEMAND"
},
"chatRequest": {
"apiFormat": "COHERE",
#"preambleOverride": "You are a helpful assistant.",
#"message": "Hello!",
#"chatHistory": [],
"maxTokens": 600,
"isStream": False,
"frequencyPenalty": 0,
"presencePenalty": 0,
"temperature": 1,
"topP": 0.75
}
}
oci_config_template = {
"user": "",
"fingerprint": "",
"tenancy": "",
"region": "",
"compartment_id": "",
"key_content": ""
}

class OCILargeLanguageModel(LargeLanguageModel):
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
_supported_models = {
"meta.llama-3-70b-instruct": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
"cohere.command-r-16k": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"cohere.command-r-plus": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
}

def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool:
feature = self._supported_models.get(model_id)
if not feature:
return False
return feature["stream_tool_call"] if stream else feature["tool_call"]

def _is_multimodal_supported(self, model_id: str) -> bool:
feature = self._supported_models.get(model_id)
if not feature:
return False
return feature["multimodal"]

def _is_system_prompt_supported(self, model_id: str) -> bool:
feature = self._supported_models.get(model_id)
if not feature:
return False
return feature["system"]

def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
Invoke large language model

:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
#print("model"+"*"*20)
#print(model)
#print("credentials"+"*"*20)
#print(credentials)
#print("model_parameters"+"*"*20)
#print(model_parameters)
#print("prompt_messages"+"*"*200)
#print(prompt_messages)
#print("tools"+"*"*20)
#print(tools)

# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)

def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages

:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prompt = self._convert_messages_to_prompt(prompt_messages)

return self._get_num_tokens_by_gpt2(prompt)

def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages

:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prompt = self._convert_messages_to_prompt(prompt_messages)

return len(prompt)

def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
"""
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list

text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)

return text.rstrip()

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

:param model: model name
:param credentials: model credentials
:return:
"""
# Setup basic variables
# Auth Config
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"maxTokens": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None
) -> Union[LLMResult, Generator]:
"""
Invoke large language model

:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# config_kwargs = model_parameters.copy()
# config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
# if stop:
# config_kwargs["stop_sequences"] = stop

# initialize client
# ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
oci_config = copy.deepcopy(oci_config_template)
if "oci_config_content" in credentials:
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
config_items = oci_config_content.split("/")
if len(config_items) != 5:
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
oci_config["user"] = config_items[0]
oci_config["fingerprint"] = config_items[1]
oci_config["tenancy"] = config_items[2]
oci_config["region"] = config_items[3]
oci_config["compartment_id"] = config_items[4]
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
if "oci_key_content" in credentials:
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")

#oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
compartment_id = oci_config["compartment_id"]
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
# call embedding model
request_args = copy.deepcopy(request_template)
request_args["compartmentId"] = compartment_id
request_args["servingMode"]["modelId"] = model

chathistory = []
system_prompts = []
#if "meta.llama" in model:
# request_args["chatRequest"]["apiFormat"] = "GENERIC"
request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
request_args["chatRequest"].update(model_parameters)
frequency_penalty = model_parameters.get("frequencyPenalty", 0)
presence_penalty = model_parameters.get("presencePenalty", 0)
if frequency_penalty > 0 and presence_penalty > 0:
raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty")

# for msg in prompt_messages: # makes message roles strictly alternating
# content = self._format_message_to_glm_content(msg)
# if history and history[-1]["role"] == content["role"]:
# history[-1]["parts"].extend(content["parts"])
# else:
# history.append(content)

# temporary not implement the tool call function
valid_value = self._is_tool_call_supported(model, stream)
if tools is not None and len(tools) > 0:
if not valid_value:
raise InvokeBadRequestError("Does not support function calling")
if model.startswith("cohere"):
#print("run cohere " * 10)
for message in prompt_messages[:-1]:
text = ""
if isinstance(message.content, str):
text = message.content
if isinstance(message, UserPromptMessage):
chathistory.append({"role": "USER", "message": text})
else:
chathistory.append({"role": "CHATBOT", "message": text})
if isinstance(message, SystemPromptMessage):
if isinstance(message.content, str):
system_prompts.append(message.content)
args = {"apiFormat": "COHERE",
"preambleOverride": ' '.join(system_prompts),
"message": prompt_messages[-1].content,
"chatHistory": chathistory, }
request_args["chatRequest"].update(args)
elif model.startswith("meta"):
#print("run meta " * 10)
meta_messages = []
for message in prompt_messages:
text = message.content
meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
args = {"apiFormat": "GENERIC",
"messages": meta_messages,
"numGenerations": 1,
"topK": -1}
request_args["chatRequest"].update(args)

if stream:
request_args["chatRequest"]["isStream"] = True
#print("final request" + "|" * 20)
#print(request_args)
response = client.chat(request_args)
#print(vars(response))

if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)

return self._handle_generate_response(model, credentials, response, prompt_messages)

def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response

:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.data.chat_response.text
)

# calculate num tokens
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])

# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

# transform response
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
)

return result

def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response

:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
index = -1
events = response.data.events()
for stream in events:
chunk = json.loads(stream.data)
#print(chunk)
#chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}



#for chunk in response:
#for part in chunk.parts:
#if part.function_call:
# assistant_prompt_message.tool_calls = [
# AssistantPromptMessage.ToolCall(
# id=part.function_call.name,
# type='function',
# function=AssistantPromptMessage.ToolCall.ToolCallFunction(
# name=part.function_call.name,
# arguments=json.dumps(dict(part.function_call.args.items()))
# )
# )
# ]

if "finishReason" not in chunk:
assistant_prompt_message = AssistantPromptMessage(
content=''
)
if model.startswith("cohere"):
if chunk["text"]:
assistant_prompt_message.content += chunk["text"]
elif model.startswith("meta"):
assistant_prompt_message.content += chunk["message"]["content"][0]["text"]
index += 1
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
# calculate num tokens
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])

# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=str(chunk["finishReason"]),
usage=usage
)
)

def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.

:param message: PromptMessage to convert.
:return: String representation of the message.
"""
human_prompt = "\n\nuser:"
ai_prompt = "\n\nmodel:"

content = message.content
if isinstance(content, list):
content = "".join(
c.data for c in content if c.type != PromptMessageContentType.IMAGE
)

if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else:
raise ValueError(f"Got unknown type {message}")

return message_text

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.

:return: Invoke error mapping
"""
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: []
}

+ 51
- 0
api/core/model_runtime/model_providers/oci/llm/meta.llama-3-70b-instruct.yaml Voir le fichier

@@ -0,0 +1,51 @@
model: meta.llama-3-70b-instruct
label:
zh_Hans: meta.llama-3-70b-instruct
en_US: meta.llama-3-70b-instruct
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
default: 1
max: 2.0
- name: topP
use_template: top_p
default: 0.75
min: 0
max: 1
- name: topK
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presencePenalty
use_template: presence_penalty
min: -2
max: 2
default: 0
- name: frequencyPenalty
use_template: frequency_penalty
min: -2
max: 2
default: 0
- name: maxTokens
use_template: max_tokens
default: 600
max: 8000
pricing:
input: '0.015'
output: '0.015'
unit: '0.0001'
currency: USD

+ 34
- 0
api/core/model_runtime/model_providers/oci/oci.py Voir le fichier

@@ -0,0 +1,34 @@
import logging

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class OCIGENAIProvider(ModelProvider):

def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials

if validate failed, raise exception

:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)

# Use `cohere.command-r-plus` model for validate,
model_instance.validate_credentials(
model='cohere.command-r-plus',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex



+ 42
- 0
api/core/model_runtime/model_providers/oci/oci.yaml Voir le fichier

@@ -0,0 +1,42 @@
provider: oci
label:
en_US: OCIGenerativeAI
description:
en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+.
zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#FFFFFF"
help:
title:
en_US: Get your API Key from OCI
zh_Hans: 从 OCI 获取 API Key
url:
en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm
supported_model_types:
- llm
- text-embedding
#- rerank
configurate_methods:
- predefined-model
#- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: oci_config_content
label:
en_US: oci api key config file's content
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
- variable: oci_key_content
label:
en_US: oci api key file's content
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8')))
en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')))

+ 0
- 0
api/core/model_runtime/model_providers/oci/text_embedding/__init__.py Voir le fichier


+ 5
- 0
api/core/model_runtime/model_providers/oci/text_embedding/_position.yaml Voir le fichier

@@ -0,0 +1,5 @@
- cohere.embed-english-light-v2.0
- cohere.embed-english-light-v3.0
- cohere.embed-english-v3.0
- cohere.embed-multilingual-light-v3.0
- cohere.embed-multilingual-v3.0

+ 9
- 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v2.0.yaml Voir le fichier

@@ -0,0 +1,9 @@
model: cohere.embed-english-light-v2.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

+ 9
- 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-light-v3.0.yaml Voir le fichier

@@ -0,0 +1,9 @@
model: cohere.embed-english-light-v3.0
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

+ 9
- 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-english-v3.0.yaml Voir le fichier

@@ -0,0 +1,9 @@
model: cohere.embed-english-v3.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

+ 9
- 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-light-v3.0.yaml Voir le fichier

@@ -0,0 +1,9 @@
model: cohere.embed-multilingual-light-v3.0
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

+ 9
- 0
api/core/model_runtime/model_providers/oci/text_embedding/cohere.embed-multilingual-v3.0.yaml Voir le fichier

@@ -0,0 +1,9 @@
model: cohere.embed-multilingual-v3.0
model_type: text-embedding
model_properties:
context_size: 1024
max_chunks: 48
pricing:
input: '0.001'
unit: '0.0001'
currency: USD

+ 242
- 0
api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py Voir le fichier

@@ -0,0 +1,242 @@
import base64
import copy
import time
from typing import Optional

import numpy as np
import oci

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel

request_template = {
"compartmentId": "",
"servingMode": {
"modelId": "cohere.embed-english-light-v3.0",
"servingType": "ON_DEMAND"
},
"truncate": "NONE",
"inputs": [""]
}
oci_config_template = {
"user": "",
"fingerprint": "",
"tenancy": "",
"region": "",
"compartment_id": "",
"key_content": ""
}
class OCITextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Cohere text embedding model.
"""

def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model

:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)

inputs = []
indices = []
used_tokens = 0

for i, text in enumerate(texts):

# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)

if num_tokens >= context_size:
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0: cutoff])
else:
inputs.append(text)
indices += [i]

batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)

for i in _iter:
# call embedding model
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
credentials=credentials,
texts=inputs[i: i + max_chunks]
)

used_tokens += embedding_used_tokens
batched_embeddings += embeddings_batch

# calc usage
usage = self._calc_response_usage(
model=model,
credentials=credentials,
tokens=used_tokens
)

return TextEmbeddingResult(
embeddings=batched_embeddings,
usage=usage,
model=model
)

def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages

:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)

def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages

:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
characters = 0
for text in texts:
characters += len(text)
return characters
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

:param model: model name
:param credentials: model credentials
:return:
"""
try:
# call embedding model
self._embedding_invoke(
model=model,
credentials=credentials,
texts=['ping']
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
"""
Invoke embedding model

:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return: embeddings and used tokens
"""

# oci
# initialize client
oci_config = copy.deepcopy(oci_config_template)
if "oci_config_content" in credentials:
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
config_items = oci_config_content.split("/")
if len(config_items) != 5:
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
oci_config["user"] = config_items[0]
oci_config["fingerprint"] = config_items[1]
oci_config["tenancy"] = config_items[2]
oci_config["region"] = config_items[3]
oci_config["compartment_id"] = config_items[4]
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
if "oci_key_content" in credentials:
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
else:
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
compartment_id = oci_config["compartment_id"]
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
# call embedding model
request_args = copy.deepcopy(request_template)
request_args["compartmentId"] = compartment_id
request_args["servingMode"]["modelId"] = model
request_args["inputs"] = texts
response = client.embed_text(request_args)
return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts)

def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage

:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)

# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)

return usage

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
KeyError
]
}

+ 484
- 431
api/poetry.lock
Fichier diff supprimé car celui-ci est trop grand
Voir le fichier


+ 1
- 0
api/pyproject.toml Voir le fichier

@@ -190,6 +190,7 @@ zhipuai = "1.0.7"
azure-ai-ml = "^1.19.0"
azure-ai-inference = "^1.0.0b3"
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
oci = "^2.133.0"
[tool.poetry.group.indriect.dependencies]
kaleido = "0.2.1"
rank-bm25 = "~0.2.2"

+ 0
- 0
api/tests/integration_tests/model_runtime/oci/__init__.py Voir le fichier


+ 130
- 0
api/tests/integration_tests/model_runtime/oci/test_llm.py Voir le fichier

@@ -0,0 +1,130 @@
import os
from collections.abc import Generator

import pytest

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.oci.llm.llm import OCILargeLanguageModel


def test_validate_credentials():
model = OCILargeLanguageModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="cohere.command-r-plus",
credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
)

model.validate_credentials(
model="cohere.command-r-plus",
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
)


def test_invoke_model():
model = OCILargeLanguageModel()

response = model.invoke(
model="cohere.command-r-plus",
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123",
)

assert isinstance(response, LLMResult)
assert len(response.message.content) > 0


def test_invoke_stream_model():
model = OCILargeLanguageModel()

response = model.invoke(
model="meta.llama-3-70b-instruct",
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=True,
user="abc-123",
)

assert isinstance(response, Generator)

for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True


def test_invoke_model_with_function():
model = OCILargeLanguageModel()

response = model.invoke(
model="cohere.command-r-plus",
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=False,
user="abc-123",
tools=[
PromptMessageTool(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
)
],
)

assert isinstance(response, LLMResult)
assert len(response.message.content) > 0


def test_get_num_tokens():
model = OCILargeLanguageModel()

num_tokens = model.get_num_tokens(
model="cohere.command-r-plus",
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
)

assert num_tokens == 18

+ 20
- 0
api/tests/integration_tests/model_runtime/oci/test_provider.py Voir le fichier

@@ -0,0 +1,20 @@
import os

import pytest

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.oci.oci import OCIGENAIProvider


def test_validate_provider_credentials():
provider = OCIGENAIProvider()

with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})

provider.validate_provider_credentials(
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
}
)

+ 58
- 0
api/tests/integration_tests/model_runtime/oci/test_text_embedding.py Voir le fichier

@@ -0,0 +1,58 @@
import os

import pytest

from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.oci.text_embedding.text_embedding import OCITextEmbeddingModel


def test_validate_credentials():
model = OCITextEmbeddingModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="cohere.embed-multilingual-v3.0",
credentials={"oci_config_content": "invalid_key", "oci_key_content": "invalid_key"},
)

model.validate_credentials(
model="cohere.embed-multilingual-v3.0",
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
)


def test_invoke_model():
model = OCITextEmbeddingModel()

result = model.invoke(
model="cohere.embed-multilingual-v3.0",
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="abc-123",
)

assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 4
# assert result.usage.total_tokens == 811


def test_get_num_tokens():
model = OCITextEmbeddingModel()

num_tokens = model.get_num_tokens(
model="cohere.embed-multilingual-v3.0",
credentials={
"oci_config_content": os.environ.get("OCI_CONFIG_CONTENT"),
"oci_key_content": os.environ.get("OCI_KEY_CONTENT"),
},
texts=["hello", "world"],
)

assert num_tokens == 2

Chargement…
Annuler
Enregistrer