|
|
|
@@ -17,7 +17,13 @@ class WenxinRerank(_CommonWenxin): |
|
|
|
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None): |
|
|
|
access_token = self._get_access_token() |
|
|
|
url = f"{self.api_bases[model]}?access_token={access_token}" |
|
|
|
|
|
|
|
# For issue #11252 |
|
|
|
# for wenxin Rerank model top_n length should be equal or less than docs length |
|
|
|
if top_n is not None and top_n > len(docs): |
|
|
|
top_n = len(docs) |
|
|
|
# for wenxin Rerank model, query should not be an empty string |
|
|
|
if query == "": |
|
|
|
query = " " # FIXME: this is a workaround for wenxin rerank model for better user experience. |
|
|
|
try: |
|
|
|
response = httpx.post( |
|
|
|
url, |
|
|
|
@@ -25,7 +31,11 @@ class WenxinRerank(_CommonWenxin): |
|
|
|
headers={"Content-Type": "application/json"}, |
|
|
|
) |
|
|
|
response.raise_for_status() |
|
|
|
return response.json() |
|
|
|
data = response.json() |
|
|
|
# wenxin error handling |
|
|
|
if "error_code" in data: |
|
|
|
raise InternalServerError(data["error_msg"]) |
|
|
|
return data |
|
|
|
except httpx.HTTPStatusError as e: |
|
|
|
raise InternalServerError(str(e)) |
|
|
|
|
|
|
|
@@ -69,6 +79,9 @@ class WenxinRerankModel(RerankModel): |
|
|
|
results = wenxin_rerank.rerank(model, query, docs, top_n) |
|
|
|
|
|
|
|
rerank_documents = [] |
|
|
|
if "results" not in results: |
|
|
|
raise ValueError("results key not found in response") |
|
|
|
|
|
|
|
for result in results["results"]: |
|
|
|
index = result["index"] |
|
|
|
if "document" in result: |