Parcourir la source

Fix: DataSet.update() now refreshes object data (#8058)

### What problem does this PR solve?

#8057 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
tags/v0.19.1
Liu An il y a 5 mois
Parent
révision
ab5e3ded68
Aucun compte lié à l'adresse e-mail de l'auteur

+ 10
- 1
api/apps/sdk/dataset.py Voir le fichier

@@ -385,7 +385,16 @@ def update(tenant_id, dataset_id):
logging.exception(e)
return get_error_data_result(message="Database operation failed")

return get_result()
try:
ok, k = KnowledgebaseService.get_by_id(kb.id)
if not ok:
return get_error_data_result(message="Dataset created failed")
except OperationalError as e:
logging.exception(e)
return get_error_data_result(message="Database operation failed")

response_data = remap_dictionary_keys(k.to_dict())
return get_result(data=response_data)


@manager.route("/datasets", methods=["GET"]) # noqa: F821

+ 8
- 4
sdk/python/ragflow_sdk/modules/base.py Voir le fichier

@@ -14,9 +14,13 @@
# limitations under the License.
#


class Base:
def __init__(self, rag, res_dict):
self.rag = rag
self._update_from_dict(rag, res_dict)

def _update_from_dict(self, rag, res_dict):
for k, v in res_dict.items():
if isinstance(v, dict):
self.__dict__[k] = Base(rag, v)
@@ -27,7 +31,7 @@ class Base:
pr = {}
for name in dir(self):
value = getattr(self, name)
if not name.startswith('__') and not callable(value) and name != "rag":
if not name.startswith("__") and not callable(value) and name != "rag":
if isinstance(value, Base):
pr[name] = value.to_json()
else:
@@ -35,7 +39,7 @@ class Base:
return pr

def post(self, path, json=None, stream=False, files=None):
res = self.rag.post(path, json, stream=stream,files=files)
res = self.rag.post(path, json, stream=stream, files=files)
return res

def get(self, path, params=None):
@@ -46,8 +50,8 @@ class Base:
res = self.rag.delete(path, json)
return res

def put(self,path, json):
res = self.rag.put(path,json)
def put(self, path, json):
res = self.rag.put(path, json)
return res

def __str__(self):

+ 7
- 9
sdk/python/ragflow_sdk/modules/dataset.py Voir le fichier

@@ -14,9 +14,8 @@
# limitations under the License.
#

from .document import Document

from .base import Base
from .document import Document


class DataSet(Base):
@@ -43,12 +42,14 @@ class DataSet(Base):
super().__init__(rag, res_dict)

def update(self, update_message: dict):
res = self.put(f'/datasets/{self.id}',
update_message)
res = self.put(f"/datasets/{self.id}", update_message)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])

self._update_from_dict(self.rag, res.get("data", {}))
return self

def upload_documents(self, document_list: list[dict]):
url = f"/datasets/{self.id}/documents"
files = [("file", (ele["display_name"], ele["blob"])) for ele in document_list]
@@ -62,11 +63,8 @@ class DataSet(Base):
return doc_list
raise Exception(res.get("message"))

def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30,
orderby: str = "create_time", desc: bool = True):
res = self.get(f"/datasets/{self.id}/documents",
params={"id": id, "keywords": keywords, "page": page, "page_size": page_size, "orderby": orderby,
"desc": desc})
def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True):
res = self.get(f"/datasets/{self.id}/documents", params={"id": id, "keywords": keywords, "page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json()
documents = []
if res.get("code") == 0:

Chargement…
Annuler
Enregistrer