Light GraphRAG (#4585)

### What problem does this PR solve?

#4543

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-01-22 19:43:14 +08:00
committed by GitHub
parent 1a367664f1
commit dd0ebbea35
55 changed files with 5461 additions and 4000 deletions

View File

@ -15,7 +15,7 @@
#
from flask import request, jsonify
from api.db import LLMType, ParserType
from api.db import LLMType
from api.db.services.dialog_service import label_question
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
@ -30,6 +30,7 @@ def retrieval(tenant_id):
req = request.json
question = req["query"]
kb_id = req["knowledge_id"]
use_kg = req.get("use_kg", False)
retrieval_setting = req.get("retrieval_setting", {})
similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0))
top = int(retrieval_setting.get("top_k", 1024))
@ -45,8 +46,7 @@ def retrieval(tenant_id):
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
ranks = settings.retrievaler.retrieval(
question,
embd_mdl,
kb.tenant_id,
@ -58,6 +58,16 @@ def retrieval(tenant_id):
top=top,
rank_feature=label_question(question, [kb])
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
records = []
for c in ranks["chunks"]:
c.pop("vector", None)

View File

@ -1297,15 +1297,15 @@ def retrieval_test(tenant_id):
kb_ids = req["dataset_ids"]
if not isinstance(kb_ids, list):
return get_error_data_result("`dataset_ids` should be a list")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
for id in kb_ids:
if not KnowledgebaseService.accessible(kb_id=id, user_id=tenant_id):
return get_error_data_result(f"You don't own the dataset {id}.")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_result(
message='Datasets use different embedding models."',
code=settings.RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.DATA_ERROR,
)
if "question" not in req:
return get_error_data_result("`question` is required.")
@ -1313,6 +1313,7 @@ def retrieval_test(tenant_id):
size = int(req.get("page_size", 30))
question = req["question"]
doc_ids = req.get("document_ids", [])
use_kg = req.get("use_kg", False)
if not isinstance(doc_ids, list):
return get_error_data_result("`documents` should be a list")
doc_ids_list = KnowledgebaseService.list_documents_by_ids(kb_ids)
@ -1342,8 +1343,7 @@ def retrieval_test(tenant_id):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
ranks = settings.retrievaler.retrieval(
question,
embd_mdl,
kb.tenant_id,
@ -1358,6 +1358,15 @@ def retrieval_test(tenant_id):
highlight=highlight,
rank_feature=label_question(question, kbs)
)
if use_kg:
ck = settings.kg_retrievaler.retrieval(question,
[k.tenant_id for k in kbs],
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)
for c in ranks["chunks"]:
c.pop("vector", None)