|
|
|
@@ -2,12 +2,12 @@ import array |
|
|
|
import json |
|
|
|
import re |
|
|
|
import uuid |
|
|
|
from contextlib import contextmanager |
|
|
|
from typing import Any |
|
|
|
|
|
|
|
import jieba.posseg as pseg # type: ignore |
|
|
|
import numpy |
|
|
|
import oracledb |
|
|
|
from oracledb.connection import Connection |
|
|
|
from pydantic import BaseModel, model_validator |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
@@ -70,6 +70,7 @@ class OracleVector(BaseVector): |
|
|
|
super().__init__(collection_name) |
|
|
|
self.pool = self._create_connection_pool(config) |
|
|
|
self.table_name = f"embedding_{collection_name}" |
|
|
|
self.config = config |
|
|
|
|
|
|
|
def get_type(self) -> str: |
|
|
|
return VectorType.ORACLE |
|
|
|
@@ -107,16 +108,19 @@ class OracleVector(BaseVector): |
|
|
|
outconverter=self.numpy_converter_out, |
|
|
|
) |
|
|
|
|
|
|
|
def _get_connection(self) -> Connection: |
|
|
|
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn) |
|
|
|
return connection |
|
|
|
|
|
|
|
def _create_connection_pool(self, config: OracleVectorConfig): |
|
|
|
pool_params = { |
|
|
|
"user": config.user, |
|
|
|
"password": config.password, |
|
|
|
"dsn": config.dsn, |
|
|
|
"min": 1, |
|
|
|
"max": 50, |
|
|
|
"max": 5, |
|
|
|
"increment": 1, |
|
|
|
} |
|
|
|
|
|
|
|
if config.is_autonomous: |
|
|
|
pool_params.update( |
|
|
|
{ |
|
|
|
@@ -125,22 +129,8 @@ class OracleVector(BaseVector): |
|
|
|
"wallet_password": config.wallet_password, |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
return oracledb.create_pool(**pool_params) |
|
|
|
|
|
|
|
@contextmanager |
|
|
|
def _get_cursor(self): |
|
|
|
conn = self.pool.acquire() |
|
|
|
conn.inputtypehandler = self.input_type_handler |
|
|
|
conn.outputtypehandler = self.output_type_handler |
|
|
|
cur = conn.cursor() |
|
|
|
try: |
|
|
|
yield cur |
|
|
|
finally: |
|
|
|
cur.close() |
|
|
|
conn.commit() |
|
|
|
conn.close() |
|
|
|
|
|
|
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
dimension = len(embeddings[0]) |
|
|
|
self._create_collection(dimension) |
|
|
|
@@ -162,41 +152,68 @@ class OracleVector(BaseVector): |
|
|
|
numpy.array(embeddings[i]), |
|
|
|
) |
|
|
|
) |
|
|
|
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.executemany( |
|
|
|
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values |
|
|
|
) |
|
|
|
with self._get_connection() as conn: |
|
|
|
conn.inputtypehandler = self.input_type_handler |
|
|
|
conn.outputtypehandler = self.output_type_handler |
|
|
|
# with conn.cursor() as cur: |
|
|
|
# cur.executemany( |
|
|
|
# f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values |
|
|
|
# ) |
|
|
|
# conn.commit() |
|
|
|
for value in values: |
|
|
|
with conn.cursor() as cur: |
|
|
|
try: |
|
|
|
cur.execute( |
|
|
|
f"""INSERT INTO {self.table_name} (id, text, meta, embedding) |
|
|
|
VALUES (:1, :2, :3, :4)""", |
|
|
|
value, |
|
|
|
) |
|
|
|
conn.commit() |
|
|
|
except Exception as e: |
|
|
|
print(e) |
|
|
|
conn.close() |
|
|
|
return pks |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) |
|
|
|
return cur.fetchone() is not None |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) |
|
|
|
return cur.fetchone() is not None |
|
|
|
conn.close() |
|
|
|
|
|
|
|
def get_by_ids(self, ids: list[str]) -> list[Document]: |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) |
|
|
|
docs = [] |
|
|
|
for record in cur: |
|
|
|
docs.append(Document(page_content=record[1], metadata=record[0])) |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) |
|
|
|
docs = [] |
|
|
|
for record in cur: |
|
|
|
docs.append(Document(page_content=record[1], metadata=record[0])) |
|
|
|
self.pool.release(connection=conn) |
|
|
|
conn.close() |
|
|
|
return docs |
|
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None: |
|
|
|
if not ids: |
|
|
|
return |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) |
|
|
|
conn.commit() |
|
|
|
conn.close() |
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str) -> None: |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) |
|
|
|
conn.commit() |
|
|
|
conn.close() |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
""" |
|
|
|
Search the nearest neighbors to a vector. |
|
|
|
|
|
|
|
:param query_vector: The input vector to search for similar items. |
|
|
|
:param top_k: The number of nearest neighbors to return, default is 5. |
|
|
|
:return: List of Documents that are nearest to the query vector. |
|
|
|
""" |
|
|
|
top_k = kwargs.get("top_k", 4) |
|
|
|
@@ -205,20 +222,25 @@ class OracleVector(BaseVector): |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute( |
|
|
|
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" |
|
|
|
f" {where_clause} ORDER BY distance fetch first {top_k} rows only", |
|
|
|
[numpy.array(query_vector)], |
|
|
|
) |
|
|
|
docs = [] |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
for record in cur: |
|
|
|
metadata, text, distance = record |
|
|
|
score = 1 - distance |
|
|
|
metadata["score"] = score |
|
|
|
if score > score_threshold: |
|
|
|
docs.append(Document(page_content=text, metadata=metadata)) |
|
|
|
with self._get_connection() as conn: |
|
|
|
conn.inputtypehandler = self.input_type_handler |
|
|
|
conn.outputtypehandler = self.output_type_handler |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute( |
|
|
|
f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) |
|
|
|
AS distance FROM {self.table_name} |
|
|
|
{where_clause} ORDER BY distance fetch first {top_k} rows only""", |
|
|
|
[numpy.array(query_vector)], |
|
|
|
) |
|
|
|
docs = [] |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
for record in cur: |
|
|
|
metadata, text, distance = record |
|
|
|
score = 1 - distance |
|
|
|
metadata["score"] = score |
|
|
|
if score > score_threshold: |
|
|
|
docs.append(Document(page_content=text, metadata=metadata)) |
|
|
|
conn.close() |
|
|
|
return docs |
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
@@ -228,7 +250,7 @@ class OracleVector(BaseVector): |
|
|
|
|
|
|
|
top_k = kwargs.get("top_k", 5) |
|
|
|
# just not implement fetch by score_threshold now, may be later |
|
|
|
# score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
score_threshold = float(kwargs.get("score_threshold") or 0.0) |
|
|
|
if len(query) > 0: |
|
|
|
# Check which language the query is in |
|
|
|
zh_pattern = re.compile("[\u4e00-\u9fa5]+") |
|
|
|
@@ -239,7 +261,7 @@ class OracleVector(BaseVector): |
|
|
|
words = pseg.cut(query) |
|
|
|
current_entity = "" |
|
|
|
for word, pos in words: |
|
|
|
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名 |
|
|
|
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 |
|
|
|
current_entity += word |
|
|
|
else: |
|
|
|
if current_entity: |
|
|
|
@@ -260,30 +282,35 @@ class OracleVector(BaseVector): |
|
|
|
for token in all_tokens: |
|
|
|
if token not in stop_words: |
|
|
|
entities.append(token) |
|
|
|
with self._get_cursor() as cur: |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
where_clause = "" |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) " |
|
|
|
cur.execute( |
|
|
|
f"select meta, text, embedding FROM {self.table_name}" |
|
|
|
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} " |
|
|
|
f"order by score(1) desc fetch first {top_k} rows only", |
|
|
|
[" ACCUM ".join(entities)], |
|
|
|
) |
|
|
|
docs = [] |
|
|
|
for record in cur: |
|
|
|
metadata, text, embedding = record |
|
|
|
docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
where_clause = "" |
|
|
|
if document_ids_filter: |
|
|
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) |
|
|
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) " |
|
|
|
cur.execute( |
|
|
|
f"""select meta, text, embedding FROM {self.table_name} |
|
|
|
WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} |
|
|
|
order by score(1) desc fetch first {top_k} rows only""", |
|
|
|
kk=" ACCUM ".join(entities), |
|
|
|
) |
|
|
|
docs = [] |
|
|
|
for record in cur: |
|
|
|
metadata, text, embedding = record |
|
|
|
docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) |
|
|
|
conn.close() |
|
|
|
return docs |
|
|
|
else: |
|
|
|
return [Document(page_content="", metadata={})] |
|
|
|
return [] |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") |
|
|
|
conn.commit() |
|
|
|
conn.close() |
|
|
|
|
|
|
|
def _create_collection(self, dimension: int): |
|
|
|
cache_key = f"vector_indexing_{self._collection_name}" |
|
|
|
@@ -293,11 +320,14 @@ class OracleVector(BaseVector): |
|
|
|
if redis_client.get(collection_exist_cache_key): |
|
|
|
return |
|
|
|
|
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name)) |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) |
|
|
|
with self._get_connection() as conn: |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name)) |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
with conn.cursor() as cur: |
|
|
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) |
|
|
|
conn.commit() |
|
|
|
conn.close() |
|
|
|
|
|
|
|
|
|
|
|
class OracleVectorFactory(AbstractVectorFactory): |