ソースを参照

Feat: Add model provider Text Embedding Inference for embedding and rerank (#7132)

tags/0.7.0
Yanyi Liu 1年前
コミット
5b32f2e0dd
コミッターのメールアドレスに関連付けられたアカウントが存在しません

+ 0
- 0
api/core/model_runtime/model_providers/huggingface_tei/__init__.py ファイルの表示


+ 11
- 0
api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py ファイルの表示

@@ -0,0 +1,11 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class HuggingfaceTeiProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

+ 36
- 0
api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml ファイルの表示

@@ -0,0 +1,36 @@
provider: huggingface_tei
label:
en_US: Text Embedding Inference
description:
en_US: A blazing fast inference solution for text embeddings models.
zh_Hans: 用于文本嵌入模型的超快速推理解决方案。
background: "#FFF8DC"
help:
title:
en_US: How to deploy Text Embedding Inference
zh_Hans: 如何部署 Text Embedding Inference
url:
en_US: https://github.com/huggingface/text-embeddings-inference
supported_model_types:
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080

+ 0
- 0
api/core/model_runtime/model_providers/huggingface_tei/rerank/__init__.py ファイルの表示


+ 137
- 0
api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py ファイルの表示

@@ -0,0 +1,137 @@
from typing import Optional
import httpx
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
class HuggingfaceTeiRerankModel(RerankModel):
"""
Model class for Text Embedding Inference rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
server_url = credentials['server_url']
if server_url.endswith('/'):
server_url = server_url[:-1]
try:
results = TeiHelper.invoke_rerank(server_url, query, docs)
rerank_documents = []
for result in results:
rerank_document = RerankDocument(
index=result['index'],
text=result['text'],
score=result['score'],
)
if score_threshold is None or result['score'] >= score_threshold:
rerank_documents.append(rerank_document)
if top_n is not None and len(rerank_documents) >= top_n:
break
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
server_url = credentials['server_url']
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
if extra_args.model_type != 'reranker':
raise CredentialsValidateFailedError('Current model is not a rerank model')
credentials['context_size'] = extra_args.max_input_length
self.invoke(
model=model,
credentials=credentials,
query='Whose kasumi',
docs=[
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ',
'and she leads a team named PopiParty.',
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
},
parameter_rules=[],
)
return entity

+ 183
- 0
api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py ファイルの表示

@@ -0,0 +1,183 @@
from threading import Lock
from time import time
from typing import Optional
import httpx
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
from requests.sessions import Session
from yarl import URL
class TeiModelExtraParameter:
model_type: str
max_input_length: int
max_client_batch_size: int
def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None:
self.model_type = model_type
self.max_input_length = max_input_length
self.max_client_batch_size = max_client_batch_size
cache = {}
cache_lock = Lock()
class TeiHelper:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
TeiHelper._clean_cache()
with cache_lock:
if model_name not in cache:
cache[model_name] = {
'expires': time() + 300,
'value': TeiHelper._get_tei_extra_parameter(server_url),
}
return cache[model_name]['value']
@staticmethod
def _clean_cache() -> None:
try:
with cache_lock:
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
for model_uid in expired_keys:
del cache[model_uid]
except RuntimeError as e:
pass
@staticmethod
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
"""
get tei model extra parameter like model_type, max_input_length, max_batch_requests
"""
url = str(URL(server_url) / 'info')
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3))
try:
response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200:
raise RuntimeError(
f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}'
)
response_json = response.json()
model_type = response_json.get('model_type', {})
if len(model_type.keys()) < 1:
raise RuntimeError('model_type is empty')
model_type = list(model_type.keys())[0]
if model_type not in ['embedding', 'reranker']:
raise RuntimeError(f'invalid model_type: {model_type}')
max_input_length = response_json.get('max_input_length', 512)
max_client_batch_size = response_json.get('max_client_batch_size', 1)
return TeiModelExtraParameter(
model_type=model_type,
max_input_length=max_input_length,
max_client_batch_size=max_client_batch_size
)
@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
"""
Invoke tokenize endpoint
Example response:
[
[
{
"id": 0,
"text": "<s>",
"special": true,
"start": null,
"stop": null
},
{
"id": 7704,
"text": "str",
"special": false,
"start": 0,
"stop": 3
},
< MORE TOKENS >
]
]
:param server_url: server url
:param texts: texts to tokenize
"""
resp = httpx.post(
f'{server_url}/tokenize',
json={'inputs': texts},
)
resp.raise_for_status()
return resp.json()
@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
"""
Invoke embeddings endpoint
Example response:
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [...],
"index": 0
}
],
"model": "MODEL_NAME",
"usage": {
"prompt_tokens": 3,
"total_tokens": 3
}
}
:param server_url: server url
:param texts: texts to embed
"""
# Use OpenAI compatible API here, which has usage tracking
resp = httpx.post(
f'{server_url}/v1/embeddings',
json={'input': texts},
)
resp.raise_for_status()
return resp.json()
@staticmethod
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
"""
Invoke rerank endpoint
Example response:
[
{
"index": 0,
"text": "Deep Learning is ...",
"score": 0.9950755
}
]
:param server_url: server url
:param texts: texts to rerank
:param candidates: candidates to rerank
"""
params = {'query': query, 'texts': docs, 'return_text': True}
response = httpx.post(
server_url + '/rerank',
json=params,
)
response.raise_for_status()
return response.json()

+ 0
- 0
api/core/model_runtime/model_providers/huggingface_tei/text_embedding/__init__.py ファイルの表示


+ 204
- 0
api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py ファイルの表示

@@ -0,0 +1,204 @@
import time
from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper
class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
"""
Model class for Text Embedding Inference text embedding model.
"""
def _invoke(
self, model: str, credentials: dict, texts: list[str], user: Optional[str] = None
) -> TextEmbeddingResult:
"""
Invoke text embedding model
credentials should be like:
{
'server_url': 'server url',
'model_uid': 'model uid',
}
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
server_url = credentials['server_url']
if server_url.endswith('/'):
server_url = server_url[:-1]
# get model properties
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
# get tokenized results from TEI
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts)
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
# Check if the number of tokens is larger than the context size
num_tokens = len(tokenize_result)
if num_tokens >= context_size:
# Find the best cutoff point
pre_special_token_count = 0
for token in tokenize_result:
if token['special']:
pre_special_token_count += 1
else:
break
rest_special_token_count = len([token for token in tokenize_result if token['special']]) - pre_special_token_count
# Calculate the cutoff point, leave 20 extra space to avoid exceeding the limit
token_cutoff = context_size - rest_special_token_count - 20
# Find the cutoff index
cutpoint_token = tokenize_result[token_cutoff]
cutoff = cutpoint_token['start']
inputs.append(text[0: cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
try:
used_tokens = 0
for i in _iter:
iter_texts = inputs[i : i + max_chunks]
results = TeiHelper.invoke_embeddings(server_url, iter_texts)
embeddings = results['data']
embeddings = [embedding['embedding'] for embedding in embeddings]
batched_embeddings.extend(embeddings)
usage = results['usage']
used_tokens += usage['total_tokens']
except RuntimeError as e:
raise InvokeServerUnavailableError(str(e))
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
result = TextEmbeddingResult(model=model, embeddings=batched_embeddings, usage=usage)
return result
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
num_tokens = 0
server_url = credentials['server_url']
if server_url.endswith('/'):
server_url = server_url[:-1]
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts)
num_tokens = sum(len(tokens) for tokens in batch_tokens)
return num_tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
server_url = credentials['server_url']
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
print(extra_args)
if extra_args.model_type != 'embedding':
raise CredentialsValidateFailedError('Current model is not a embedding model')
credentials['context_size'] = extra_args.max_input_length
credentials['max_chunks'] = extra_args.max_client_batch_size
self._invoke(model=model, credentials=credentials, texts=['ping'])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError],
}
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at,
)
return usage
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.MAX_CHUNKS: int(credentials.get('max_chunks', 1)),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
},
parameter_rules=[],
)
return entity

+ 2
- 0
api/pyproject.toml ファイルの表示

@@ -93,6 +93,8 @@ CODE_MAX_STRING_LENGTH = "80000"
CODE_EXECUTION_ENDPOINT = "http://127.0.0.1:8194"
CODE_EXECUTION_API_KEY = "dify-sandbox"
FIRECRAWL_API_KEY = "fc-"
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"

[tool.poetry]
name = "dify-api"

+ 94
- 0
api/tests/integration_tests/model_runtime/__mock/huggingface_tei.py ファイルの表示

@@ -0,0 +1,94 @@

from api.core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiModelExtraParameter


class MockTEIClass:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
# During mock, we don't have a real server to query, so we just return a dummy value
if 'rerank' in model_name:
model_type = 'reranker'
else:
model_type = 'embedding'

return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
# Use space as token separator, and split the text into tokens
tokenized_texts = []
for text in texts:
tokens = text.split(' ')
current_index = 0
tokenized_text = []
for idx, token in enumerate(tokens):
s_token = {
'id': idx,
'text': token,
'special': False,
'start': current_index,
'stop': current_index + len(token),
}
current_index += len(token) + 1
tokenized_text.append(s_token)
tokenized_texts.append(tokenized_text)
return tokenized_texts

@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
# {
# "object": "list",
# "data": [
# {
# "object": "embedding",
# "embedding": [...],
# "index": 0
# }
# ],
# "model": "MODEL_NAME",
# "usage": {
# "prompt_tokens": 3,
# "total_tokens": 3
# }
# }
embeddings = []
for idx, text in enumerate(texts):
embedding = [0.1] * 768
embeddings.append(
{
'object': 'embedding',
'embedding': embedding,
'index': idx,
}
)
return {
'object': 'list',
'data': embeddings,
'model': 'MODEL_NAME',
'usage': {
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
'total_tokens': sum(len(text.split(' ')) for text in texts),
},
}

