Fix: DataSet.update() now refreshes object data (#8058)

### What problem does this PR solve?

#8057 

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Liu An
2025-06-05 09:26:19 +08:00
committed by GitHub
parent ec60b322ab
commit ab5e3ded68
3 changed files with 25 additions and 14 deletions

View File

@ -14,9 +14,13 @@
# limitations under the License.
#
class Base:
def __init__(self, rag, res_dict):
self.rag = rag
self._update_from_dict(rag, res_dict)
def _update_from_dict(self, rag, res_dict):
for k, v in res_dict.items():
if isinstance(v, dict):
self.__dict__[k] = Base(rag, v)
@ -27,7 +31,7 @@ class Base:
pr = {}
for name in dir(self):
value = getattr(self, name)
if not name.startswith('__') and not callable(value) and name != "rag":
if not name.startswith("__") and not callable(value) and name != "rag":
if isinstance(value, Base):
pr[name] = value.to_json()
else:
@ -35,7 +39,7 @@ class Base:
return pr
def post(self, path, json=None, stream=False, files=None):
res = self.rag.post(path, json, stream=stream,files=files)
res = self.rag.post(path, json, stream=stream, files=files)
return res
def get(self, path, params=None):
@ -46,8 +50,8 @@ class Base:
res = self.rag.delete(path, json)
return res
def put(self,path, json):
res = self.rag.put(path,json)
def put(self, path, json):
res = self.rag.put(path, json)
return res
def __str__(self):

View File

@ -14,9 +14,8 @@
# limitations under the License.
#
from .document import Document
from .base import Base
from .document import Document
class DataSet(Base):
@ -43,12 +42,14 @@ class DataSet(Base):
super().__init__(rag, res_dict)
def update(self, update_message: dict):
res = self.put(f'/datasets/{self.id}',
update_message)
res = self.put(f"/datasets/{self.id}", update_message)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
self._update_from_dict(self.rag, res.get("data", {}))
return self
def upload_documents(self, document_list: list[dict]):
url = f"/datasets/{self.id}/documents"
files = [("file", (ele["display_name"], ele["blob"])) for ele in document_list]
@ -62,11 +63,8 @@ class DataSet(Base):
return doc_list
raise Exception(res.get("message"))
def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30,
orderby: str = "create_time", desc: bool = True):
res = self.get(f"/datasets/{self.id}/documents",
params={"id": id, "keywords": keywords, "page": page, "page_size": page_size, "orderby": orderby,
"desc": desc})
def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True):
res = self.get(f"/datasets/{self.id}/documents", params={"id": id, "keywords": keywords, "page": page, "page_size": page_size, "orderby": orderby, "desc": desc})
res = res.json()
documents = []
if res.get("code") == 0: