|
|
|
@@ -1,14 +1,15 @@ |
|
|
|
import logging |
|
|
|
from typing import Optional, List |
|
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
import cohere |
|
|
|
import openai |
|
|
|
from langchain.schema import Document |
|
|
|
|
|
|
|
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ |
|
|
|
LLMRateLimitError, LLMAuthorizationError |
|
|
|
from core.model_providers.error import (LLMAPIConnectionError, |
|
|
|
LLMAPIUnavailableError, |
|
|
|
LLMAuthorizationError, |
|
|
|
LLMBadRequestError, LLMRateLimitError) |
|
|
|
from core.model_providers.models.reranking.base import BaseReranking |
|
|
|
from core.model_providers.providers.base import BaseModelProvider |
|
|
|
from langchain.schema import Document |
|
|
|
|
|
|
|
|
|
|
|
class CohereReranking(BaseReranking): |
|
|
|
@@ -26,10 +27,14 @@ class CohereReranking(BaseReranking): |
|
|
|
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: |
|
|
|
docs = [] |
|
|
|
doc_id = [] |
|
|
|
unique_documents = [] |
|
|
|
for document in documents: |
|
|
|
if document.metadata['doc_id'] not in doc_id: |
|
|
|
doc_id.append(document.metadata['doc_id']) |
|
|
|
docs.append(document.page_content) |
|
|
|
unique_documents.append(document) |
|
|
|
documents = unique_documents |
|
|
|
|
|
|
|
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k) |
|
|
|
rerank_documents = [] |
|
|
|
|