Light GraphRAG (#4585)

### What problem does this PR solve?

#4543

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-01-22 19:43:14 +08:00
committed by GitHub
parent 1a367664f1
commit dd0ebbea35
55 changed files with 5461 additions and 4000 deletions

View File

@ -28,7 +28,7 @@ from peewee import fn
from api.db.db_utils import bulk_insert_into_db
from api import settings
from api.utils import current_timestamp, get_format_time, get_uuid
from graphrag.mind_map_extractor import MindMapExtractor
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME
from rag.utils.storage_factory import STORAGE_IMPL
from rag.nlp import search, rag_tokenizer
@ -105,8 +105,19 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
cls.clear_chunk_num(doc.id)
try:
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
except Exception:
pass
return cls.delete_by_id(doc.id)
@classmethod
@ -142,7 +153,7 @@ class DocumentService(CommonService):
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg,
cls.model.run]
cls.model.run, cls.model.parser_id]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
@ -295,9 +306,9 @@ class DocumentService(CommonService):
Tenant.asr_id,
Tenant.llm_id,
)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id))
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))
.where(cls.model.id == doc_id)
)
configs = configs.dicts()
if not configs:
@ -365,6 +376,12 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def update_progress(cls):
MSG = {
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
"graphrag": "Start Graph Extraction",
"graph_resolution": "Start Graph Resolution",
"graph_community": "Start Graph Community Reports Generation"
}
docs = cls.get_unfinished_docs()
for d in docs:
try:
@ -390,15 +407,27 @@ class DocumentService(CommonService):
prg = -1
status = TaskStatus.FAIL.value
elif finished:
if d["parser_config"].get("raptor", {}).get("use_raptor") and d["progress_msg"].lower().find(
" raptor") < 0:
queue_raptor_tasks(d)
m = "\n".join(sorted(msg))
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("resolution") \
and m.find(MSG["graph_resolution"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
and d["parser_config"].get("graphrag", {}).get("community") \
and m.find(MSG["graph_community"]) < 0:
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
prg = 0.98 * len(tsks) / (len(tsks) + 1)
msg.append("------ RAPTOR -------")
else:
status = TaskStatus.DONE.value
msg = "\n".join(msg)
msg = "\n".join(sorted(msg))
info = {
"process_duation": datetime.timestamp(
datetime.now()) -
@ -430,7 +459,7 @@ class DocumentService(CommonService):
return False
def queue_raptor_tasks(doc):
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
chunking_config = DocumentService.get_chunking_config(doc["id"])
hasher = xxhash.xxh64()
for field in sorted(chunking_config.keys()):
@ -443,15 +472,16 @@ def queue_raptor_tasks(doc):
"doc_id": doc["id"],
"from_page": 100000000,
"to_page": 100000000,
"progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval)."
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
}
task = new_task()
for field in ["doc_id", "from_page", "to_page"]:
hasher.update(str(task.get(field, "")).encode("utf-8"))
hasher.update(ty.encode("utf-8"))
task["digest"] = hasher.hexdigest()
bulk_insert_into_db(Task, [task], True)
task["type"] = "raptor"
task["task_type"] = ty
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
@ -489,7 +519,7 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
ParserType.AUDIO.value: audio,
ParserType.EMAIL.value: email
}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
doc_nm = {}
@ -592,4 +622,4 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
return [d["id"] for d, _ in files]
return [d["id"] for d, _ in files]