Просмотр исходного кода

Merge commit from fork

* fix(oraclevector): SQL Injection

Signed-off-by: -LAN- <laipz8200@outlook.com>

* fix(oraclevector): Remove bind variables from FETCH FIRST clause

Oracle doesn't support bind variables in the FETCH FIRST clause.
Fixed by using validated integers directly in the SQL string while
maintaining proper input validation to prevent SQL injection.

- Updated search_by_vector method to use validated top_k directly
- Updated search_by_full_text method to use validated top_k directly
- Adjusted parameter numbering for document_ids_filter placeholders

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
tags/1.8.0
-LAN- 2 месяцев назад
Родитель
Сommit
04954918a5
Аккаунт пользователя с таким Email не найден
1 измененных файлов: 32 добавлений и 10 удалений
  1. 32
    10
      api/core/rag/datasource/vdb/oracle/oraclevector.py

+ 32
- 10
api/core/rag/datasource/vdb/oracle/oraclevector.py Просмотреть файл

def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,))
return cur.fetchone() is not None return cur.fetchone() is not None
conn.close() conn.close()


def get_by_ids(self, ids: list[str]) -> list[Document]: def get_by_ids(self, ids: list[str]) -> list[Document]:
if not ids:
return []
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = [] docs = []
for record in cur: for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0])) docs.append(Document(page_content=record[1], metadata=record[0]))
return return
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
placeholders = ", ".join(f":{i + 1}" for i in range(len(ids)))
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids)
conn.commit() conn.commit()
conn.close() conn.close()


def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_connection() as conn: with self._get_connection() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,))
conn.commit() conn.commit()
conn.close() conn.close()


:param top_k: The number of nearest neighbors to return, default is 5. :param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector. :return: List of Documents that are nearest to the query vector.
""" """
# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 4 # Use default if invalid

document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
where_clause = "" where_clause = ""
params = [numpy.array(query_vector)]

if document_ids_filter: if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter)))
where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
params.extend(document_ids_filter)

with self._get_connection() as conn: with self._get_connection() as conn:
conn.inputtypehandler = self.input_type_handler conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler conn.outputtypehandler = self.output_type_handler
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
AS distance FROM {self.table_name} AS distance FROM {self.table_name}
{where_clause} ORDER BY distance fetch first {top_k} rows only""", {where_clause} ORDER BY distance fetch first {top_k} rows only""",
[numpy.array(query_vector)],
params,
) )
docs = [] docs = []
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
import nltk # type: ignore import nltk # type: ignore
from nltk.corpus import stopwords # type: ignore from nltk.corpus import stopwords # type: ignore


# Validate and sanitize top_k to prevent SQL injection
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000:
top_k = 5 # Use default if invalid
# just not implement fetch by score_threshold now, may be later # just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0: if len(query) > 0:
with conn.cursor() as cur: with conn.cursor() as cur:
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
where_clause = "" where_clause = ""
params: dict[str, Any] = {"kk": " ACCUM ".join(entities)}

if document_ids_filter: if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
placeholders = []
for i, doc_id in enumerate(document_ids_filter):
param_name = f"doc_id_{i}"
placeholders.append(f":{param_name}")
params[param_name] = doc_id
where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) "

cur.execute( cur.execute(
f"""select meta, text, embedding FROM {self.table_name} f"""select meta, text, embedding FROM {self.table_name}
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
order by score(1) desc fetch first {top_k} rows only""", order by score(1) desc fetch first {top_k} rows only""",
kk=" ACCUM ".join(entities),
params,
) )
docs = [] docs = []
for record in cur: for record in cur:

Загрузка…
Отмена
Сохранить