|
|
|
@@ -188,14 +188,17 @@ class OracleVector(BaseVector): |
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) |
|
|
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,)) |
|
|
|
return cur.fetchone() is not None |
|
|
|
conn.close() |
|
|
|
|
|
|
|
def get_by_ids(self, ids: list[str]) -> list[Document]: |
|
|
|
if not ids: |
|
|
|
return [] |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) |
|
|
|
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) |
|
|
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids) |
|
|
|
docs = [] |
|
|
|
for record in cur: |
|
|
|
docs.append(Document(page_content=record[1], metadata=record[0])) |
|
|
|
@@ -208,14 +211,15 @@ class OracleVector(BaseVector): |
|
|
|
return |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) |
|
|
|
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) |
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids) |
|
|
|
conn.commit() |
|
|
|
conn.close() |
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str) -> None: |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) |
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,)) |
|
|
|
conn.commit() |
|
|
|
conn.close() |
|
|
|
|
|
|
|
@@ -227,12 +231,20 @@ class OracleVector(BaseVector): |
|
|
|
:param top_k: The number of nearest neighbors to return, default is 5. |
|
|
|
:return: List of Documents that are nearest to the query vector. |
|
|
|
""" |
|
|
|
# Validate and sanitize top_k to prevent SQL injection |
|
|
|
top_k = kwargs.get("top_k", 4) |
|
|
|
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: |
|
|
|
top_k = 4 # Use default if invalid |
|
|
|
|
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
where_clause = "" |
|
|
|
params = [numpy.array(query_vector)] |
|
|
|
|
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" |
|
|
|
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter))) |
|
|
|
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})" |
|
|
|
params.extend(document_ids_filter) |
|
|
|
|
|
|
|
with self._get_connection() as conn: |
|
|
|
conn.inputtypehandler = self.input_type_handler |
|
|
|
conn.outputtypehandler = self.output_type_handler |
|
|
|
@@ -241,7 +253,7 @@ class OracleVector(BaseVector): |
|
|
|
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) |
|
|
|
AS distance FROM {self.table_name} |
|
|
|
{where_clause} ORDER BY distance fetch first {top_k} rows only""", |
|
|
|
[numpy.array(query_vector)], |
|
|
|
params, |
|
|
|
) |
|
|
|
docs = [] |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
@@ -259,7 +271,10 @@ class OracleVector(BaseVector): |
|
|
|
import nltk # type: ignore |
|
|
|
from nltk.corpus import stopwords # type: ignore |
|
|
|
|
|
|
|
# Validate and sanitize top_k to prevent SQL injection |
|
|
|
top_k = kwargs.get("top_k", 5) |
|
|
|
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: |
|
|
|
top_k = 5 # Use default if invalid |
|
|
|
# just not implement fetch by score_threshold now, may be later |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
if len(query) > 0: |
|
|
|
@@ -297,14 +312,21 @@ class OracleVector(BaseVector): |
|
|
|
with conn.cursor() as cur: |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
where_clause = "" |
|
|
|
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)} |
|
|
|
|
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) " |
|
|
|
placeholders = [] |
|
|
|
for i, doc_id in enumerate(document_ids_filter): |
|
|
|
param_name = f"doc_id_{i}" |
|
|
|
placeholders.append(f":{param_name}") |
|
|
|
params[param_name] = doc_id |
|
|
|
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) " |
|
|
|
|
|
|
|
cur.execute( |
|
|
|
f"""select meta, text, embedding FROM {self.table_name} |
|
|
|
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} |
|
|
|
order by score(1) desc fetch first {top_k} rows only""", |
|
|
|
kk=" ACCUM ".join(entities), |
|
|
|
params, |
|
|
|
) |
|
|
|
docs = [] |
|
|
|
for record in cur: |