| @@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = kwargs.get("score_threshold") or 0.0 | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| @@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector): | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| @@ -92,7 +92,7 @@ class ChromaVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| collection = self._client.get_or_create_collection(self._collection_name) | |||
| results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| ids: list[str] = results["ids"][0] | |||
| documents: list[str] = results["documents"][0] | |||
| @@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector): | |||
| docs = [] | |||
| for doc, score in docs_and_scores: | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if score > score_threshold: | |||
| doc.metadata["score"] = score | |||
| docs.append(doc) | |||
| @@ -141,7 +141,7 @@ class MilvusVector(BaseVector): | |||
| for result in results[0]: | |||
| metadata = result["entity"].get(Field.METADATA_KEY.value) | |||
| metadata["score"] = result["distance"] | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if result["distance"] > score_threshold: | |||
| doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) | |||
| docs.append(doc) | |||
| @@ -122,7 +122,7 @@ class MyScaleVector(BaseVector): | |||
| def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 5) | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| where_str = ( | |||
| f"WHERE dist < {1 - score_threshold}" | |||
| if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 | |||
| @@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector): | |||
| metadata = {} | |||
| metadata["score"] = hit["_score"] | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if hit["_score"] > score_threshold: | |||
| doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) | |||
| docs.append(doc) | |||
| @@ -200,7 +200,7 @@ class OracleVector(BaseVector): | |||
| [numpy.array(query_vector)], | |||
| ) | |||
| docs = [] | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| for record in cur: | |||
| metadata, text, distance = record | |||
| score = 1 - distance | |||
| @@ -212,7 +212,7 @@ class OracleVector(BaseVector): | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 5) | |||
| # just not implement fetch by score_threshold now, may be later | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if len(query) > 0: | |||
| # Check which language the query is in | |||
| zh_pattern = re.compile("[\u4e00-\u9fa5]+") | |||
| @@ -198,7 +198,7 @@ class PGVectoRS(BaseVector): | |||
| metadata = record.meta | |||
| score = 1 - dis | |||
| metadata["score"] = score | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if score > score_threshold: | |||
| doc = Document(page_content=record.text, metadata=metadata) | |||
| docs.append(doc) | |||
| @@ -144,7 +144,7 @@ class PGVector(BaseVector): | |||
| (json.dumps(query_vector),), | |||
| ) | |||
| docs = [] | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| for record in cur: | |||
| metadata, text, distance = record | |||
| score = 1 - distance | |||
| @@ -333,13 +333,13 @@ class QdrantVector(BaseVector): | |||
| limit=kwargs.get("top_k", 4), | |||
| with_payload=True, | |||
| with_vectors=True, | |||
| score_threshold=kwargs.get("score_threshold", 0.0), | |||
| score_threshold=float(kwargs.get("score_threshold") or 0.0), | |||
| ) | |||
| docs = [] | |||
| for result in results: | |||
| metadata = result.payload.get(Field.METADATA_KEY.value) or {} | |||
| # duplicate check score threshold | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if result.score > score_threshold: | |||
| metadata["score"] = result.score | |||
| doc = Document( | |||
| @@ -230,7 +230,7 @@ class RelytVector(BaseVector): | |||
| # Organize results. | |||
| docs = [] | |||
| for document, score in results: | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if 1 - score > score_threshold: | |||
| docs.append(document) | |||
| return docs | |||
| @@ -153,7 +153,7 @@ class TencentVector(BaseVector): | |||
| limit=kwargs.get("top_k", 4), | |||
| timeout=self._client_config.timeout, | |||
| ) | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| return self._get_search_res(res, score_threshold) | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| @@ -185,7 +185,7 @@ class TiDBVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 5) | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| filter = kwargs.get("filter") | |||
| distance = 1 - score_threshold | |||
| @@ -205,7 +205,7 @@ class WeaviateVector(BaseVector): | |||
| docs = [] | |||
| for doc, score in docs_and_scores: | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| # check score threshold | |||
| if score > score_threshold: | |||
| doc.metadata["score"] = score | |||