Sfoglia il codice sorgente

feat: Support for Vertex AI - load Default Application Configuration (#4641)

Co-authored-by: miendinh <miendinh@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
tags/0.6.9
miendinh 1 anno fa
parent
commit
f804adbff3
Nessun account collegato all'indirizzo email del committer

+ 5
- 2
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py Vedi File

@@ -164,10 +164,13 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
config_kwargs["stop_sequences"] = stop

service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
if service_account_info:
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:
aiplatform.init(project=project_id, location=location)

history = []
system_instruction = GEMINI_BLOCK_MODE_PROMPT

+ 10
- 6
api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py Vedi File

@@ -41,15 +41,16 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:return: embeddings result
"""
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
if service_account_info:
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:
aiplatform.init(project=project_id, location=location)

client = VertexTextEmbeddingModel.from_pretrained(model)


embeddings_batch, embedding_used_tokens = self._embedding_invoke(
client=client,
texts=texts
@@ -103,10 +104,13 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
"""
try:
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
if service_account_info:
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:
aiplatform.init(project=project_id, location=location)

client = VertexTextEmbeddingModel.from_pretrained(model)


+ 2
- 2
api/core/model_runtime/model_providers/vertex_ai/vertex_ai.yaml Vedi File

@@ -36,8 +36,8 @@ provider_credential_schema:
en_US: Enter your Google Cloud Location
- variable: vertex_service_account_key
label:
en_US: Service Account Key
en_US: Service Account Key (Leave blank if you use Application Default Credentials)
type: secret-input
required: true
required: false
placeholder:
en_US: Enter your Google Cloud Service Account Key in base64 format

Loading…
Annulla
Salva