def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
# Example response:
# [
# {
# "index": 0,
# "text": "Deep Learning is ...",
# "score": 0.9950755
# }
# ]
reranked_docs = []
for idx, text in enumerate(texts):
reranked_docs.append(
{
'index': idx,
'text': text,
'score': 0.9,
}
)
# For mock, only return the first document
break
return reranked_docs

+ 0
- 0
api/tests/integration_tests/model_runtime/huggingface_tei/__init__.py ファイルの表示


+ 72
- 0
api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py ファイルの表示

@@ -0,0 +1,72 @@
import os

import pytest
from api.core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper

from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import (
HuggingfaceTeiTextEmbeddingModel,
)
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass

MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'


@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
yield

if MOCK:
monkeypatch.undo()


@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
# model name is only used in mock
model_name = 'embedding'

if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='reranker',
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
}
)

model.validate_credentials(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
}
)

@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
model_name = 'embedding'

result = model.invoke(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
},
texts=[
"hello",
"world"
],
user="abc-123"
)

assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens > 0

+ 76
- 0
api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py ファイルの表示

@@ -0,0 +1,76 @@
import os

import pytest

from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
HuggingfaceTeiRerankModel,
)
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass

MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'


@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
yield

if MOCK:
monkeypatch.undo()

@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = 'reranker'

if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embedding',
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
}
)

model.validate_credentials(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
}
)

@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = 'reranker'

result = model.invoke(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty."
],
score_threshold=0.8
)

assert isinstance(result, RerankResult)
assert len(result.docs) == 1
assert result.docs[0].index == 0
assert result.docs[0].score >= 0.8

読み込み中…
キャンセル
保存