diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index f76cf2f9d..d49a2c07c 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -385,7 +385,16 @@ def update(tenant_id, dataset_id): logging.exception(e) return get_error_data_result(message="Database operation failed") - return get_result() + try: + ok, k = KnowledgebaseService.get_by_id(kb.id) + if not ok: + return get_error_data_result(message="Dataset created failed") + except OperationalError as e: + logging.exception(e) + return get_error_data_result(message="Database operation failed") + + response_data = remap_dictionary_keys(k.to_dict()) + return get_result(data=response_data) @manager.route("/datasets", methods=["GET"]) # noqa: F821 diff --git a/sdk/python/ragflow_sdk/modules/base.py b/sdk/python/ragflow_sdk/modules/base.py index 9014dd02d..6b958fb8d 100644 --- a/sdk/python/ragflow_sdk/modules/base.py +++ b/sdk/python/ragflow_sdk/modules/base.py @@ -14,9 +14,13 @@ # limitations under the License. # + class Base: def __init__(self, rag, res_dict): self.rag = rag + self._update_from_dict(rag, res_dict) + + def _update_from_dict(self, rag, res_dict): for k, v in res_dict.items(): if isinstance(v, dict): self.__dict__[k] = Base(rag, v) @@ -27,7 +31,7 @@ class Base: pr = {} for name in dir(self): value = getattr(self, name) - if not name.startswith('__') and not callable(value) and name != "rag": + if not name.startswith("__") and not callable(value) and name != "rag": if isinstance(value, Base): pr[name] = value.to_json() else: @@ -35,7 +39,7 @@ class Base: return pr def post(self, path, json=None, stream=False, files=None): - res = self.rag.post(path, json, stream=stream,files=files) + res = self.rag.post(path, json, stream=stream, files=files) return res def get(self, path, params=None): @@ -46,8 +50,8 @@ class Base: res = self.rag.delete(path, json) return res - def put(self,path, json): - res = self.rag.put(path,json) + def put(self, path, json): + res = self.rag.put(path, json) return res def __str__(self): diff --git a/sdk/python/ragflow_sdk/modules/dataset.py b/sdk/python/ragflow_sdk/modules/dataset.py index fdecde4a5..3d0a83361 100644 --- a/sdk/python/ragflow_sdk/modules/dataset.py +++ b/sdk/python/ragflow_sdk/modules/dataset.py @@ -14,9 +14,8 @@ # limitations under the License. # -from .document import Document - from .base import Base +from .document import Document class DataSet(Base): @@ -43,12 +42,14 @@ class DataSet(Base): super().__init__(rag, res_dict) def update(self, update_message: dict): - res = self.put(f'/datasets/{self.id}', - update_message) + res = self.put(f"/datasets/{self.id}", update_message) res = res.json() if res.get("code") != 0: raise Exception(res["message"]) + self._update_from_dict(self.rag, res.get("data", {})) + return self + def upload_documents(self, document_list: list[dict]): url = f"/datasets/{self.id}/documents" files = [("file", (ele["display_name"], ele["blob"])) for ele in document_list] @@ -62,11 +63,8 @@ class DataSet(Base): return doc_list raise Exception(res.get("message")) - def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30, - orderby: str = "create_time", desc: bool = True): - res = self.get(f"/datasets/{self.id}/documents", - params={"id": id, "keywords": keywords, "page": page, "page_size": page_size, "orderby": orderby, - "desc": desc}) + def list_documents(self, id: str | None = None, keywords: str | None = None, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True): + res = self.get(f"/datasets/{self.id}/documents", params={"id": id, "keywords": keywords, "page": page, "page_size": page_size, "orderby": orderby, "desc": desc}) res = res.json() documents = [] if res.get("code") == 0: