浏览代码

Validate returned chunk at list_chunks and add_chunk (#4153)

### What problem does this PR solve?

Validate returned chunk at list_chunks and add_chunk

### Type of change

- [x] Refactoring
tags/v0.15.1
Zhichang Yu 10 个月前
父节点
当前提交
85083ad400
没有帐户链接到提交者的电子邮件
共有 2 个文件被更改,包括 28 次插入17 次删除
  1. 23
    14
      api/apps/sdk/doc.py
  2. 5
    3
      rag/utils/infinity_conn.py

+ 23
- 14
api/apps/sdk/doc.py 查看文件

@@ -42,9 +42,30 @@ from rag.nlp import search
from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL

from pydantic import BaseModel, Field, validator

MAXIMUM_OF_UPLOADING_FILES = 256


class Chunk(BaseModel):
id: str = ""
content: str = ""
document_id: str = ""
docnm_kwd: str = ""
important_keywords: list = Field(default_factory=list)
questions: list = Field(default_factory=list)
question_tks: str = ""
image_id: str = ""
available: bool = True
positions: list[list[int]] = Field(default_factory=list)

@validator('positions')
def validate_positions(cls, value):
for sublist in value:
if len(sublist) != 5:
raise ValueError("Each sublist in positions must have a length of 5")
return value

@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
@token_required
def upload(dataset_id, tenant_id):
@@ -848,20 +869,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
"available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", []),
}
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append(
[
float(d["positions"][i]),
float(d["positions"][i + 1]),
float(d["positions"][i + 2]),
float(d["positions"][i + 3]),
float(d["positions"][i + 4]),
]
)
d["positions"] = poss

origin_chunks.append(d)
if req.get("id"):
if req.get("id") == id:
@@ -892,6 +899,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
if renamed_chunk["available"] == 1:
renamed_chunk["available"] = True
res["chunks"].append(renamed_chunk)
_ = Chunk(**renamed_chunk) # validate the chunk
return get_result(data=res)


@@ -1031,6 +1039,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
if key in key_mapping:
new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value
_ = Chunk(**renamed_chunk) # validate the chunk
return get_result(data={"chunk": renamed_chunk})
# return get_result(data={"chunk_id": chunk_id})


+ 5
- 3
rag/utils/infinity_conn.py 查看文件

@@ -3,6 +3,7 @@ import os
import re
import json
import time
import copy
import infinity
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
@@ -390,7 +391,8 @@ class InfinityConnection(DocStoreConnection):
self.createIdx(indexName, knowledgebaseId, vector_size)
table_instance = db_instance.get_table(table_name)

for d in documents:
docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d
assert "id" in d
for k, v in d.items():
@@ -407,14 +409,14 @@ class InfinityConnection(DocStoreConnection):
elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list)
d[k] = "_".join(f"{num:08x}" for num in v)
ids = ["'{}'".format(d["id"]) for d in documents]
ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter)
# for doc in documents:
# logger.info(f"insert position_int: {doc['position_int']}")
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(documents)
table_instance.insert(docs)
self.connPool.release_conn(inf_conn)
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return []

正在加载...
取消
保存