|
|
|
@@ -1,11 +1,15 @@ |
|
|
|
import array |
|
|
|
import json |
|
|
|
import re |
|
|
|
import uuid |
|
|
|
from contextlib import contextmanager |
|
|
|
from typing import Any |
|
|
|
|
|
|
|
import jieba.posseg as pseg |
|
|
|
import nltk |
|
|
|
import numpy |
|
|
|
import oracledb |
|
|
|
from nltk.corpus import stopwords |
|
|
|
from pydantic import BaseModel, model_validator |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
@@ -50,6 +54,11 @@ CREATE TABLE IF NOT EXISTS {table_name} ( |
|
|
|
,embedding vector NOT NULL |
|
|
|
) |
|
|
|
""" |
|
|
|
SQL_CREATE_INDEX = """ |
|
|
|
CREATE INDEX idx_docs_{table_name} ON {table_name}(text) |
|
|
|
INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS |
|
|
|
('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER sys.my_chinese_vgram_lexer') |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
class OracleVector(BaseVector): |
|
|
|
@@ -188,7 +197,53 @@ class OracleVector(BaseVector): |
|
|
|
return docs |
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
# do not support bm25 search |
|
|
|
top_k = kwargs.get("top_k", 5) |
|
|
|
# just not implement fetch by score_threshold now, may be later |
|
|
|
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 |
|
|
|
if len(query) > 0: |
|
|
|
# Check which language the query is in |
|
|
|
zh_pattern = re.compile('[\u4e00-\u9fa5]+') |
|
|
|
match = zh_pattern.search(query) |
|
|
|
entities = [] |
|
|
|
# match: query condition maybe is a chinese sentence, so using Jieba split,else using nltk split |
|
|
|
if match: |
|
|
|
words = pseg.cut(query) |
|
|
|
current_entity = "" |
|
|
|
for word, pos in words: |
|
|
|
if pos == 'nr' or pos == 'Ng' or pos == 'eng' or pos == 'nz' or pos == 'n' or pos == 'ORG' or pos == 'v': # nr: 人名, ns: 地名, nt: 机构名 |
|
|
|
current_entity += word |
|
|
|
else: |
|
|
|
if current_entity: |
|
|
|
entities.append(current_entity) |
|
|
|
current_entity = "" |
|
|
|
if current_entity: |
|
|
|
entities.append(current_entity) |
|
|
|
else: |
|
|
|
try: |
|
|
|
nltk.data.find('tokenizers/punkt') |
|
|
|
nltk.data.find('corpora/stopwords') |
|
|
|
except LookupError: |
|
|
|
nltk.download('punkt') |
|
|
|
nltk.download('stopwords') |
|
|
|
print("run download") |
|
|
|
e_str = re.sub(r'[^\w ]', '', query) |
|
|
|
all_tokens = nltk.word_tokenize(e_str) |
|
|
|
stop_words = stopwords.words('english') |
|
|
|
for token in all_tokens: |
|
|
|
if token not in stop_words: |
|
|
|
entities.append(token) |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute( |
|
|
|
f"select meta, text FROM {self.table_name} WHERE CONTAINS(text, :1, 1) > 0 order by score(1) desc fetch first {top_k} rows only", |
|
|
|
[" ACCUM ".join(entities)] |
|
|
|
) |
|
|
|
docs = [] |
|
|
|
for record in cur: |
|
|
|
metadata, text = record |
|
|
|
docs.append(Document(page_content=text, metadata=metadata)) |
|
|
|
return docs |
|
|
|
else: |
|
|
|
return [Document(page_content="", metadata="")] |
|
|
|
return [] |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
@@ -206,6 +261,8 @@ class OracleVector(BaseVector): |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name)) |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) |
|
|
|
|
|
|
|
|
|
|
|
class OracleVectorFactory(AbstractVectorFactory): |