Refa: http API create dataset and test cases (#7393)

### What problem does this PR solve?

This PR introduces Pydantic-based validation for the create dataset HTTP
API, improving code clarity and robustness. Key changes include:
1. Pydantic Validation
2. ​​Error Handling
3. Test Updates
4. Documentation

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Documentation Update
- [x] Refactoring
This commit is contained in:
liu an
2025-04-29 16:53:57 +08:00
committed by GitHub
parent c88e4b3fc0
commit 78380fa181
11 changed files with 1239 additions and 812 deletions

View File

@ -13,16 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import requests
from .modules.agent import Agent
from .modules.chat import Chat
from .modules.chunk import Chunk
from .modules.dataset import DataSet
from .modules.agent import Agent
class RAGFlow:
def __init__(self, api_key, base_url, version='v1'):
def __init__(self, api_key, base_url, version="v1"):
"""
api_url: http://<host_address>/api/v1
"""
@ -31,11 +33,11 @@ class RAGFlow:
self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)}
def post(self, path, json=None, stream=False, files=None):
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream,files=files)
res = requests.post(url=self.api_url + path, json=json, headers=self.authorization_header, stream=stream, files=files)
return res
def get(self, path, params=None, json=None):
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header,json=json)
res = requests.get(url=self.api_url + path, params=params, headers=self.authorization_header, json=json)
return res
def delete(self, path, json):
@ -43,54 +45,73 @@ class RAGFlow:
return res
def put(self, path, json):
res = requests.put(url=self.api_url + path, json= json,headers=self.authorization_header)
res = requests.put(url=self.api_url + path, json=json, headers=self.authorization_header)
return res
def create_dataset(self, name: str, avatar: str = "", description: str = "", embedding_model:str = "BAAI/bge-large-zh-v1.5",
language: str = "English",
permission: str = "me",chunk_method: str = "naive",
parser_config: DataSet.ParserConfig = None) -> DataSet:
def create_dataset(
self,
name: str,
avatar: Optional[str] = None,
description: Optional[str] = None,
embedding_model: Optional[str] = "BAAI/bge-large-zh-v1.5@BAAI",
permission: str = "me",
chunk_method: str = "naive",
pagerank: int = 0,
parser_config: DataSet.ParserConfig = None,
) -> DataSet:
if parser_config:
parser_config = parser_config.to_json()
res = self.post("/datasets",
{"name": name, "avatar": avatar, "description": description,"embedding_model":embedding_model,
"language": language,
"permission": permission, "chunk_method": chunk_method,
"parser_config": parser_config
}
)
res = self.post(
"/datasets",
{
"name": name,
"avatar": avatar,
"description": description,
"embedding_model": embedding_model,
"permission": permission,
"chunk_method": chunk_method,
"pagerank": pagerank,
"parser_config": parser_config,
},
)
res = res.json()
if res.get("code") == 0:
return DataSet(self, res["data"])
raise Exception(res["message"])
def delete_datasets(self, ids: list[str] | None = None):
res = self.delete("/datasets",{"ids": ids})
res=res.json()
res = self.delete("/datasets", {"ids": ids})
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
def get_dataset(self,name: str):
def get_dataset(self, name: str):
_list = self.list_datasets(name=name)
if len(_list) > 0:
return _list[0]
raise Exception("Dataset %s not found" % name)
def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True,
id: str | None = None, name: str | None = None) -> \
list[DataSet]:
res = self.get("/datasets",
{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[DataSet]:
res = self.get(
"/datasets",
{
"page": page,
"page_size": page_size,
"orderby": orderby,
"desc": desc,
"id": id,
"name": name,
},
)
res = res.json()
result_list = []
if res.get("code") == 0:
for data in res['data']:
for data in res["data"]:
result_list.append(DataSet(self, data))
return result_list
raise Exception(res["message"])
def create_chat(self, name: str, avatar: str = "", dataset_ids=None,
llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat:
def create_chat(self, name: str, avatar: str = "", dataset_ids=None, llm: Chat.LLM | None = None, prompt: Chat.Prompt | None = None) -> Chat:
if dataset_ids is None:
dataset_ids = []
dataset_list = []
@ -98,25 +119,33 @@ class RAGFlow:
dataset_list.append(id)
if llm is None:
llm = Chat.LLM(self, {"model_name": None,
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512, })
llm = Chat.LLM(
self,
{
"model_name": None,
"temperature": 0.1,
"top_p": 0.3,
"presence_penalty": 0.4,
"frequency_penalty": 0.7,
"max_tokens": 512,
},
)
if prompt is None:
prompt = Chat.Prompt(self, {"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.7,
"top_n": 8,
"top_k": 1024,
"variables": [{
"key": "knowledge",
"optional": True
}], "rerank_model": "",
"empty_response": None,
"opener": None,
"show_quote": True,
"prompt": None})
prompt = Chat.Prompt(
self,
{
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.7,
"top_n": 8,
"top_k": 1024,
"variables": [{"key": "knowledge", "optional": True}],
"rerank_model": "",
"empty_response": None,
"opener": None,
"show_quote": True,
"prompt": None,
},
)
if prompt.opener is None:
prompt.opener = "Hi! I'm your assistant, what can I do for you?"
if prompt.prompt is None:
@ -127,70 +156,93 @@ class RAGFlow:
"Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base."
)
temp_dict = {"name": name,
"avatar": avatar,
"dataset_ids": dataset_list if dataset_list else [],
"llm": llm.to_json(),
"prompt": prompt.to_json()}
temp_dict = {"name": name, "avatar": avatar, "dataset_ids": dataset_list if dataset_list else [], "llm": llm.to_json(), "prompt": prompt.to_json()}
res = self.post("/chats", temp_dict)
res = res.json()
if res.get("code") == 0:
return Chat(self, res["data"])
raise Exception(res["message"])
def delete_chats(self,ids: list[str] | None = None):
res = self.delete('/chats',
{"ids":ids})
def delete_chats(self, ids: list[str] | None = None):
res = self.delete("/chats", {"ids": ids})
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
def list_chats(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True,
id: str | None = None, name: str | None = None) -> list[Chat]:
res = self.get("/chats",{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "name": name})
def list_chats(self, page: int = 1, page_size: int = 30, orderby: str = "create_time", desc: bool = True, id: str | None = None, name: str | None = None) -> list[Chat]:
res = self.get(
"/chats",
{
"page": page,
"page_size": page_size,
"orderby": orderby,
"desc": desc,
"id": id,
"name": name,
},
)
res = res.json()
result_list = []
if res.get("code") == 0:
for data in res['data']:
for data in res["data"]:
result_list.append(Chat(self, data))
return result_list
raise Exception(res["message"])
def retrieve(
self,
dataset_ids,
document_ids=None,
question="",
page=1,
page_size=30,
similarity_threshold=0.2,
vector_similarity_weight=0.3,
top_k=1024,
rerank_id: str | None = None,
keyword: bool = False,
):
if document_ids is None:
document_ids = []
data_json = {
"page": page,
"page_size": page_size,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"rerank_id": rerank_id,
"keyword": keyword,
"question": question,
"dataset_ids": dataset_ids,
"document_ids": document_ids,
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post("/retrieval", json=data_json)
res = res.json()
if res.get("code") == 0:
chunks = []
for chunk_data in res["data"].get("chunks"):
chunk = Chunk(self, chunk_data)
chunks.append(chunk)
return chunks
raise Exception(res.get("message"))
def retrieve(self, dataset_ids, document_ids=None, question="", page=1, page_size=30, similarity_threshold=0.2, vector_similarity_weight=0.3, top_k=1024, rerank_id: str | None = None, keyword:bool=False, ):
if document_ids is None:
document_ids = []
data_json ={
def list_agents(self, page: int = 1, page_size: int = 30, orderby: str = "update_time", desc: bool = True, id: str | None = None, title: str | None = None) -> list[Agent]:
res = self.get(
"/agents",
{
"page": page,
"page_size": page_size,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"rerank_id": rerank_id,
"keyword": keyword,
"question": question,
"dataset_ids": dataset_ids,
"document_ids": document_ids
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post('/retrieval',json=data_json)
res = res.json()
if res.get("code") ==0:
chunks=[]
for chunk_data in res["data"].get("chunks"):
chunk=Chunk(self,chunk_data)
chunks.append(chunk)
return chunks
raise Exception(res.get("message"))
def list_agents(self, page: int = 1, page_size: int = 30, orderby: str = "update_time", desc: bool = True,
id: str | None = None, title: str | None = None) -> list[Agent]:
res = self.get("/agents",{"page": page, "page_size": page_size, "orderby": orderby, "desc": desc, "id": id, "title": title})
"orderby": orderby,
"desc": desc,
"id": id,
"title": title,
},
)
res = res.json()
result_list = []
if res.get("code") == 0:
for data in res['data']:
for data in res["data"]:
result_list.append(Agent(self, data))
return result_list
raise Exception(res["message"])