### What problem does this PR solve? #8531 #7417 #6761 #6573 #6477 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.20.2
@@ -51,6 +51,7 @@ def set_dialog(): | |||
similarity_threshold = req.get("similarity_threshold", 0.1) | |||
vector_similarity_weight = req.get("vector_similarity_weight", 0.3) | |||
llm_setting = req.get("llm_setting", {}) | |||
meta_data_filter = req.get("meta_data_filter", {}) | |||
prompt_config = req["prompt_config"] | |||
if not is_create: | |||
@@ -85,6 +86,7 @@ def set_dialog(): | |||
"llm_id": llm_id, | |||
"llm_setting": llm_setting, | |||
"prompt_config": prompt_config, | |||
"meta_data_filter": meta_data_filter, | |||
"top_n": top_n, | |||
"top_k": top_k, | |||
"rerank_id": rerank_id, |
@@ -681,6 +681,11 @@ def set_meta(): | |||
return get_json_result(data=False, message="No authorization.", code=settings.RetCode.AUTHENTICATION_ERROR) | |||
try: | |||
meta = json.loads(req["meta"]) | |||
if not isinstance(meta, dict): | |||
return get_json_result(data=False, message="Only dictionary type supported.", code=settings.RetCode.ARGUMENT_ERROR) | |||
for k,v in meta.items(): | |||
if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float): | |||
return get_json_result(data=False, message=f"The type is not supported: {v}", code=settings.RetCode.ARGUMENT_ERROR) | |||
except Exception as e: | |||
return get_json_result(data=False, message=f"Json syntax error: {e}", code=settings.RetCode.ARGUMENT_ERROR) | |||
if not isinstance(meta, dict): |
@@ -351,6 +351,7 @@ def knowledge_graph(kb_id): | |||
obj["graph"]["edges"] = sorted(filtered_edges, key=lambda x: x.get("weight", 0), reverse=True)[:128] | |||
return get_json_result(data=obj) | |||
@manager.route('/<kb_id>/knowledge_graph', methods=['DELETE']) # noqa: F821 | |||
@login_required | |||
def delete_knowledge_graph(kb_id): | |||
@@ -364,3 +365,17 @@ def delete_knowledge_graph(kb_id): | |||
settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph", "entity", "relation"]}, search.index_name(kb.tenant_id), kb_id) | |||
return get_json_result(data=True) | |||
@manager.route("/get_meta", methods=["GET"]) # noqa: F821 | |||
@login_required | |||
def get_meta(): | |||
kb_ids = request.args.get("kb_ids", "").split(",") | |||
for kb_id in kb_ids: | |||
if not KnowledgebaseService.accessible(kb_id, current_user.id): | |||
return get_json_result( | |||
data=False, | |||
message='No authorization.', | |||
code=settings.RetCode.AUTHENTICATION_ERROR | |||
) | |||
return get_json_result(data=DocumentService.get_meta_by_kbs(kb_ids)) |
@@ -744,6 +744,7 @@ class Dialog(DataBaseModel): | |||
null=False, | |||
default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?", "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"}, | |||
) | |||
meta_data_filter = JSONField(null=True, default={}) | |||
similarity_threshold = FloatField(default=0.2) | |||
vector_similarity_weight = FloatField(default=0.3) | |||
@@ -1015,4 +1016,8 @@ def migrate_db(): | |||
migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors"))) | |||
except Exception: | |||
pass | |||
try: | |||
migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={}))) | |||
except Exception: | |||
pass | |||
logging.disable(logging.NOTSET) |
@@ -30,6 +30,7 @@ from api import settings | |||
from api.db import LLMType, ParserType, StatusEnum | |||
from api.db.db_models import DB, Dialog | |||
from api.db.services.common_service import CommonService | |||
from api.db.services.document_service import DocumentService | |||
from api.db.services.knowledgebase_service import KnowledgebaseService | |||
from api.db.services.langfuse_service import TenantLangfuseService | |||
from api.db.services.llm_service import LLMBundle, TenantLLMService | |||
@@ -38,6 +39,7 @@ from rag.app.resume import forbidden_select_fields4resume | |||
from rag.app.tag import label_question | |||
from rag.nlp.search import index_name | |||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in | |||
from rag.prompts.prompts import gen_meta_filter | |||
from rag.utils import num_tokens_from_string, rmSpace | |||
from rag.utils.tavily_conn import Tavily | |||
@@ -250,6 +252,46 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): | |||
return answer, idx | |||
def meta_filter(metas: dict, filters: list[dict]): | |||
doc_ids = [] | |||
def filter_out(v2docs, operator, value): | |||
nonlocal doc_ids | |||
for input,docids in v2docs.items(): | |||
try: | |||
input = float(input) | |||
value = float(value) | |||
except Exception: | |||
input = str(input) | |||
value = str(value) | |||
for conds in [ | |||
(operator == "contains", str(value).lower() in str(input).lower()), | |||
(operator == "not contains", str(value).lower() not in str(input).lower()), | |||
(operator == "start with", str(input).lower().startswith(str(value).lower())), | |||
(operator == "end with", str(input).lower().endswith(str(value).lower())), | |||
(operator == "empty", not input), | |||
(operator == "not empty", input), | |||
(operator == "=", input == value), | |||
(operator == "≠", input != value), | |||
(operator == ">", input > value), | |||
(operator == "<", input < value), | |||
(operator == "≥", input >= value), | |||
(operator == "≤", input <= value), | |||
]: | |||
try: | |||
if all(conds): | |||
doc_ids.extend(docids) | |||
except Exception: | |||
pass | |||
for k, v2docs in metas.items(): | |||
for f in filters: | |||
if k != f["key"]: | |||
continue | |||
filter_out(v2docs, f["op"], f["value"]) | |||
return doc_ids | |||
def chat(dialog, messages, stream=True, **kwargs): | |||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." | |||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): | |||
@@ -287,9 +329,10 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
retriever = settings.retrievaler | |||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:] | |||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None | |||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] | |||
if "doc_ids" in messages[-1]: | |||
attachments = messages[-1]["doc_ids"] | |||
prompt_config = dialog.prompt_config | |||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) | |||
# try to use sql if field mapping is good to go | |||
@@ -316,6 +359,14 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
if prompt_config.get("cross_languages"): | |||
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] | |||
if dialog.meta_data_filter: | |||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) | |||
if dialog.meta_data_filter.get("method") == "auto": | |||
filters = gen_meta_filter(chat_mdl, metas, questions[-1]) | |||
attachments.extend(meta_filter(metas, filters)) | |||
elif dialog.meta_data_filter.get("method") == "manual": | |||
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"])) | |||
if prompt_config.get("keyword", False): | |||
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) | |||
@@ -574,6 +574,25 @@ class DocumentService(CommonService): | |||
def update_meta_fields(cls, doc_id, meta_fields): | |||
return cls.update_by_id(doc_id, {"meta_fields": meta_fields}) | |||
@classmethod | |||
@DB.connection_context() | |||
def get_meta_by_kbs(cls, kb_ids): | |||
fields = [ | |||
cls.model.id, | |||
cls.model.meta_fields, | |||
] | |||
meta = {} | |||
for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)): | |||
doc_id = r.id | |||
for k,v in r.meta_fields.items(): | |||
if k not in meta: | |||
meta[k] = {} | |||
v = str(v) | |||
if v not in meta[k]: | |||
meta[k][v] = [] | |||
meta[k][v].append(doc_id) | |||
return meta | |||
@classmethod | |||
@DB.connection_context() | |||
def update_progress(cls): |
@@ -383,8 +383,6 @@ class Dealer: | |||
vector_column = f"q_{dim}_vec" | |||
zero_vector = [0.0] * dim | |||
sim_np = np.array(sim) | |||
if doc_ids: | |||
similarity_threshold = 0 | |||
filtered_count = (sim_np >= similarity_threshold).sum() | |||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error | |||
for i in idx: |
@@ -0,0 +1,53 @@ | |||
You are a metadata filtering condition generator. Analyze the user's question and available document metadata to output a JSON array of filter objects. Follow these rules: | |||
1. **Metadata Structure**: | |||
- Metadata is provided as JSON where keys are attribute names (e.g., "color"), and values are objects mapping attribute values to document IDs. | |||
- Example: | |||
{ | |||
"color": {"red": ["doc1"], "blue": ["doc2"]}, | |||
"listing_date": {"2025-07-11": ["doc1"], "2025-08-01": ["doc2"]} | |||
} | |||
2. **Output Requirements**: | |||
- Always output a JSON array of filter objects | |||
- Each object must have: | |||
"key": (metadata attribute name), | |||
"value": (string value to compare), | |||
"op": (operator from allowed list) | |||
3. **Operator Guide**: | |||
- Use these operators only: ["contains", "not contains", "start with", "end with", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤"] | |||
- Date ranges: Break into two conditions (≥ start_date AND < next_month_start) | |||
- Negations: Always use "≠" for exclusion terms ("not", "except", "exclude", "≠") | |||
- Implicit logic: Derive unstated filters (e.g., "July" → [≥ YYYY-07-01, < YYYY-08-01]) | |||
4. **Processing Steps**: | |||
a) Identify ALL filterable attributes in the query (both explicit and implicit) | |||
b) For dates: | |||
- Infer missing year from current date if needed | |||
- Always format dates as "YYYY-MM-DD" | |||
- Convert ranges: [≥ start, < end] | |||
c) For values: Match EXACTLY to metadata's value keys | |||
d) Skip conditions if: | |||
- Attribute doesn't exist in metadata | |||
- Value has no match in metadata | |||
5. **Example**: | |||
- User query: "上市日期七月份的有哪些商品,不要蓝色的" | |||
- Metadata: { "color": {...}, "listing_date": {...} } | |||
- Output: | |||
[ | |||
{"key": "listing_date", "value": "2025-07-01", "op": "≥"}, | |||
{"key": "listing_date", "value": "2025-08-01", "op": "<"}, | |||
{"key": "color", "value": "blue", "op": "≠"} | |||
] | |||
6. **Final Output**: | |||
- ONLY output valid JSON array | |||
- NO additional text/explanations | |||
**Current Task**: | |||
- Today's date: {{current_date}} | |||
- Available metadata keys: {{metadata_keys}} | |||
- User query: "{{user_question}}" | |||
@@ -149,6 +149,7 @@ NEXT_STEP = load_prompt("next_step") | |||
REFLECT = load_prompt("reflect") | |||
SUMMARY4MEMORY = load_prompt("summary4memory") | |||
RANK_MEMORY = load_prompt("rank_memory") | |||
META_FILTER = load_prompt("meta_filter") | |||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True) | |||
@@ -413,3 +414,20 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st | |||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>") | |||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) | |||
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> list: | |||
sys_prompt = PROMPT_JINJA_ENV.from_string(META_FILTER).render( | |||
current_date=datetime.datetime.today().strftime('%Y-%m-%d'), | |||
metadata_keys=json.dumps(meta_data), | |||
user_question=query | |||
) | |||
user_prompt = "Generate filters:" | |||
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}]) | |||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL) | |||
try: | |||
ans = json_repair.loads(ans) | |||
assert isinstance(ans, list), ans | |||
return ans | |||
except Exception: | |||
logging.exception(f"Loading json failure: {ans}") | |||
return [] |
@@ -444,7 +444,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None): | |||
tts = np.concatenate([vts for _ in range(len(tts))], axis=0) | |||
tk_count += c | |||
@timeout(5) | |||
@timeout(60) | |||
def batch_encode(txts): | |||
nonlocal mdl | |||
return mdl.encode([truncate(c, mdl.max_length-10) for c in txts]) |
@@ -190,3 +190,17 @@ class RAGFlowS3: | |||
self.__open__() | |||
time.sleep(1) | |||
return | |||
@use_prefix_path | |||
@use_default_bucket | |||
def rm_bucket(self, bucket, *args, **kwargs): | |||
for conn in self.conn: | |||
try: | |||
if not conn.bucket_exists(bucket): | |||
continue | |||
for o in conn.list_objects_v2(Bucket=bucket): | |||
conn.delete_object(bucket, o.object_name) | |||
conn.delete_bucket(Bucket=bucket) | |||
return | |||
except Exception as e: | |||
logging.error(f"Fail rm {bucket}: " + str(e)) |