Browse Source

Validate returned chunk at list_chunks and add_chunk (#4153)

### What problem does this PR solve?

Validate returned chunk at list_chunks and add_chunk

### Type of change

- [x] Refactoring
tags/v0.15.1
Zhichang Yu 10 months ago
parent
commit
85083ad400
No account linked to committer's email address
2 changed files with 28 additions and 17 deletions
  1. 23
    14
      api/apps/sdk/doc.py
  2. 5
    3
      rag/utils/infinity_conn.py

+ 23
- 14
api/apps/sdk/doc.py View File

from rag.utils import rmSpace from rag.utils import rmSpace
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL


from pydantic import BaseModel, Field, validator

MAXIMUM_OF_UPLOADING_FILES = 256 MAXIMUM_OF_UPLOADING_FILES = 256




class Chunk(BaseModel):
id: str = ""
content: str = ""
document_id: str = ""
docnm_kwd: str = ""
important_keywords: list = Field(default_factory=list)
questions: list = Field(default_factory=list)
question_tks: str = ""
image_id: str = ""
available: bool = True
positions: list[list[int]] = Field(default_factory=list)

@validator('positions')
def validate_positions(cls, value):
for sublist in value:
if len(sublist) != 5:
raise ValueError("Each sublist in positions must have a length of 5")
return value

@manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821 @manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
@token_required @token_required
def upload(dataset_id, tenant_id): def upload(dataset_id, tenant_id):
"available_int": sres.field[id].get("available_int", 1), "available_int": sres.field[id].get("available_int", 1),
"positions": sres.field[id].get("position_int", []), "positions": sres.field[id].get("position_int", []),
} }
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
poss.append(
[
float(d["positions"][i]),
float(d["positions"][i + 1]),
float(d["positions"][i + 2]),
float(d["positions"][i + 3]),
float(d["positions"][i + 4]),
]
)
d["positions"] = poss

origin_chunks.append(d) origin_chunks.append(d)
if req.get("id"): if req.get("id"):
if req.get("id") == id: if req.get("id") == id:
if renamed_chunk["available"] == 1: if renamed_chunk["available"] == 1:
renamed_chunk["available"] = True renamed_chunk["available"] = True
res["chunks"].append(renamed_chunk) res["chunks"].append(renamed_chunk)
_ = Chunk(**renamed_chunk) # validate the chunk
return get_result(data=res) return get_result(data=res)




if key in key_mapping: if key in key_mapping:
new_key = key_mapping.get(key, key) new_key = key_mapping.get(key, key)
renamed_chunk[new_key] = value renamed_chunk[new_key] = value
_ = Chunk(**renamed_chunk) # validate the chunk
return get_result(data={"chunk": renamed_chunk}) return get_result(data={"chunk": renamed_chunk})
# return get_result(data={"chunk_id": chunk_id}) # return get_result(data={"chunk_id": chunk_id})



+ 5
- 3
rag/utils/infinity_conn.py View File

import re import re
import json import json
import time import time
import copy
import infinity import infinity
from infinity.common import ConflictType, InfinityException, SortType from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType from infinity.index import IndexInfo, IndexType
self.createIdx(indexName, knowledgebaseId, vector_size) self.createIdx(indexName, knowledgebaseId, vector_size)
table_instance = db_instance.get_table(table_name) table_instance = db_instance.get_table(table_name)


for d in documents:
docs = copy.deepcopy(documents)
for d in docs:
assert "_id" not in d assert "_id" not in d
assert "id" in d assert "id" in d
for k, v in d.items(): for k, v in d.items():
elif k in ["page_num_int", "top_int"]: elif k in ["page_num_int", "top_int"]:
assert isinstance(v, list) assert isinstance(v, list)
d[k] = "_".join(f"{num:08x}" for num in v) d[k] = "_".join(f"{num:08x}" for num in v)
ids = ["'{}'".format(d["id"]) for d in documents]
ids = ["'{}'".format(d["id"]) for d in docs]
str_ids = ", ".join(ids) str_ids = ", ".join(ids)
str_filter = f"id IN ({str_ids})" str_filter = f"id IN ({str_ids})"
table_instance.delete(str_filter) table_instance.delete(str_filter)
# for doc in documents: # for doc in documents:
# logger.info(f"insert position_int: {doc['position_int']}") # logger.info(f"insert position_int: {doc['position_int']}")
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}") # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
table_instance.insert(documents)
table_instance.insert(docs)
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.") logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
return [] return []

Loading…
Cancel
Save