Refactor Chunk API (#2855)

### What problem does this PR solve?

Refactor Chunk API
#2846
### Type of change


- [x] Refactoring

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
liuhua
2024-10-16 18:41:24 +08:00
committed by GitHub
parent b9fa00f341
commit dab92ac1e8
11 changed files with 760 additions and 791 deletions

View File

@ -17,32 +17,11 @@ class Chunk(Base):
res_dict.pop(k)
super().__init__(rag, res_dict)
def delete(self) -> bool:
"""
Delete the chunk in the document.
"""
res = self.post('/doc/chunk/rm',
{"document_id": self.document_id, 'chunk_ids': [self.id]})
res = res.json()
if res.get("retmsg") == "success":
return True
raise Exception(res["retmsg"])
def save(self) -> bool:
"""
Save the document details to the server.
"""
res = self.post('/doc/chunk/set',
{"chunk_id": self.id,
"knowledgebase_id": self.knowledgebase_id,
"name": self.document_name,
"content": self.content,
"important_keywords": self.important_keywords,
"document_id": self.document_id,
"available": self.available,
})
def update(self,update_message:dict):
res = self.put(f"/dataset/{self.knowledgebase_id}/document/{self.document_id}/chunk/{self.id}",update_message)
res = res.json()
if res.get("retmsg") == "success":
return True
raise Exception(res["retmsg"])
if res.get("code") != 0 :
raise Exception(res["message"])

View File

@ -65,3 +65,14 @@ class DataSet(Base):
if res.get("code") != 0:
raise Exception(res["message"])
def async_parse_documents(self,document_ids):
res = self.post(f"/dataset/{self.id}/chunk",{"document_ids":document_ids})
res = res.json()
if res.get("code") != 0:
raise Exception(res.get("message"))
def async_cancel_parse_documents(self,document_ids):
res = self.rm(f"/dataset/{self.id}/chunk",{"document_ids":document_ids})
res = res.json()
if res.get("code") != 0:
raise Exception(res.get("message"))

View File

@ -1,7 +1,10 @@
import time
from PIL.ImageFile import raise_oserror
from .base import Base
from .chunk import Chunk
from typing import List
class Document(Base):
@ -29,160 +32,28 @@ class Document(Base):
res_dict.pop(k)
super().__init__(rag, res_dict)
def update(self,update_message:dict) -> bool:
"""
Save the document details to the server.
"""
res = self.post(f'/dataset/{self.knowledgebase_id}/info/{self.id}',update_message)
def list_chunks(self,offset=0, limit=30, keywords="", id:str=None):
data={"document_id": self.id,"keywords": keywords,"offset":offset,"limit":limit,"id":id}
res = self.get(f'/dataset/{self.knowledgebase_id}/document/{self.id}/chunk', data)
res = res.json()
if res.get("code") != 0:
raise Exception(res["message"])
if res.get("code") == 0:
chunks=[]
for data in res["data"].get("chunks"):
chunk = Chunk(self.rag,data)
chunks.append(chunk)
return chunks
raise Exception(res.get("message"))
def delete(self) -> bool:
"""
Delete the document from the server.
"""
res = self.rm('/doc/delete',
{"document_id": self.id})
res = res.json()
if res.get("retmsg") == "success":
return True
raise Exception(res["retmsg"])
def download(self) -> bytes:
"""
Download the document content from the server using the Flask API.
:return: The downloaded document content in bytes.
"""
# Construct the URL for the API request using the document ID and knowledge base ID
res = self.get(f"/dataset/{self.knowledgebase_id}/document/{self.id}")
# Check the response status code to ensure the request was successful
if res.status_code == 200:
# Return the document content as bytes
return res.content
else:
# Handle the error and raise an exception
raise Exception(
f"Failed to download document. Server responded with: {res.status_code}, {res.text}"
)
def async_parse(self):
"""
Initiate document parsing asynchronously without waiting for completion.
"""
try:
# Construct request data including document ID and run status (assuming 1 means to run)
data = {"document_ids": [self.id], "run": 1}
# Send a POST request to the specified parsing status endpoint to start parsing
res = self.post(f'/doc/run', data)
# Check the server response status code
if res.status_code != 200:
raise Exception(f"Failed to start async parsing: {res.text}")
print("Async parsing started successfully.")
except Exception as e:
# Catch and handle exceptions
print(f"Error occurred during async parsing: {str(e)}")
raise
import time
def join(self, interval=5, timeout=3600):
"""
Wait for the asynchronous parsing to complete and yield parsing progress periodically.
:param interval: The time interval (in seconds) for progress reports.
:param timeout: The timeout (in seconds) for the parsing operation.
:return: An iterator yielding parsing progress and messages.
"""
start_time = time.time()
while time.time() - start_time < timeout:
# Check the parsing status
res = self.get(f'/doc/{self.id}/status', {"document_ids": [self.id]})
res_data = res.json()
data = res_data.get("data", [])
# Retrieve progress and status message
progress = data.get("progress", 0)
progress_msg = data.get("status", "")
yield progress, progress_msg # Yield progress and message
if progress == 100: # Parsing completed
break
time.sleep(interval)
def cancel(self):
"""
Cancel the parsing task for the document.
"""
try:
# Construct request data, including document ID and action to cancel (assuming 2 means cancel)
data = {"document_ids": [self.id], "run": 2}
# Send a POST request to the specified parsing status endpoint to cancel parsing
res = self.post(f'/doc/run', data)
# Check the server response status code
if res.status_code != 200:
print("Failed to cancel parsing. Server response:", res.text)
else:
print("Parsing cancelled successfully.")
except Exception as e:
print(f"Error occurred during async parsing cancellation: {str(e)}")
raise
def list_chunks(self, page=1, offset=0, limit=12,size=30, keywords="", available_int=None):
"""
List all chunks associated with this document by calling the external API.
Args:
page (int): The page number to retrieve (default 1).
size (int): The number of chunks per page (default 30).
keywords (str): Keywords for searching specific chunks (default "").
available_int (int): Filter for available chunks (optional).
Returns:
list: A list of chunks returned from the API.
"""
data = {
"document_id": self.id,
"page": page,
"size": size,
"keywords": keywords,
"offset":offset,
"limit":limit
}
if available_int is not None:
data["available_int"] = available_int
res = self.post(f'/doc/chunk/list', data)
if res.status_code == 200:
res_data = res.json()
if res_data.get("retmsg") == "success":
chunks=[]
for chunk_data in res_data["data"].get("chunks", []):
chunk=Chunk(self.rag,chunk_data)
chunks.append(chunk)
return chunks
else:
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
else:
raise Exception(f"API request failed with status code {res.status_code}")
def add_chunk(self, content: str):
res = self.post('/doc/chunk/create', {"document_id": self.id, "content":content})
if res.status_code == 200:
res_data = res.json().get("data")
chunk_data = res_data.get("chunk")
return Chunk(self.rag,chunk_data)
else:
raise Exception(f"Failed to add chunk: {res.status_code} {res.text}")
res = self.post(f'/dataset/{self.knowledgebase_id}/document/{self.id}/chunk', {"content":content})
res = res.json()
if res.get("code") == 0:
return Chunk(self.rag,res["data"].get("chunk"))
raise Exception(res.get("message"))
def delete_chunks(self,ids:List[str]):
res = self.rm(f"dataset/{self.knowledgebase_id}/document/{self.id}/chunk",{"ids":ids})
res = res.json()
if res.get("code")!=0:
raise Exception(res.get("message"))

View File

@ -15,8 +15,8 @@ class Session(Base):
for message in self.messages:
if "reference" in message:
message.pop("reference")
res = self.post(f"/chat/{self.chat_id}/session/{self.id}/completion",
{"question": question, "stream": True}, stream=stream)
res = self.post(f"/chat/{self.chat_id}/completion",
{"question": question, "stream": True,"session_id":self.id}, stream=stream)
for line in res.iter_lines():
line = line.decode("utf-8")
if line.startswith("{"):
@ -82,3 +82,4 @@ class Chunk(Base):
self.term_similarity = None
self.positions = None
super().__init__(rag, res_dict)

View File

@ -158,105 +158,30 @@ class RAGFlow:
raise Exception(res["message"])
def async_parse_documents(self, doc_ids):
"""
Asynchronously start parsing multiple documents without waiting for completion.
:param doc_ids: A list containing multiple document IDs.
"""
try:
if not doc_ids or not isinstance(doc_ids, list):
raise ValueError("doc_ids must be a non-empty list of document IDs")
data = {"document_ids": doc_ids, "run": 1}
res = self.post(f'/doc/run', data)
if res.status_code != 200:
raise Exception(f"Failed to start async parsing for documents: {res.text}")
print(f"Async parsing started successfully for documents: {doc_ids}")
except Exception as e:
print(f"Error occurred during async parsing for documents: {str(e)}")
raise
def async_cancel_parse_documents(self, doc_ids):
"""
Cancel the asynchronous parsing of multiple documents.
:param doc_ids: A list containing multiple document IDs.
"""
try:
if not doc_ids or not isinstance(doc_ids, list):
raise ValueError("doc_ids must be a non-empty list of document IDs")
data = {"document_ids": doc_ids, "run": 2}
res = self.post(f'/doc/run', data)
if res.status_code != 200:
raise Exception(f"Failed to cancel async parsing for documents: {res.text}")
print(f"Async parsing canceled successfully for documents: {doc_ids}")
except Exception as e:
print(f"Error occurred during canceling parsing for documents: {str(e)}")
raise
def retrieval(self,
question,
datasets=None,
documents=None,
offset=0,
limit=6,
similarity_threshold=0.1,
vector_similarity_weight=0.3,
top_k=1024):
"""
Perform document retrieval based on the given parameters.
:param question: The query question.
:param datasets: A list of datasets (optional, as documents may be provided directly).
:param documents: A list of documents (if specific documents are provided).
:param offset: Offset for the retrieval results.
:param limit: Maximum number of retrieval results.
:param similarity_threshold: Similarity threshold.
:param vector_similarity_weight: Weight of vector similarity.
:param top_k: Number of top most similar documents to consider (for pre-filtering or ranking).
Note: This is a hypothetical implementation and may need adjustments based on the actual backend service API.
"""
try:
data = {
"question": question,
"datasets": datasets if datasets is not None else [],
"documents": [doc.id if hasattr(doc, 'id') else doc for doc in
documents] if documents is not None else [],
def retrieve(self, question="",datasets=None,documents=None, offset=1, limit=30, similarity_threshold=0.2,vector_similarity_weight=0.3,top_k=1024,rerank_id:str=None,keyword:bool=False,):
data_params = {
"offset": offset,
"limit": limit,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"knowledgebase_id": datasets,
"rerank_id":rerank_id,
"keyword":keyword
}
data_json ={
"question": question,
"datasets": datasets,
"documents": documents
}
# Send a POST request to the backend service (using requests library as an example, actual implementation may vary)
res = self.post(f'/doc/retrieval_test', data)
# Check the response status code
if res.status_code == 200:
res_data = res.json()
if res_data.get("retmsg") == "success":
chunks = []
for chunk_data in res_data["data"].get("chunks", []):
chunk = Chunk(self, chunk_data)
chunks.append(chunk)
return chunks
else:
raise Exception(f"Error fetching chunks: {res_data.get('retmsg')}")
else:
raise Exception(f"API request failed with status code {res.status_code}")
except Exception as e:
print(f"An error occurred during retrieval: {e}")
raise
res = self.get(f'/retrieval', data_params,data_json)
res = res.json()
if res.get("code") ==0:
chunks=[]
for chunk_data in res["data"].get("chunks"):
chunk=Chunk(self,chunk_data)
chunks.append(chunk)
return chunks
raise Exception(res.get("message"))

View File

@ -63,17 +63,13 @@ class TestDocument(TestSdk):
# Check if the retrieved document is of type Document
if isinstance(doc, Document):
# Download the document content and save it to a file
try:
with open("ragflow.txt", "wb+") as file:
file.write(doc.download())
# Print the document object for debugging
print(doc)
with open("./ragflow.txt", "wb+") as file:
file.write(doc.download())
# Print the document object for debugging
print(doc)
# Assert that the download was successful
assert True, "Document downloaded successfully."
except Exception as e:
# If an error occurs, raise an assertion error
assert False, f"Failed to download document, error: {str(e)}"
# Assert that the download was successful
assert True, f"Failed to download document, error: {doc}"
else:
# If the document retrieval fails, assert failure
assert False, f"Failed to get document, error: {doc}"
@ -100,7 +96,7 @@ class TestDocument(TestSdk):
blob2 = b"Sample document content for ingestion test222."
list_1 = [{"name":name1,"blob":blob1},{"name":name2,"blob":blob2}]
ds.upload_documents(list_1)
for d in ds.list_docs(keywords="test", offset=0, limit=12):
for d in ds.list_documents(keywords="test", offset=0, limit=12):
assert isinstance(d, Document), "Failed to upload documents"
def test_delete_documents_in_dataset_with_success(self):
@ -123,16 +119,11 @@ class TestDocument(TestSdk):
blob1 = b"Sample document content for ingestion test333."
name2 = "Test Document444.txt"
blob2 = b"Sample document content for ingestion test444."
name3 = 'test.txt'
path = 'test_data/test.txt'
rag.create_document(ds, name=name3, blob=open(path, "rb").read())
rag.create_document(ds, name=name1, blob=blob1)
rag.create_document(ds, name=name2, blob=blob2)
for d in ds.list_docs(keywords="document", offset=0, limit=12):
ds.upload_documents([{"name":name1,"blob":blob1},{"name":name2,"blob":blob2}])
for d in ds.list_documents(keywords="document", offset=0, limit=12):
assert isinstance(d, Document)
d.delete()
print(d)
remaining_docs = ds.list_docs(keywords="rag", offset=0, limit=12)
ds.delete_documents([d.id])
remaining_docs = ds.list_documents(keywords="rag", offset=0, limit=12)
assert len(remaining_docs) == 0, "Documents were not properly deleted."
def test_parse_and_cancel_document(self):
@ -144,16 +135,15 @@ class TestDocument(TestSdk):
# Define the document name and path
name3 = 'westworld.pdf'
path = 'test_data/westworld.pdf'
path = './test_data/westworld.pdf'
# Create a document in the dataset using the file path
rag.create_document(ds, name=name3, blob=open(path, "rb").read())
ds.upload_documents({"name":name3, "blob":open(path, "rb").read()})
# Retrieve the document by name
doc = rag.get_document(name="westworld.pdf")
# Initiate asynchronous parsing
doc.async_parse()
doc = rag.list_documents(name="westworld.pdf")
doc = doc[0]
ds.async_parse_documents(document_ids=[])
# Print message to confirm asynchronous parsing has been initiated
print("Async parsing initiated")