Przeglądaj źródła

feat: support json schema for gemini models (#10835)

tags/0.12.0
非法操作 11 miesięcy temu
rodzic
commit
bc1013dacf
No account linked to committer's email address
18 zmienionych plików z 61 dodań i 77 usunięć
  1. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml
  2. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml
  3. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml
  4. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml
  5. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml
  6. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml
  7. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml
  8. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml
  9. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml
  10. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml
  11. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml
  12. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml
  13. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml
  14. 3
    4
      api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml
  15. 1
    0
      api/core/model_runtime/model_providers/google/llm/gemini-pro-vision.yaml
  16. 1
    0
      api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml
  17. 10
    14
      api/core/model_runtime/model_providers/google/llm/llm.py
  18. 7
    7
      api/tests/integration_tests/model_runtime/google/test_llm.py

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 3
- 4
api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml Wyświetl plik

@@ -24,14 +24,13 @@ parameter_rules:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'

+ 1
- 0
api/core/model_runtime/model_providers/google/llm/gemini-pro-vision.yaml Wyświetl plik

@@ -32,3 +32,4 @@ pricing:
output: '0.00'
unit: '0.000001'
currency: USD
deprecated: true

+ 1
- 0
api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml Wyświetl plik

@@ -36,3 +36,4 @@ pricing:
output: '0.00'
unit: '0.000001'
currency: USD
deprecated: true

+ 10
- 14
api/core/model_runtime/model_providers/google/llm/llm.py Wyświetl plik

@@ -1,7 +1,6 @@
import base64
import io
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast

@@ -36,17 +35,6 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

logger = logging.getLogger(__name__)

GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.

<instructions>
{{instructions}}
</instructions>
""" # noqa: E501


class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(
@@ -155,7 +143,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):

try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})

except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@@ -184,7 +172,15 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
if schema := config_kwargs.pop("json_schema", None):
try:
schema = json.loads(schema)
except:
raise exceptions.InvalidArgument("Invalid JSON Schema")
if tools:
raise exceptions.InvalidArgument("gemini not support use Tools and JSON Schema at same time")
config_kwargs["response_schema"] = schema
config_kwargs["response_mime_type"] = "application/json"

if stop:
config_kwargs["stop_sequences"] = stop

+ 7
- 7
api/tests/integration_tests/model_runtime/google/test_llm.py Wyświetl plik

@@ -31,7 +31,7 @@ def test_invoke_model(setup_google_mock):
model = GoogleLargeLanguageModel()

response = model.invoke(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@@ -48,7 +48,7 @@ def test_invoke_model(setup_google_mock):
]
),
],
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_tokens_to_sample": 2048},
model_parameters={"temperature": 0.5, "top_p": 1.0, "max_output_tokens": 2048},
stop=["How"],
stream=False,
user="abc-123",
@@ -63,7 +63,7 @@ def test_invoke_stream_model(setup_google_mock):
model = GoogleLargeLanguageModel()

response = model.invoke(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@@ -80,7 +80,7 @@ def test_invoke_stream_model(setup_google_mock):
]
),
],
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens_to_sample": 2048},
model_parameters={"temperature": 0.2, "top_k": 5, "max_tokens": 2048},
stream=True,
user="abc-123",
)
@@ -99,7 +99,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock):
model = GoogleLargeLanguageModel()

result = model.invoke(
model="gemini-pro-vision",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
@@ -128,7 +128,7 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock):
model = GoogleLargeLanguageModel()

result = model.invoke(
model="gemini-pro-vision",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(content="You are a helpful AI assistant."),
@@ -164,7 +164,7 @@ def test_get_num_tokens():
model = GoogleLargeLanguageModel()

num_tokens = model.get_num_tokens(
model="gemini-pro",
model="gemini-1.5-pro",
credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")},
prompt_messages=[
SystemPromptMessage(

Ładowanie…
Anuluj
Zapisz